项目地址
import numpy as np
import pandas as pd
import torch
from tqdm import tqdmfrom infer_model import SamOutdef load_model_and_voc(device="cpu"):voc = pd.read_pickle("total_voc.pkl")net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)# net = SamOut(len(voc["voc"]), 512, 32, 8)print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum([i.shape[0] for i in net.parameters() if len(i.shape) == 1]))# net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))# net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))# net.load_state_dict(torch.load("pretrain.pth", map_location=device))net.to(device)net.eval()return net, vocdef gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cuda"):print("agent:", end="", flush=True)model.to(device)state=Nonefor _ in range(max_len):prompt_list = []for i in prompt:if i not in voc["voc"]:prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]else:prompt_list.append(voc["voc"].index(i))if state is None:out, state = model(torch.Tensor([prompt_list]).to(device).long())else:out, state = model(torch.Tensor([prompt_list[-1:]]).to(device).long(),state)out = out[:, -1:]# 重复抑制for token_id in enumerate(prompt_list):out[:, :, token_id] /= rpscore = torch.softmax(out, -1)[0, 0]score, score_index = torch.sort(score,descending=True)if device=="cpu":score=score.detach().numpy()score_index = score_index.detach().numpy()else:score = score.cpu().detach().numpy()score_index = score_index.cpu().detach().numpy()score_sum = np.cumsum(score)score1=score[score_sum<0.9]if score1.size==0:score=score[:1]else:score=score1score_index=score_index[:min(top_k, score.size)]out = score / tempv= out[:min(top_k, score.size)]idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)if voc["voc"][score_index[idx_next.item()]] == "<|sos|>":breakprompt += [voc["voc"][score_index[idx_next.item()]]]print(prompt[-1], end="", flush=True)def t_infre():model, voc = load_model_and_voc()while True:text = input("user:")gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)print()if __name__ == '__main__':t_infre()
这段代码实现了一个基于PyTorch的文本生成模型的推理过程,它能够根据用户输入的提示(prompt)生成相应的回复。下面是对代码的主要部分进行解析:
1. 模型加载函数 load_model_and_voc
此函数负责加载词汇表和预训练模型,并将模型设置为评估模式。这里使用了Pandas读取了一个名为total_voc.pkl
的词汇表文件,该文件包含了两个键:voc
代表主要词汇表,而voc0
可能是用于处理未知词汇的映射。
def load_model_and_voc(device="cpu"):voc = pd.read_pickle("total_voc.pkl")net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum([i.shape[0] for i in net.parameters() if len(i.shape) == 1]))net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))net.to(device)net.eval()return net, voc
SamOut
是一个自定义的神经网络模型类,它接收词汇大小、隐藏层维度、注意力头数量以及解码层数作为参数。- 加载预训练权重时指定了设备(CPU或GPU),并打印了模型参数的数量以供调试。
- 最后返回了准备好的模型实例和词汇表。
2. 文本生成函数 gen_token
该函数实现了给定提示后的逐词生成逻辑,包括词汇索引转换、重复抑制、温度采样及Top-K采样等机制。
def gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cuda"):...
-
输入参数:
voc
: 包含词汇信息的数据结构。model
: 已经加载并准备好使用的神经网络模型。prompt
: 用户提供的初始文本序列。max_len
: 生成的最大长度。rp
,temp
,top_k
: 控制生成策略的超参数。device
: 执行计算的目标硬件(默认是CUDA)。
-
核心步骤:
- 将输入文本转换成对应的词汇ID列表。
- 使用模型预测下一个词汇的概率分布,并应用一系列策略来选择最合适的词汇。
- 更新状态(如果有),并将新词汇添加到输出序列中。
- 循环直到达到最大长度或者遇到特殊终止标记(如
<|sos|>
)。
3. 推理循环 t_infre
这是主程序入口,创建了一个无限循环等待用户输入,并调用gen_token
函数来生成回应。
def t_infre():model, voc = load_model_and_voc()while True:text = input("user:")gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)print()
- 首先调用了
load_model_and_voc
初始化模型和词汇表。 - 然后进入一个无限循环,每次迭代都会从标准输入获取一行文本作为用户的询问。
- 对于每个询问,它会构造一个带有起始和结束标记的完整提示,并调用
gen_token
来生成响应。 - 最终打印出生成的结果,并继续等待下一个用户输入。
总结
整个脚本通过结合上述三个主要组件——模型加载、文本生成以及交互式对话循环——实现了一个人机对话系统的基础框架。特别值得注意的是,代码中对于词汇表的处理方式,即如何将输入文本映射到模型可以理解的形式,以及在生成过程中采取的各种策略来提高生成质量。此外,还展示了如何利用tqdm
库来跟踪长任务的进度,尽管在这个具体的例子中没有直接展示tqdm
的应用,但在类似的长时间运行的任务中非常有用。最后,代码遵循了良好的实践,比如使用了上下文管理器(虽然在这里未显式出现)和适当的错误处理机制,确保了系统的健壮性和用户体验。