我们拿到在拿到一堆语料数据,或者是在网络中爬取下来的文本数据如何处理成为模型能够训练的数据呢?这里有我们先经过停用词和按字分词的处理之后,得到的问答对文本数据,input_by_word.txt 和 target_by_word.txt 。其中,input_by_word.txt 里面存放问题,如下:
target_by_word.txt 里面存放回答,如下:
对应行组成一个问答对(qa_pair)。
我们的目的是把一行中文句子转换为对应的一行数字,每一个字对应的数字都应该是独一无二的,比如,['不', '是'] 会对应的转换为 [ 6, 4 ]这样的形式。那么我们可以通过构建字典来完成这种映射关系,即 {..., '不': 6, '是': 4 ,...},另外,我们还需要将模型的输出转换回原本的句子的形式,即将 [ 6, 4 ] 转换为 ['不', '是'] ,所以我们还需要实现反转字典,即 {... , 6 : 不', 4 : '是' , ...} 。
我们可以把每个字作为一个词元,统计这个词元在 input_by_word.txt 和 target_by_word.txt 里出现的频率,根据词元的频率排序得到一个词典,或者是分别对input_by_word.txt 和 target_by_word.txt 里面词元出现的频率,根据频率排序,得到 input_by_word.txt 对应的词典 和 target_by_word.txt 对应的词典。另外,我们的词典中还包含几个特殊词元应对不同情况的需要,如未知词元<"UNK">等。假设我们通过按行遍历的方式,拿到了每一行的句子(sentence),通过对每一行句子进行处理,即目前待处理的数据是这种形式为 list :['是', '王', '若', '猫', '的' ] 。
class Wordsequence():PAD_TAG = "PAD"PAD = 0UNK_TAG = "UNK"UNK = 1SOS_TAG = "SOS"SOS = 2EOS_TAG = "EOS"EOS = 3def __init__(self):self.dict = {self.PAD_TAG: self.PAD,self.UNK_TAG: self.UNK,self.SOS_TAG: self.SOS,self.EOS_TAG: self.EOS}self.inverse_dict = {} # 反转字典self.count = {} # 用于记录词频def fit(self, sentence):"""词频统计:param sentence: list:return:"""for word in sentence:self.count[word] = self.count.get(word, 0) + 1
记录完所有词元的出现次数,我们着手构建词表,可以规定词元的最大频率 max_count 和最小频率 min_count ,以及词表最大容量 max_features ,通常我们最大词频对最大词频不做限制,对小于最小词频的词元不记录在字典,统一用 <'UNK'> 替代,另外还可以按照词频降序排序。下面我们实现了构建词典和反转词典的方法:
def build_vocab(self, min_count=5, max_count=None, max_features=None):"""构建词典:param min_count: 最小词频:param max_count: 最大词频:param max_features: 最大词元数:return:"""# 字典在遍历的过程中无法改变temp = self.count.copy()for key in temp:cur_count = self.count.get(key, 0)if min_count is not None:if cur_count < min_count:del self.count[key]if max_count is not None:if cur_count > max_count:del self.count[key]if max_features is not None:self.count = (sorted(self.count.items(), key=lambda x: x[1], reverse=True)[:max_features])for key in self.count:self.dict[key] = len(self.dict)self.inverse_dict = zip(self.dict.values(), self.dict.keys())
我们构建的两个词典分别如下:
input :{'PAD': 0, 'UNK': 1, 'SOS': 2, 'EOS': 3, '呵': 4, '不': 5, '是': 6, '怎': 7, '么': 8, '了': 9, ...}
target:{'PAD': 0, 'UNK': 1, 'SOS': 2, 'EOS': 3, '是': 4, '王': 5, '若': 6, '猫': 7, '的': 8, '。': 9, ...}
在我们得到了词典之后我们就可以依次把每个句子映射为序列了。
在这里我们把句子转换为序列
def transform(self, sentence, max_len, add_eos=False):"""把sentence转化为数值序列当add_eos=False时,sentence长度会是max_len当add_eos=True时,sentence长度会是max_len+1:param sentence::param max_len::param add_eos: add_eos=True使用 --> 也就是说输出长度会是max_len+1add_eos=False中使用 --> 也就是说输出长度会是max_len:return:"""# 把sentence长度固定为一致if len(sentence) > max_len:sentence = sentence[:max_len]sentence_len = len(sentence) # 提前计算句子长度,实现add_eos后,句子长度还是一致if add_eos:sentence = sentence+[self.EOS_TAG]if sentence_len < max_len:sentence = sentence + [self.PAD_TAG]*(max_len-sentence_len)result = [self.dict.get(i, self.UNK) for i in sentence] # 若在字典中未出现则用UNK字符替代return result
以及将序列反转回句子的形式和整个词典的大小:
def inverse_transform(self, indices):"""把序列转为sentence:param indices: 序列:return: list 不包含有EOS的词元"""result = []for i in indices:if i == self.EOS:breakresult.append(self.inverse_dict.get(i, self.UNK_TAG))return resultdef __len__(self):return len(self.dict) # 返回词典大小
至此我们完成了构建词典和序列化句子所需要的函数。并通过 pickle 保存把词典保存为pkl文件,在重新加载pkl的时候需要把该类重新导入一下。
def save_ws():ws = Wordsequence()for line in open(config.chatbot_input_by_word_path, encoding="utf-8").readlines():ws.fit(line.strip().split())ws.build_vocab()print(ws.dict, len(ws))pickle.dump(ws, open(config.chatbot_ws_input_by_word_path, "wb"))ws = Wordsequence()for line in open(config.chatbot_target_by_word_path, encoding="utf-8").readlines():ws.fit(line.strip().split())ws.build_vocab()print(ws.dict, len(ws))pickle.dump(ws, open(config.chatbot_ws_target_by_word_path, "wb"))
我们现在要将文本数据变成这种形式为 list :['是', '王', '若', '猫', '的' ] 。我们在构建 dataset 的时候处理:
class ChatbotDataset(Dataset):def __init__(self):self.input_path = config.chatbot_input_by_word_pathself.target_path = config.chatbot_target_by_word_pathself.input_lines = open(self.input_path, encoding="utf-8").readlines()self.target_lines = open(self.target_path, encoding="utf-8").readlines()# 确保文本数据input 和 target数量一致assert len(self.input_lines) == len(self.target_lines), "input和target长度一致"def __getitem__(self, idx):input = self.input_lines[idx].strip().split() # 按空格切分为列表target = self.target_lines[idx].strip().split()input_length = len(input) # input的句子真实长度target_length = len(target) # target的句子真实长度return input, target, input_length, target_lengthdef __len__(self):return len(self.input_lines)
然后重写DataLoader的 collate_fn 函数:
def collate_fn(batch):"""重写cllate_fn:param batch: [(input, target, input_length, target_length),(input, target, input_length, target_length),...]:return:"""# 对input的句子,按句子长度降序排序sorted(batch, key=lambda x: x[2], reverse=True)input, target, input_length, target_length = zip(*batch)input = [config.chatbot_ws_by_word_input.transform(i, max_len=config.chatbot_input_max_seq_len) for i in input]target = [config.chatbot_ws_by_word_target.transform(i, max_len=config.chatbot_target_max_seq_len, add_eos=True) for i in target]input = torch.LongTensor(input)target = torch.LongTensor(target)input_length = torch.LongTensor(input_length)target_length = torch.LongTensor(target_length)return input, target, input_length, target_length
train_data_loader = DataLoader(ChatbotDataset(),batch_size=config.chatbot_batch_size,shuffle=True, collate_fn=collate_fn)
if __name__ == '__main__':for idx, (input, target, input_length, target_length) in enumerate(train_data_loader):print(idx)print(input, target, input_length, target_length)break
最后我们可以的到序列化后的数据:
0
tensor([[ 22, 59, 311, ..., 0, 0, 0],[ 19, 482, 9, ..., 0, 0, 0],[617, 807, 0, ..., 0, 0, 0],...,[670, 671, 17, ..., 0, 0, 0],[ 30, 349, 764, ..., 0, 0, 0],[220, 955, 6, ..., 0, 0, 0]])
tensor([[ 582, 218, 386, ..., 0, 0, 0],[1740, 3, 0, ..., 0, 0, 0],[ 304, 578, 859, ..., 0, 0, 0],...,[1852, 25, 131, ..., 0, 0, 0],[ 70, 265, 23, ..., 0, 0, 0],[ 297, 4, 56, ..., 74, 1362, 3]])
tensor([ 9, 3, 2, 11, 18, 3, 3, 5, 10, 7, 10, 1, 9, 4, 2, 5, 7, 77,45, 2, 15, 2, 3, 12, 11, 4, 5, 4, 45, 3, 12, 2, 5, 9, 7, 13,5, 9, 4, 11, 21, 6, 7, 5, 5, 3, 7, 6, 10, 4, 18, 19, 10, 4,2, 5, 3, 5, 4, 7, 6, 7, 4, 9, 9, 3, 3, 8, 6, 19, 6, 14,3, 10, 7, 4, 5, 5, 2, 5, 5, 11, 2, 7, 9, 6, 6, 18, 8, 4,4, 12, 3, 8, 9, 10, 19, 7, 5, 3, 8, 9, 9, 13, 10, 3, 6, 5,4, 4, 34, 11, 12, 7, 5, 6, 20, 2, 8, 3, 16, 5, 3, 4, 7, 12,6, 4])
tensor([ 4, 1, 6, 10, 3, 12, 15, 2, 5, 4, 11, 3, 6, 6, 8, 7, 4, 7,4, 12, 9, 8, 6, 5, 5, 5, 18, 4, 5, 1, 9, 13, 8, 1, 5, 1,6, 4, 13, 4, 9, 14, 12, 2, 5, 13, 12, 9, 9, 9, 4, 2, 4, 14,4, 4, 9, 6, 6, 5, 5, 2, 7, 6, 10, 14, 6, 10, 4, 7, 6, 7,2, 6, 7, 7, 3, 3, 4, 4, 4, 10, 4, 6, 5, 3, 13, 6, 3, 5,16, 2, 7, 10, 2, 3, 6, 9, 6, 3, 9, 13, 12, 8, 7, 10, 7, 12,4, 8, 2, 3, 7, 5, 4, 14, 2, 6, 4, 1, 5, 9, 14, 2, 3, 4,5, 79])Process finished with exit code 0