Skip to content

pytorch_nn

CBOW and SG architectures Pytorch Implementation

CBOW

Bases: Word2Vec

CBOW with Negative Sampling model Pytorch implementation.

Source code in rivertext/models/iword2vec/model.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class CBOW(Word2Vec):
    """CBOW with Negative Sampling model Pytorch implementation."""

    def __init__(self, emb_size: int, emb_dimension: int, cbow_mean: bool = True):
        """Initialize a SG instance.

        Args:
            emb_size: The number of words to process.
            emb_dimension: The dimension of the word embeddings.
            cbow_mean: True if it use mean of context vector without considering padding
                idx, otherwise False.
        """
        super(CBOW, self).__init__(emb_size, emb_dimension)
        self.cbow_mean = cbow_mean

    def forward(
        self, target: torch.Tensor, context: torch.Tensor, negatives: torch.Tensor
    ) -> float:
        """Forward pass the CBOW model.

        Args:
            target: Target positive samples.
            context: Context positive samples.
            negatives: Negative samples.

        Returns:
            Objective function result.
        """
        t = self.syn1(target)
        c = self.syn0(context)

        # Mean of context vector without considering padding idx (0)
        if self.cbow_mean:
            mean_c = torch.sum(c, dim=1) / torch.sum(context != 0, dim=1, keepdim=True)
        else:
            mean_c = c.sum(dim=1)

        score = torch.mul(t, mean_c).squeeze()
        score = torch.sum(score, dim=1)
        score = F.logsigmoid(score)

        n = self.syn1(negatives)
        neg_score = torch.bmm(n, mean_c.unsqueeze(2)).squeeze()
        neg_score = F.logsigmoid(-1 * neg_score)

        return -1 * (torch.sum(score) + torch.sum(neg_score))

__init__(emb_size, emb_dimension, cbow_mean=True)

Initialize a SG instance.

Parameters:

Name Type Description Default
emb_size int

The number of words to process.

required
emb_dimension int

The dimension of the word embeddings.

required
cbow_mean bool

True if it use mean of context vector without considering padding idx, otherwise False.

True
Source code in rivertext/models/iword2vec/model.py
110
111
112
113
114
115
116
117
118
119
120
def __init__(self, emb_size: int, emb_dimension: int, cbow_mean: bool = True):
    """Initialize a SG instance.

    Args:
        emb_size: The number of words to process.
        emb_dimension: The dimension of the word embeddings.
        cbow_mean: True if it use mean of context vector without considering padding
            idx, otherwise False.
    """
    super(CBOW, self).__init__(emb_size, emb_dimension)
    self.cbow_mean = cbow_mean

forward(target, context, negatives)

Forward pass the CBOW model.

Parameters:

Name Type Description Default
target torch.Tensor

Target positive samples.

required
context torch.Tensor

Context positive samples.

required
negatives torch.Tensor

Negative samples.

required

Returns:

Type Description
float

Objective function result.

Source code in rivertext/models/iword2vec/model.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def forward(
    self, target: torch.Tensor, context: torch.Tensor, negatives: torch.Tensor
) -> float:
    """Forward pass the CBOW model.

    Args:
        target: Target positive samples.
        context: Context positive samples.
        negatives: Negative samples.

    Returns:
        Objective function result.
    """
    t = self.syn1(target)
    c = self.syn0(context)

    # Mean of context vector without considering padding idx (0)
    if self.cbow_mean:
        mean_c = torch.sum(c, dim=1) / torch.sum(context != 0, dim=1, keepdim=True)
    else:
        mean_c = c.sum(dim=1)

    score = torch.mul(t, mean_c).squeeze()
    score = torch.sum(score, dim=1)
    score = F.logsigmoid(score)

    n = self.syn1(negatives)
    neg_score = torch.bmm(n, mean_c.unsqueeze(2)).squeeze()
    neg_score = F.logsigmoid(-1 * neg_score)

    return -1 * (torch.sum(score) + torch.sum(neg_score))

SG

Bases: Word2Vec

SkipGram with Negative Sampling model Pytorch implementation.

Source code in rivertext/models/iword2vec/model.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class SG(Word2Vec):
    """SkipGram with Negative Sampling model Pytorch implementation."""

    def __init__(self, emb_size: int, emb_dimension: int):
        """Initialize a SG instance.

        Args:
            emb_size:
                The number of words to process.
            emb_dimension:
                The dimension of the word embeddings.
        """
        super(SG, self).__init__(emb_size, emb_dimension)

    def forward(
        self, target: torch.Tensor, context: torch.Tensor, negatives: torch.Tensor
    ) -> float:
        """Forward pass the SG model.

        Args:
            target: Target positive samples.
            context: Context positive samples.
            negatives: Negative samples.

        Returns:
            Objective function result.
        """
        t = self.syn0(target)
        c = self.syn1(context)

        score = torch.mul(t, c).squeeze()
        score = torch.sum(score, dim=1)
        score = F.logsigmoid(score)

        n = self.syn1(negatives)
        neg_score = torch.bmm(n, t.unsqueeze(2)).squeeze()
        neg_score = F.logsigmoid(-1 * neg_score)

        return -1 * (torch.sum(score) + torch.sum(neg_score))

__init__(emb_size, emb_dimension)

Initialize a SG instance.

Parameters:

Name Type Description Default
emb_size int

The number of words to process.

required
emb_dimension int

The dimension of the word embeddings.

required
Source code in rivertext/models/iword2vec/model.py
69
70
71
72
73
74
75
76
77
78
def __init__(self, emb_size: int, emb_dimension: int):
    """Initialize a SG instance.

    Args:
        emb_size:
            The number of words to process.
        emb_dimension:
            The dimension of the word embeddings.
    """
    super(SG, self).__init__(emb_size, emb_dimension)

forward(target, context, negatives)

Forward pass the SG model.

Parameters:

Name Type Description Default
target torch.Tensor

Target positive samples.

required
context torch.Tensor

Context positive samples.

required
negatives torch.Tensor

Negative samples.

required

Returns:

Type Description
float

Objective function result.

Source code in rivertext/models/iword2vec/model.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def forward(
    self, target: torch.Tensor, context: torch.Tensor, negatives: torch.Tensor
) -> float:
    """Forward pass the SG model.

    Args:
        target: Target positive samples.
        context: Context positive samples.
        negatives: Negative samples.

    Returns:
        Objective function result.
    """
    t = self.syn0(target)
    c = self.syn1(context)

    score = torch.mul(t, c).squeeze()
    score = torch.sum(score, dim=1)
    score = F.logsigmoid(score)

    n = self.syn1(negatives)
    neg_score = torch.bmm(n, t.unsqueeze(2)).squeeze()
    neg_score = F.logsigmoid(-1 * neg_score)

    return -1 * (torch.sum(score) + torch.sum(neg_score))

Word2Vec

Bases: nn.Module

Base class for encapsulating the shared parameter beetween the two models.

References
  1. Mikolov, T., Chen, K., Corrado, G. & Dean, J. (2013). Efficient Estimation of Word Representations in Vector Space. CoRR, abs/1301.3781.
Source code in rivertext/models/iword2vec/model.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class Word2Vec(nn.Module):
    """Base class for encapsulating the shared parameter beetween the two models.

    References:
        1. Mikolov, T., Chen, K., Corrado, G. & Dean, J. (2013). Efficient Estimation
             of Word Representations in Vector Space. CoRR, abs/1301.3781.

    """

    def __init__(self, emb_size: int, emb_dimension: int):
        """Initialize a Word2Vec instance.

        Args:
            emb_size:
                The number of words to process.
            emb_dimension: int
                The dimension of the word embeddings.
        """

        super(Word2Vec, self).__init__()
        self.emb_size = emb_size
        self.emb_dimension = emb_dimension

        # syn0: embedding for input words
        # syn1: embedding for output words
        self.syn0 = nn.Embedding(emb_size, emb_dimension, sparse=True, padding_idx=0)
        self.syn1 = nn.Embedding(emb_size, emb_dimension, sparse=True, padding_idx=0)

        init_range = 0.5 / self.emb_dimension
        init.uniform_(self.syn0.weight.data, -init_range, init_range)
        init.constant_(self.syn1.weight.data, 0)
        self.syn0.weight.data[0, :] = 0

    def forward(self, pos_u: torch.Tensor, pos_v: torch.Tensor, neg_v: torch.Tensor):
        """Forward network pass.

        Args:
            pos_u: Target positive samples.
            pos_v: Context positive samples.
            neg_v: Negative samples.

        """
        raise NotImplementedError()

    def get_embedding(self, idx: int) -> np.ndarray:
        """Obtain the vector associated with a word by its index.

        Args:
            idx:
                Index associated with a word.

        Returns:
            The vector associated with a word.
        """
        return (self.syn0.weight[idx] + self.syn1.weight[idx]).cpu().detach().numpy()

__init__(emb_size, emb_dimension)

Initialize a Word2Vec instance.

Parameters:

Name Type Description Default
emb_size int

The number of words to process.

required
emb_dimension int

int The dimension of the word embeddings.

required
Source code in rivertext/models/iword2vec/model.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(self, emb_size: int, emb_dimension: int):
    """Initialize a Word2Vec instance.

    Args:
        emb_size:
            The number of words to process.
        emb_dimension: int
            The dimension of the word embeddings.
    """

    super(Word2Vec, self).__init__()
    self.emb_size = emb_size
    self.emb_dimension = emb_dimension

    # syn0: embedding for input words
    # syn1: embedding for output words
    self.syn0 = nn.Embedding(emb_size, emb_dimension, sparse=True, padding_idx=0)
    self.syn1 = nn.Embedding(emb_size, emb_dimension, sparse=True, padding_idx=0)

    init_range = 0.5 / self.emb_dimension
    init.uniform_(self.syn0.weight.data, -init_range, init_range)
    init.constant_(self.syn1.weight.data, 0)
    self.syn0.weight.data[0, :] = 0

forward(pos_u, pos_v, neg_v)

Forward network pass.

Parameters:

Name Type Description Default
pos_u torch.Tensor

Target positive samples.

required
pos_v torch.Tensor

Context positive samples.

required
neg_v torch.Tensor

Negative samples.

required
Source code in rivertext/models/iword2vec/model.py
42
43
44
45
46
47
48
49
50
51
def forward(self, pos_u: torch.Tensor, pos_v: torch.Tensor, neg_v: torch.Tensor):
    """Forward network pass.

    Args:
        pos_u: Target positive samples.
        pos_v: Context positive samples.
        neg_v: Negative samples.

    """
    raise NotImplementedError()

get_embedding(idx)

Obtain the vector associated with a word by its index.

Parameters:

Name Type Description Default
idx int

Index associated with a word.

required

Returns:

Type Description
np.ndarray

The vector associated with a word.

Source code in rivertext/models/iword2vec/model.py
53
54
55
56
57
58
59
60
61
62
63
def get_embedding(self, idx: int) -> np.ndarray:
    """Obtain the vector associated with a word by its index.

    Args:
        idx:
            Index associated with a word.

    Returns:
        The vector associated with a word.
    """
    return (self.syn0.weight[idx] + self.syn1.weight[idx]).cpu().detach().numpy()