word2vec.py
import torch
import torch.nn.functional as F
import numpy as np
import time
import jieba
class SkipGram(torch.nn.Module):
def __init__(self, vocab_size, embedding_size):
super().__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden = torch.nn.Linear(self.vocab_size, self.embedding_size)
self.predict = torch.nn.Linear(self.embedding_size, self.vocab_size)
# self.w1=torch.
# self.w2=torch.
def forward(self, X):
hidden = self.hidden(X)
predict = F.softmax(self.predict(hidden))
return predict
def data_iter(words,batch_size=3):
'''sentence = 'The quick fox jumps over the lazy dog'
words = sentence.split()
'''
w_len = len(words)
word2id = {words[i]: i for i in range(w_len)}
one_hot = np.eye(len(words))
context_size = 1
x = []
y=[]
for i in range(w_len):
prior = words[i - context_size:i]
behind = words[i + 1:i + 1 + context_size]
context = prior + behind
x.extend([one_hot[i] for c in context])#输入的x是word,此处重复context次,是为了构建context个(word,context)对
y.extend([word2id[c] for c in context])#预测的y是相邻单词的id,因为loss是CrossEntropy
idx = 0
while idx < len(x):
yield (x[idx:idx + batch_size],y[idx:idx+batch_size])
idx += batch_size
def cut_sentence(sentence):
word_list=jieba.cut(sentence)
return word_list
def train(words,batch_size=64):
net = SkipGram(len(words),20)
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
loss_fun = torch.nn.CrossEntropyLoss()
losses = []
for i in range(5):
batch_num=0
for x,y in data_iter(words,batch_size):
x, y = torch.FloatTensor(x), torch.LongTensor(y)
pred = net(x)
loss = loss_fun(pred, y)
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_num+=1
if batch_num%10==0:
print('batch_num %d,loss:%.4f' % (batch_num, sum(losses) / len(losses)))
date_str=time.strftime('%Y%m%d%H%M%S',time.localtime())
model_filename='skipgram_{}.pkl'.format(date_str)
torch.save(net.state_dict(),'saved_model/{}'.format(model_filename))
print('model is saved as {}'.format(model_filename))
return model_filename
def load_corpus():
stop_word = ['【', '】', ')', '(', '、', ',', '“', '”', '。', '\n', '《', '》', ' ', '-', '!', '?', '.', '\'', '[', ']',
':', '/', '.', '"', '\u3000', '’', '.', ',', '…', '?',';','(',')']
f= open('data/corpus.txt','r',encoding='utf-8')
text=f.read()
f.close()
for i in stop_word:
text = text.replace(i, "")
print(text)
return text
def test():
text='The quick fox jumps over the lazy dog'
words=text.split()
text = load_corpus()
words=list(cut_sentence(text))
#print(words)
w_len = len(words)
word2id = {words[i]: i for i in range(w_len)}
one_hot = np.eye(len(words))
model_filename = train(words)
net = SkipGram(len(words),20)
net.load_state_dict(torch.load('saved_model/{}'.format(model_filename)))
idx=5
print('word:',words[idx])
print('prediction:')
print(net(torch.FloatTensor(one_hot[idx])))
if __name__ == '__main__':
test()