SamOut 任意长度推理空间不变

项目地址


import numpy as np
import pandas as pd
import torch
from tqdm import tqdmfrom infer_model import SamOutdef load_model_and_voc(device="cpu"):voc = pd.read_pickle("total_voc.pkl")net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)# net = SamOut(len(voc["voc"]), 512, 32, 8)print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum([i.shape[0] for i in net.parameters() if len(i.shape) == 1]))# net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))# net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))# net.load_state_dict(torch.load("pretrain.pth", map_location=device))net.to(device)net.eval()return net, vocdef gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cuda"):print("agent:", end="", flush=True)model.to(device)state=Nonefor _ in range(max_len):prompt_list = []for i in prompt:if i not in voc["voc"]:prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]else:prompt_list.append(voc["voc"].index(i))if state is None:out, state = model(torch.Tensor([prompt_list]).to(device).long())else:out, state = model(torch.Tensor([prompt_list[-1:]]).to(device).long(),state)out = out[:, -1:]# 重复抑制for token_id in enumerate(prompt_list):out[:, :, token_id] /= rpscore = torch.softmax(out, -1)[0, 0]score, score_index = torch.sort(score,descending=True)if device=="cpu":score=score.detach().numpy()score_index = score_index.detach().numpy()else:score = score.cpu().detach().numpy()score_index = score_index.cpu().detach().numpy()score_sum = np.cumsum(score)score1=score[score_sum<0.9]if score1.size==0:score=score[:1]else:score=score1score_index=score_index[:min(top_k, score.size)]out = score / tempv= out[:min(top_k, score.size)]idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)if voc["voc"][score_index[idx_next.item()]] == "<|sos|>":breakprompt += [voc["voc"][score_index[idx_next.item()]]]print(prompt[-1], end="", flush=True)def t_infre():model, voc = load_model_and_voc()while True:text = input("user:")gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)print()if __name__ == '__main__':t_infre()

这段代码实现了一个基于PyTorch的文本生成模型的推理过程,它能够根据用户输入的提示(prompt)生成相应的回复。下面是对代码的主要部分进行解析:

1. 模型加载函数 load_model_and_voc

此函数负责加载词汇表和预训练模型,并将模型设置为评估模式。这里使用了Pandas读取了一个名为total_voc.pkl的词汇表文件,该文件包含了两个键:voc代表主要词汇表,而voc0可能是用于处理未知词汇的映射。

def load_model_and_voc(device="cpu"):voc = pd.read_pickle("total_voc.pkl")net = SamOut(len(voc["voc"]), 1024 + 512, 64, 16)print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum([i.shape[0] for i in net.parameters() if len(i.shape) == 1]))net.load_state_dict(torch.load("pretrain_sft_single_1024.pth", map_location=device))net.to(device)net.eval()return net, voc
  • SamOut 是一个自定义的神经网络模型类,它接收词汇大小、隐藏层维度、注意力头数量以及解码层数作为参数。
  • 加载预训练权重时指定了设备(CPU或GPU),并打印了模型参数的数量以供调试。
  • 最后返回了准备好的模型实例和词汇表。

2. 文本生成函数 gen_token

该函数实现了给定提示后的逐词生成逻辑,包括词汇索引转换、重复抑制、温度采样及Top-K采样等机制。

def gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.13, top_k=16, device="cuda"):...
  • 输入参数:

    • voc: 包含词汇信息的数据结构。
    • model: 已经加载并准备好使用的神经网络模型。
    • prompt: 用户提供的初始文本序列。
    • max_len: 生成的最大长度。
    • rp, temp, top_k: 控制生成策略的超参数。
    • device: 执行计算的目标硬件(默认是CUDA)。
  • 核心步骤:

    • 将输入文本转换成对应的词汇ID列表。
    • 使用模型预测下一个词汇的概率分布,并应用一系列策略来选择最合适的词汇。
    • 更新状态(如果有),并将新词汇添加到输出序列中。
    • 循环直到达到最大长度或者遇到特殊终止标记(如<|sos|>)。

3. 推理循环 t_infre

这是主程序入口,创建了一个无限循环等待用户输入,并调用gen_token函数来生成回应。

def t_infre():model, voc = load_model_and_voc()while True:text = input("user:")gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 64)print()
  • 首先调用了load_model_and_voc初始化模型和词汇表。
  • 然后进入一个无限循环,每次迭代都会从标准输入获取一行文本作为用户的询问。
  • 对于每个询问,它会构造一个带有起始和结束标记的完整提示,并调用gen_token来生成响应。
  • 最终打印出生成的结果,并继续等待下一个用户输入。

总结

整个脚本通过结合上述三个主要组件——模型加载、文本生成以及交互式对话循环——实现了一个人机对话系统的基础框架。特别值得注意的是,代码中对于词汇表的处理方式,即如何将输入文本映射到模型可以理解的形式,以及在生成过程中采取的各种策略来提高生成质量。此外,还展示了如何利用tqdm库来跟踪长任务的进度,尽管在这个具体的例子中没有直接展示tqdm的应用,但在类似的长时间运行的任务中非常有用。最后,代码遵循了良好的实践,比如使用了上下文管理器(虽然在这里未显式出现)和适当的错误处理机制,确保了系统的健壮性和用户体验。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/64573.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

17.springcloud_openfeign之扩展组件一

文章目录 一、前言二、默认约定配置FeignAutoConfigurationCachingCapabilityFeignCachingInvocationHandlerFactoryFeignJacksonConfiguration熔断器配置FeignCircuitBreakerTargeterFeignCircuitBreaker.Builder FeignClientsConfigurationCircuitBreakerFactory 总结 一、前…

Python读取Excel批量写入到PPT生成词卡

一、问题的提出 有网友想把Excel表中的三列数据&#xff0c;分别是&#xff1a;单词、音标和释义分别写入到PPT当中&#xff0c;每一张PPT写一个单词的内容。这种批量操作是python的强项&#xff0c;尤其是在办公领域&#xff0c;它能较好地解放双手&#xff0c;读取Excel表后…

Proteus(8.15)仿真下载安装过程(附详细安装过程图)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、Proteus是什么&#xff1f; 二、下载链接 三、下安装步骤 1.解压&#xff0c;有键管理员运行 2.点击Next&#xff0c;进行下一步 3.勾选I accept…&#…

防止私接小路由器

电脑获取到IP地址不是DHCP服务器的IP地址段&#xff0c;导致整个公司网络瘫痪&#xff0c;这些故障现象通常80%原因是私接小路由器导致的&#xff0c;以下防止私接小路由器措施。 一、交换机配置DHCP Sooping DHCP snooping是一种DHCP安全特性&#xff0c;用于防止非法设备获…

动态导出word文件支持转pdf

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、功能说明二、使用步骤1.controller2.工具类 DocumentUtil 导出样式 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 例如&#xff…

紧固件设计之——开槽六角头防脱出杆螺栓仿真APP

按照产品形态分类&#xff0c;紧固件通常包括以下12类&#xff1a;螺栓、螺柱、螺钉、螺母、自攻螺钉、木螺钉、垫圈、挡圈、销、铆钉、焊钉、组合件与连接副&#xff0c;是一类用于连接和固定各种构件和零部件的重要机械零件&#xff0c;可确保机械装置或设备结构的牢固和稳定…

【Python装饰器】编写一个装饰器,并将其放到适当的位置,目的是让代码 1 秒钟打印一个结果

import timedef fib():back1, back2 0, 1def func():nonlocal back1, back2back1, back2 back2, back1 back2print(back1, end )return funcdef get_fib(n):f fib()for i in range(n):f()n int(input("请输入需要获取的斐波那契数&#xff1a;"))get_fib(n) imp…

mysql中与并发相关的问题?

今天我们来聊聊 MySQL 中与并发相关的一些问题。作为一名资深 Python 开发工程师&#xff0c;我觉得这些问题不仅关乎数据库的稳定性和数据的一致性&#xff0c;更与我们的代码实现和业务逻辑密切相关。 尤其是在高并发环境下&#xff0c;如何保证数据的一致性&#xff0c;如何…

【Mac】安装 PaddleOCR

环境&#xff1a;Mac M1 芯片 1、安装 Anaconda 安装较为简单&#xff0c;直接在 Anaconda 官网 下载pkg文件&#xff0c;根据向导提示完成安装。 Anaconda 用于搭建 Python 虚拟环境&#xff0c;目的是为了避免与之前环境安装库的版本冲突&#xff0c;另外 paddle 对Python…

使用k6进行kafka负载测试

1.安装环境 kafka环境 参考Docker搭建kafka环境-CSDN博客 xk6-kafka环境 ./xk6 build --with github.com/mostafa/xk6-kafkalatest 查看安装情况 2.编写脚本 test_kafka.js // Either import the module object import * as kafka from "k6/x/kafka";// Or in…

服务器ip:port服务用nginx 域名代理

ubuntu 1、安装nginx # 更新软件包列表 sudo apt update# 安装Nginx sudo apt install nginx -y# 检查Nginx状态 sudo systemctl status nginx2、创建存放域名 SSL证书的目录 # 创建目录 sudo mkdir -p /etc/nginx/ssl# 复制证书文件到该目录 sudo cp play.cn_bundle.crt /et…

[机器学习]XGBoost(3)——确定树的结构

XGBoost的目标函数详见[机器学习]XGBoost&#xff08;2&#xff09;——目标函数&#xff08;公式详解&#xff09; 确定树的结构 之前在关于目标函数的计算中&#xff0c;均假设树的结构是确定的&#xff0c;但实际上&#xff0c;当划分条件不同时&#xff0c;叶子节点包含的…

springboot444新冠物资管理系统的设计与实现(论文+源码)_kaic

摘 要 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff0c;在计算机上安装新冠物资管理系统软件来发挥其高效地信息处理的作用&#x…

Javascript-web API-day02

文章目录 01-事件监听02-点击关闭广告03-随机点名案例04-鼠标经过或离开事件05-可点击的轮播图06-小米搜索框07-键盘类型事件08-键盘事件-发布评论案例09-focus选择器10-评论回车发布11-事件对象12-trim方法13-环境对象14-回调函数15-tab栏切换 01-事件监听 <!DOCTYPE html…

使用xjar 对Spring-Boot JAR 包加密运行

1 Xjar 介绍 Spring Boot JAR 安全加密运行工具&#xff0c;同时支持的原生JAR。 基于对JAR包内资源的加密以及拓展ClassLoader来构建的一套程序加密启动&#xff0c;动态解密运行的方案&#xff0c;避免源码泄露或反编译。 功能特性 无需侵入代码&#xff0c;只需要把编译好的…

深度学习的下一站:解锁人工智能的新边界

引言&#xff1a;新边界的呼唤 深度学习的诞生&#xff0c;犹如人工智能领域的一次革命&#xff0c;激发了语音助手、自动驾驶、智能医疗等前沿技术的飞速发展。然而&#xff0c;面对现实世界的复杂性&#xff0c;现有的深度学习模型仍然存在数据依赖、可解释性差、环境适应力不…

基于DockerCompose搭建Redis主从哨兵模式

linux目录结构 内网配置 哨兵配置文件如下&#xff0c;创建3个哨兵配置文件 # sentinel26379.conf sentinel26380.conf sentinel26381.conf 内容如下 protected-mode no sentinel monitor mymaster redis-master 6379 2 sentinel down-after-milliseconds mymaster 60000 s…

Vite 与 Webpack 的区别

在前端开发中&#xff0c;构建工具是不可或缺的&#xff0c;Webpack 和 Vite 是当前最流行的选择之一。尽管它们的目标相似&#xff0c;但在实现方式和开发体验上却有显著差异。本文将探讨 Vite 和 Webpack 的主要区别&#xff0c;以便于根据项目需求选择合适的工具。 1. 构建…

upload-labs靶场1-19关

第 1 关&#xff08;删除前端js校验&#xff09; 点击第一关&#xff0c;我们可以看到页面上传区可以上传一个图片&#xff0c;我们要上传一个 webshell&#xff0c;这里我们上传一句话木马的 php 点击上传 显示文件不支持上传&#xff0c;这时我们查看源码 查看代码后发现&am…

vue3+vite 引入动画组件库 Inspira UI

关于Inspira UI Inspira UI不是传统的组件库。相反&#xff0c;它是精选的优雅组件集合&#xff0c;您可以轻松将其集成到您的应用程序中。只需选择所需的组件&#xff0c;复制代码&#xff0c;然后自定义以适合您的项目即可。您可以随意使用和修改代码&#xff01; 官网地址…