Pytorch实现word2vec
主要内容
Word2Vec的原理网上有很多很多资料,这里就不再复述了。本人使用pytorch来尽可能复现Distributed Representations of Words and Phrases and their Compositionality 论文中训练词向量的方法。论文中有很多模型实现的细节,这些细节对于词向量的好坏至关重要。我们虽然无法完全复现论文中的实验结果,主要是由于计算资源等各种细节原因,但是还是可以大致展示如何训练词向量。
以下是一些未实现的细节。
- subsampling:参考论文section 2.3
训练数据为text8,所有相关代码及数据下载地址Word2Vec地址,提取密码:p46t。
在项目目录下运行:bash run_word2vec.sh。
数据预处理
- 从文本文件中读取所有的文字,通过这些文本创建一个vocabulary
- 由于单词数量可能太大,我们只选取最常见的MAX_VOCAB_SIZE个单词
- 我们添加一个UNK单词表示所有不常见的单词
- 我们需要记录单词到index的mapping,以及index到单词的mapping,单词的count,单词的(normalized) frequency,以及单词总数。
with open(args.data_dir,'r') as fin:
text = fin.read()
text = [w for w in text.lower().split()]
vocab = dict(Counter(text).most_common(args.max_vocab_size-1))
vocab["<unk>"] = len(text) - np.sum(list(vocab.values()))
idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:idx for idx,word in enumerate(idx_to_word)}
# negsample的采样概率分布
word_counts = np.asarray([value for value in vocab.values()],dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (3./4.)
word_freqs = word_freqs / np.sum(word_freqs)
vocab_size = len(idx_to_word)
return text,idx_to_word,word_to_idx,word_freqs,vocab_size
实现Dataloader
一个dataloader需要以下内容:
- 把所有text编码成数字,然后用subsampling预处理这些文字。
- 保存vocabulary,单词count,normalized word frequency
- 每个iteration sample一个中心词
- 根据中心词sample一些negative单词
- 返回单词的counts
直接使用pytorch的dataloader,使用方法参照这里Pytorch dataloader。我们需要定义一下两个function - _ len_ 返回整个数据集中的item总数
- _ get_ 根据给定的index返回指定的item
class wordEmbeddingDataset(Dataset):
def __init__(self,text, word_to_idx, idx_to_word, word_freqs,C,K):
'''
:param text: 语料
:param word_to_idx:
:param idx_to_word:
:param word_freqs: 词频的3/4,negatiesample
:param C: skip_gram的周围词个数
:param K: negative sample的个数
'''
super(wordEmbeddingDataset, self).__init__()
self.vocab_size = len(word_to_idx)
self.text_encoded = [word_to_idx.get(t,self.vocab_size-1) for t in text]
self.text_encoded = torch.tensor(self.text_encoded).long()
self.word_to_idx = word_to_idx
self.idx_to_word = idx_to_word
self.word_freqs = torch.tensor(word_freqs)
self.C = C
self.K = K
def __len__(self):
return len(self.text_encoded)
def __getitem__(self, idx):
center_word = self.text_encoded[idx]
pos_indices = list(range(idx-self.C,idx)) + list(range(idx+1,idx+self.C+1))
# 前后超范围从后前取
pos_indices = [i%len(self.text_encoded) for i in pos_indices]
pos_words = self.text_encoded[pos_indices]
#neg_words的采样
neg_words = torch.multinomial(self.word_freqs,self.K*pos_words.shape[0],replacement=True)
return center_word,pos_words,neg_words
定义Word2vec模型
word2vec模型很简单其实就是一个in_embed,一个out_imbed,采用neg_sample,objective函数为:
代码如下:
class word2VecModel(nn.Module):
def __init__(self,vocab_size,emb_size):
super(word2VecModel,self).__init__()
self.vocab_size = vocab_size
self.emb_size = emb_size
initrange = 0.5/self.emb_size
self.in_embed = nn.Embedding(vocab_size,emb_size)
self.in_embed.weight.data.uniform_(-initrange,initrange)
self.out_embed = nn.Embedding(vocab_size,emb_size)
self.out_embed.weight.data.uniform_(-initrange,initrange)
def forward(self,center_words,pos_words,neg_words):
'''
center_words: 中心词, [batch_size]
pos_words: 中心词周围 context window 出现过的单词 [batch_size * (window_size * 2)]
neg_words: 中心词周围没有出现过的单词,从 negative sampling 得到 [batch_size, (window_size * 2 * K)]
'''
batch_size = center_words.size(0)
input_embedding = self.in_embed(center_words) #[batch,emb]
pos_embedding = self.out_embed(pos_words) #[batch,2c,emb]
neg_embedding = self.out_embed(neg_words) #[batch,2c*k,emb]
log_pos = torch.matmul(pos_embedding,input_embedding.unsqueeze(2)).squeeze() # [batch,2c]
log_nes = torch.matmul(neg_embedding,-input_embedding.unsqueeze(2)).squeeze() #[batch,2c*k]
log_pos_los = F.logsigmoid(log_pos).sum(1)
log_neg_los = F.logsigmoid(log_nes).sum(1)
loss = log_neg_los+log_pos_los
return -loss.mean()
def input_embeddings(self):
return self.in_embed.weight.data.cpu().numpy()
模型的训练与评估
- 模型一般需要训练若干个epoch
- 每个epoch我们都把所有的数据分成若干个batch
- 把每个batch的输入和输出都包装成cuda tensor
- forward pass
- 清空模型当前gradient
- backward pass,更新模型参数
- 每隔一定的iteration输出模型在当前iteration的loss,以及在验证数据集上做模型的评估
- 模型保存
def train(args,model,dataloader,word_to_idx,idx_to_word):
LOG_FILE = "word-embedding.log"
tb_writer = SummaryWriter('./runs')
model.train()
t_total = args.num_epoch * len(dataloader)
optimizer = AdamW(model.parameters(),lr=args.learnning_rate,eps=1e-8)
scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=args.warmup_steps,num_training_steps=t_total)
train_iterator = trange(args.num_epoch,desc="epoch")
tr_loss = 0.
logg_loss = 0.
global_step = 0
for k in train_iterator:
print("the {} epoch beginning!".format(k))
epoch_iteration = tqdm(dataloader,desc="iteration")
for step,batch in enumerate(epoch_iteration):
batch = tuple(t.to(args.device) for t in batch)
input = {"center_words":batch[0],"pos_words":batch[1],"neg_words":batch[2]}
loss = model(**input)
model.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
global_step +=1
tr_loss += loss.item()
if (step+1) % 100 == 0:
loss_scalar = (tr_loss - logg_loss) / 100
logg_loss = tr_loss
with open(LOG_FILE, "a") as fout:
fout.write("epoch: {}, iter: {}, loss: {},learn_rate: {}\n".format(k, step, loss_scalar,scheduler.get_lr()[0]))
print("epoch: {}, iter: {}, loss: {}, learning_rate: {}".format(k, step, loss_scalar,scheduler.get_lr()[0]))
tb_writer.add_scalar("learning_rate",scheduler.get_lr()[0],global_step)
tb_writer.add_scalar("loss",loss_scalar,global_step)
if (step+1) % 2000 == 0:
embedding_weights = model.input_embeddings()
sim_simlex = evaluate("./worddata/simlex-999.txt", embedding_weights,word_to_idx)
sim_men = evaluate("./worddata/men.txt", embedding_weights,word_to_idx)
sim_353 = evaluate("./worddata/wordsim353.csv", embedding_weights,word_to_idx)
with open(LOG_FILE, "a") as fout:
print("epoch: {}, iteration: {}, simlex-999: {}, men:{}, sim353:{}, nearest to monster: {}\n".format(
k, step, sim_simlex,sim_men,sim_353, find_nearest("monster",embedding_weights,word_to_idx,idx_to_word)))
fout.write("epoch: {}, iteration: {}, simlex-999: {}, men: {}, sim353: {}, nearest to monster: {}\n".format(
k, step, sim_simlex, sim_men, sim_353, find_nearest("monster",embedding_weights,word_to_idx,idx_to_word)))
embedding_weights = model.input_embeddings()
np.save("embedding-{}".format(args.embed_size), embedding_weights)
torch.save(model.state_dict(), "embedding-{}.th".format(args.embed_size))
模型的loss下降曲线:
模型展示
在 MEN 和 Simplex-999、sim-353 数据集上做评估:
simlex999: SpearmanrResult(correlation=0.17249746449326459, pvalue=8.268870735375061e-08),
men: SpearmanrResult(correlation=0.427926614729899, pvalue=1.76732628946326e-115),
sim353: SpearmanrResult(correlation=0.4555634677353853, pvalue=9.452442365338771e-18)
寻找nearest neighbors:
word:good, nearest:[‘good’, ‘bad’, ‘things’, ‘happiness’, ‘everything’, ‘pleasure’, ‘nothing’, ‘something’, ‘think’, ‘whatever’]
word:fresh, nearest:[‘fresh’, ‘salt’, ‘dry’, ‘grain’, ‘vegetables’, ‘eggs’, ‘fruit’, ‘sugar’, ‘milk’, ‘drinking’]
word:monster, nearest:[‘monster’, ‘giant’, ‘loch’, ‘ness’, ‘creature’, ‘beast’, ‘hero’, ‘wolf’, ‘sword’, ‘serpent’]
word:green, nearest:[‘green’, ‘blue’, ‘yellow’, ‘orange’, ‘red’, ‘purple’, ‘white’, ‘colored’, ‘brown’, ‘colors’]
word:like, nearest:[‘like’, ‘similar’, ‘such’, ‘resemble’, ‘teeth’, ‘sometimes’, ‘soft’, ‘unlike’, ‘honey’, ‘etc’]
word:america, nearest:[‘america’, ‘africa’, ‘australia’, ‘europe’, ‘canada’, ‘african’, ‘caribbean’, ‘pacific’, ‘carolina’, ‘americas’]
word:chicago, nearest:[‘chicago’, ‘illinois’, ‘boston’, ‘detroit’, ‘atlanta’, ‘cleveland’, ‘cincinnati’, ‘miami’, ‘houston’, ‘denver’]
word:work, nearest:[‘work’, ‘works’, ‘ideas’, ‘scientific’, ‘writing’, ‘haydn’, ‘seminal’, ‘philosophical’, ‘philosophy’, ‘writings’]
word:computer, nearest:[‘computer’, ‘computers’, ‘hardware’, ‘software’, ‘computing’, ‘digital’, ‘graphics’, ‘machines’, ‘portable’, ‘interface’]
word:language, nearest:[‘language’, ‘languages’, ‘dialects’, ‘dialect’, ‘spoken’, ‘vocabulary’, ‘syntax’, ‘grammar’, ‘alphabet’, ‘speakers’]
单词之间的关系
the nearest to <women-man+king>:
queen、prince、emperor、king、son、daughter、throne、iii、kings、wife、iv、duke、heir、vii、henry、princess、father、empress、anne、brother