计数方法应用于PTB数据集
- PTB数据集
- ptb.py
- 使用ptb.py
- 计数方法应用于PTB数据集
PTB数据集
内容如下:
一行保存一个句子;将稀有单词替换成特殊字符 < unk > ;将具体的数字替换 成“N”
we 're talking about years ago before anyone heard of asbestos having any questionable properties there is no asbestos in our products now neither <unk> nor the researchers who studied the workers were aware of any research on smokers of the kent cigarettes we have no useful information on whether users are at risk said james a. <unk> of boston 's <unk> cancer institute dr. <unk> led a team of researchers from the national cancer institute and the medical schools of harvard university and boston university
ptb.py
使用PTB数据集:
由下面这句话,可知用PTB数据集时候,是把所有句子首尾连接了。
words = open(file_path).read().replace('\n', '<eos>').strip().split()
ptb.py起到了下载PTB数据集,把数据集存到文件夹某个位置,然后对数据集进行提取的功能,提取出corpus, word_to_id, id_to_word。
import sys
import os
sys.path.append('..')
try:import urllib.request
except ImportError:raise ImportError('Use Python3!')
import pickle
import numpy as npurl_base = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/'
key_file = {'train':'ptb.train.txt','test':'ptb.test.txt','valid':'ptb.valid.txt'
}
save_file = {'train':'ptb.train.npy','test':'ptb.test.npy','valid':'ptb.valid.npy'
}
vocab_file = 'ptb.vocab.pkl'dataset_dir = os.path.dirname(os.path.abspath(__file__))def _download(file_name):file_path = dataset_dir + '/' + file_nameif os.path.exists(file_path):returnprint('Downloading ' + file_name + ' ... ')try:urllib.request.urlretrieve(url_base + file_name, file_path)except urllib.error.URLError:import sslssl._create_default_https_context = ssl._create_unverified_contexturllib.request.urlretrieve(url_base + file_name, file_path)print('Done')def load_vocab():vocab_path = dataset_dir + '/' + vocab_fileif os.path.exists(vocab_path):with open(vocab_path, 'rb') as f:word_to_id, id_to_word = pickle.load(f)return word_to_id, id_to_wordword_to_id = {}id_to_word = {}data_type = 'train'file_name = key_file[data_type]file_path = dataset_dir + '/' + file_name_download(file_name)words = open(file_path).read().replace('\n', '<eos>').strip().split()for i, word in enumerate(words):if word not in word_to_id:tmp_id = len(word_to_id)word_to_id[word] = tmp_idid_to_word[tmp_id] = wordwith open(vocab_path, 'wb') as f:pickle.dump((word_to_id, id_to_word), f)return word_to_id, id_to_worddef load_data(data_type='train'):''':param data_type: 数据的种类:'train' or 'test' or 'valid (val)':return:'''if data_type == 'val': data_type = 'valid'save_path = dataset_dir + '/' + save_file[data_type]word_to_id, id_to_word = load_vocab()if os.path.exists(save_path):corpus = np.load(save_path)return corpus, word_to_id, id_to_wordfile_name = key_file[data_type]file_path = dataset_dir + '/' + file_name_download(file_name)words = open(file_path).read().replace('\n', '<eos>').strip().split()corpus = np.array([word_to_id[w] for w in words])np.save(save_path, corpus)return corpus, word_to_id, id_to_wordif __name__ == '__main__':for data_type in ('train', 'val', 'test'):load_data(data_type)
使用ptb.py
corpus保存了单词ID列表,id_to_word 是将单词ID转化为单词的字典,word_to_id 是将单词转化为单词ID的字典。
使用ptb.load_data()加载数据。里面的参数 ‘train’、‘test’、‘valid’ 分别对应训练用数据、测试用数据、验证用数据。
import sys
sys.path.append('..')
from dataset import ptbcorpus, word_to_id, id_to_word = ptb.load_data('train')print('corpus size:', len(corpus))
print('corpus[:30]:', corpus[:30])
print()
print('id_to_word[0]:', id_to_word[0])
print('id_to_word[1]:', id_to_word[1])
print('id_to_word[2]:', id_to_word[2])
print()
print("word_to_id['car']:", word_to_id['car'])
print("word_to_id['happy']:", word_to_id['happy'])
print("word_to_id['lexus']:", word_to_id['lexus'])
结果:
corpus size: 929589
corpus[:30]: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 2324 25 26 27 28 29]id_to_word[0]: aer
id_to_word[1]: banknote
id_to_word[2]: berlitzword_to_id['car']: 3856
word_to_id['happy']: 4428
word_to_id['lexus']: 7426Process finished with exit code 0
计数方法应用于PTB数据集
其实和不用PTB数据集的区别就在于这句话。
corpus, word_to_id, id_to_word = ptb.load_data('train')
下面这句话起降维的效果
word_vecs = U[:, :wordvec_size]
整个代码其实耗时最大的是在下面这个函数上:
W = ppmi(C, verbose=True)
完整代码:
import sys
sys.path.append('..')
import numpy as np
from common.util import most_similar, create_co_matrix, ppmi
from dataset import ptbwindow_size = 2
wordvec_size = 100corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
print('counting co-occurrence ...')
C = create_co_matrix(corpus, vocab_size, window_size)
print('calculating PPMI ...')
W = ppmi(C, verbose=True)print('calculating SVD ...')
#try:# truncated SVD (fast!)
print("ok")
from sklearn.utils.extmath import randomized_svd
U, S, V = randomized_svd(W, n_components=wordvec_size, n_iter=5,random_state=None)
#except ImportError:# SVD (slow)# U, S, V = np.linalg.svd(W)word_vecs = U[:, :wordvec_size]querys = ['you', 'year', 'car', 'toyota']
for query in querys:most_similar(query, word_to_id, id_to_word, word_vecs, top=5)
下面这个是用普通的np.linalg.svd(W)做出的结果。
[query] youi: 0.7016294002532959we: 0.6388039588928223anybody: 0.5868048667907715do: 0.5612815618515015'll: 0.512611985206604[query] yearmonth: 0.6957005262374878quarter: 0.691483736038208earlier: 0.6661213636398315last: 0.6327787041664124third: 0.6230476498603821[query] carluxury: 0.6767407655715942auto: 0.6339930295944214vehicle: 0.5972712635993958cars: 0.5888376235961914truck: 0.5693157315254211[query] toyotamotor: 0.7481387853622437nissan: 0.7147319316864014motors: 0.6946366429328918lexus: 0.6553674340248108honda: 0.6343469619750977
下面结果,是用了sklearn模块里面的randomized_svd方法,使用了随机数的 Truncated SVD,仅对奇异值较大的部分进行计算,计算速度比常规的 SVD 快。
calculating SVD ...
ok[query] youi: 0.6678948998451233we: 0.6213737726211548something: 0.560122013092041do: 0.5594725608825684someone: 0.5490139126777649[query] yearmonth: 0.6444296836853027quarter: 0.6192560791969299next: 0.6152222156524658fiscal: 0.5712860226631165earlier: 0.5641934871673584[query] carluxury: 0.6612467765808105auto: 0.6166062355041504corsica: 0.5270425081253052cars: 0.5142025947570801truck: 0.5030257105827332[query] toyotamotor: 0.7747215628623962motors: 0.6871038675308228lexus: 0.6786072850227356nissan: 0.6618651151657104mazda: 0.6237337589263916Process finished with exit code 0