GPT建模与预测实战

代码链接见文末

效果图:

1.数据样本生成方法

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl

其中train.pkl是处理后的文件

因此,我们首先需要执行preprocess.py进行预处理操作,配置参数:

--data_path data/novel --save_path data/train.pkl --win_size 200 --step 200

其中--vocab_file是语料表,一般不用修改,--log_path是日志路径

预处理流程如下:

  • 首先,初始化tokenizer
  • 读取作文数据集目录下的所有文件,预处理后,对于每条数据,使用滑动窗口对其进行截断
  • 最后,序列化训练数据 

代码如下:

# 初始化tokenizertokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")#pip install jiebaeod_id = tokenizer.convert_tokens_to_ids("<eod>")   # 文档结束符sep_id = tokenizer.sep_token_id# 读取作文数据集目录下的所有文件train_list = []logger.info("start tokenizing data")for file in tqdm(os.listdir(args.data_path)):file = os.path.join(args.data_path, file)with open(file, "r", encoding="utf8")as reader:lines = reader.readlines()title = lines[1][3:].strip()    # 取出标题lines = lines[7:]   # 取出正文内容article = ""for line in lines:if line.strip() != "":  # 去除换行article += linetitle_ids = tokenizer.encode(title, add_special_tokens=False)article_ids = tokenizer.encode(article, add_special_tokens=False)token_ids = title_ids + [sep_id] + article_ids + [eod_id]# train_list.append(token_ids)# 对于每条数据,使用滑动窗口对其进行截断win_size = args.win_sizestep = args.stepstart_index = 0end_index = win_sizedata = token_ids[start_index:end_index]train_list.append(data)start_index += stepend_index += stepwhile end_index+50 < len(token_ids):  # 剩下的数据长度,大于或等于50,才加入训练数据集data = token_ids[start_index:end_index]train_list.append(data)start_index += stepend_index += step# 序列化训练数据with open(args.save_path, "wb") as f:pickle.dump(train_list, f)

2.模型训练过程

 (1) 数据与标签

        在训练过程中,我们需要根据前面的内容预测后面的内容,因此,对于每一个词的标签需要向后错一位。最终预测的是每一个位置的下一个词的token_id的概率。

(2)训练过程

        对于每一轮epoch,我们需要统计该batch的预测token的正确数与总数,并计算损失,更新梯度。

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0  # 记录下整个epoch的loss的总和epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量epoch_total_num = 0  # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):# 捕获cuda out of memory exceptiontry:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()# 统计该batch的预测token的正确数与总数batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 进行一定step的梯度累计之后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exception# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# save modellogger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss

(3)部署与网页预测展示

        app.py既是模型预测文件,又能够在网页中展示,这需要我们下载一个依赖库:

pip install streamlit

        

生成下一个词流程,每次只根据当前位置的前context_len个token进行生成:

  • 第一步,先将输入文本截断成训练的token大小,训练时我们采用的200,截断为后200个词
  • 第二步,预测的下一个token的概率,采用温度采样和topk/topp采样

最终,我们不断的以自回归的方式不断生成预测结果

这里指定模型目录 

进入项目路径

执行streamlit run app.py 

 生成效果:

 数据与代码链接:https://pan.baidu.com/s/1XmurJn3k_VI5OR3JsFJgTQ?pwd=x3ci 
提取码:x3ci 

 

         

      

 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/810715.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android-NDK的linux交叉编译环境

NDK工具包下载 NDK 下载 | Android NDK | Android Developers https://github.com/android/ndk/wiki/Unsupported-Downloads 以android-ndk-r26c下载为例&#xff0c;下载后将压缩包解压至/usr目录下 CMakeLists编译选项设置 编译平台变量判断条件中增加一下android条件…

ubuntu下man手册 查不到 pthread_mutex_lock等系列函数用法的问题

问题 在ubuntu系统中无法man到 pthread_mutex_lock pthread_mutex_trylock pthread_mutex_unlock等函数 $ man pthread_mutex_lock 没有 pthread_mutex_lock 的手册页条目解决方式 输入以下命令 sudo apt-get install manpages-posix manpages-posix-dev 然后输入密码 再次m…

MobaXterm无法登陆oracle cloud的问题

问题 我在oracle cloud上创建实例的时候&#xff0c;只能使用密钥的方式登陆&#xff0c;当时下载了私钥文件。实例创建好以后&#xff0c;在mobaxterm上使用这个私钥文件无法登陆 排查 尝试使用mobaxterm的keygen&#xff0c;把私钥文件转成ppk格式&#xff0c;还是不行。…

高中数学:三角函数-同角与异角的三角函数关系

一、同角三角函数关系 1、基本公式 知一求二 2、快速求值方法 重点掌握辅助三角形方法 3、题型 3.1、一次式整式求值 sinα和cosα指数是一次的求值&#xff0c;建议用辅助三角形方法 例题 3.2、一次式分式求值 分子、分母同除以sinα或者cosα 例题 3.3、二次式…

区块链安全-----接口测试-Postman

Postman是一款支持http协议的接口调试与测试工具&#xff0c;其主要特点就是功能强大&#xff0c;使用简单且易 用性好 。无论是开发人员进行接口调试&#xff0c;还是测试人员做接口测试&#xff0c;Postman都是我们的首选工具 之一 。 更早的接入测试&#xff0c;更早的发现问…

切面条(蓝桥杯)

目录 题目 分析 代码实现 题目 一根高筋拉面&#xff0c;中间切一刀&#xff0c;可以得到2根面条。 如果先对折1次&#xff0c;中间切一刀&#xff0c;可以得到3根面条。 如果连续对折2次&#xff0c;中间切一刀&#xff0c;可以得到5根面条。 那么&#xff0c;连续对折1…

光耦合器的使用:了解输入和输出之间的关系

光耦合器也称为光隔离器&#xff0c;是许多电子电路中的重要组件&#xff0c;可在输入和输出信号之间提供隔离。它们在各种应用中确保安全、降低噪声和防止接地环路方面发挥着至关重要的作用。在本文中&#xff0c;我们将深入研究光耦合器的基础知识&#xff0c;探讨它们的工作…

人形机器人行业报告:AI赋能人形机器人开启产业化元年

今天分享的是人形机器人专题系列深度研究报告&#xff1a;《AI赋能&#xff0c;人形机器人开启产业化元年》。 &#xff08;报告出品方&#xff1a;国泰君安证券&#xff09; 报告共计&#xff1a;56页 要点 通用性是人形机器人商业化的关键&#xff0c;AI大模型赋能加速机…

打破传统,蔚莱普康定义国货美妆新未来

在全球美妆市场经济改革的今天&#xff0c;中国新兴品牌蔚莱普康&#xff0c;正在以前所未有的速度和规模&#xff0c;冲破瓶颈&#xff0c;赢得市场的广泛认可。这一切&#xff0c;得益于国家政策的扶持和国货品牌自身的不懈努力与创新。 各类国潮产品不断‘出圈’的背后&…

java程序 .exe启动nginx防止重复启动,已解决

java代码生成好的.exe启动nginx服务程序 根据nginx占用端口来解决nginx服务重复启动问题&#xff08;下面代码了解代码逻辑后根据自己的业务需求修改即可&#xff09; 代码&#xff1a; package org.example;import javax.swing.*; import java.awt.*; import java.io.*; …

Linux-select剖析

一、select函数 select函数是IO多路复用的函数&#xff0c;它主要的功能是用来等文件描述符中的事件是否就绪&#xff0c;select可以使我们在同时等待多个文件缓冲区 &#xff0c;减少IO等待的时间&#xff0c;能够提高进程的IO效率。 select()函数允许程序监视多个文件描述符…

书生·浦语大模型全链路开源体系-第3课

书生浦语大模型全链路开源体系-第3课 书生浦语大模型全链路开源体系-第3课相关资源RAG 概述在 InternLM Studio 上部署茴香豆技术助手环境配置配置基础环境下载基础文件下载安装茴香豆 使用茴香豆搭建 RAG 助手修改配置文件 创建知识库运行茴香豆知识助手 在茴香豆 Web 版中创建…

工作日常随记-总

软件测试主管工作日常随记-总 前言 接下来&#xff0c;我将开始散文式地记录我作为一位从业3年多的软件测试人员的软测经验。这是我在繁忙的日常工作的中跋涉出来又又投入的另一工作&#xff08;bushi&#xff09;另一兴趣中去。 我将简单&#xff08;偏流水线向&#xff09;…

50. QT/QML中创建多线程的方式汇总

1. 说明 在QT / QML中创建线程主要有三种方式。第一种:在定义类时继承 QThread 这个类,然后重写父类的虚函数 run(),将子线程需要执行的业务代码放到 run() 函数当中即可。**注意:**这种方式官方已经摒弃了。第二种:使用moveToThread()函数将需要在子线程中执行的函数类移…

【每日练习】二叉树

⭐ 作者&#xff1a;小胡_不糊涂 &#x1f331; 作者主页&#xff1a;小胡_不糊涂的个人主页 &#x1f4c0; 收录专栏&#xff1a;二叉树 &#x1f496; 持续更文&#xff0c;关注博主少走弯路&#xff0c;谢谢大家支持 &#x1f496; 文章目录 一、100. 相同的树1. 题目简介2.…

问题汇总

一、TCP的粘包和拆包问题&#xff1f; TCP在发送和接受数据的时候&#xff0c;有一个滑动窗口来控制接受数据的大小&#xff0c;这个滑动窗口你就可以理解为一个缓冲区的大小。缓冲区满了就会把数据发送&#xff0c;数据包的大小是不固定的&#xff0c;有时候比缓冲区大有时候…

NIST再次强调:2024-2030年,必须升级至抗量子算法

4月10日至12日&#xff0c;美国国家标准与技术研究院&#xff08;NIST&#xff09;在马里兰州罗克维尔举办第五届NIST PQC&#xff08;后量子密码学&#xff09;标准化会议&#xff0c;该会议的目的是对PQC算法进行全面讨论&#xff08;包括已选定和正在评估的算法&#xff09;…

如何处理ubuntu22.04LTS安装过程中出现“Daemons using outdated libraries”提示

Ubuntu 22.04 LTS 中使用命令行升级软件或安装任何新软件时&#xff0c;您可能收到“Daemons using outdated libraries”&#xff0c;“Which services should be restarted?”的提示&#xff0c;提示下面列出备选的重启服务&#xff0c;如下。 使用以下命令&#xff0c;能够…

LangChain - 文档加载

文章目录 一、关于 检索二、文档加载器入门指南 三、CSV1、使用每个文档一行的 CSV 数据加载 CSVLoader2、自定义 csv 解析和加载 &#xff08;csv_args3、指定用于 标识文档来源的 列&#xff08;source_column 四、文件目录 file_directory1、加载文件目录数据&#xff08;Di…

day05-java面向对象(上)

5.1 面向对象编程 5.1.1 类和对象 1、什么是类 类是一类具有相同特性的事物的抽象描述&#xff0c;是一组相关属性和行为的集合。 属性&#xff1a;就是该事物的状态信息。 行为&#xff1a;就是在你这个程序中&#xff0c;该状态信息要做什么操作&#xff0c;或者基于事物…