从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练

3 数据加载函数

def load_dataset(logger, args):"""加载训练集"""logger.info("loading training dataset")train_path = args.train_pathwith open(train_path, "rb") as f:train_list = pickle.load(f)# test# train_list = train_list[:24]train_dataset = CPMDataset(train_list, args.max_len)return train_dataset
  1. List item

4 训练函数

def train(model, logger, train_dataset, args):train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,drop_last=True)logger.info("total_steps:{}".format(len(train_dataloader)* args.epochs))t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochsoptimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)# 设置warmuplogger.info('start training')train_losses = []   # 记录每个epoch的平均loss# ========== start training ========== #for epoch in range(args.epochs):train_loss = train_epoch(model=model, train_dataloader=train_dataloader,optimizer=optimizer, scheduler=scheduler,logger=logger, epoch=epoch, args=args)train_losses.append(round(train_loss, 4))logger.info("train loss list:{}".format(train_losses))logger.info('training finished')logger.info("train_losses:{}".format(train_losses))

5 迭代训练函数

def train_epoch(model, train_dataloader, optimizer, scheduler, logger,epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0  # 记录下整个epoch的loss的总和epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量epoch_total_num = 0  # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):# 捕获cuda out of memory exceptiontry:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()# 统计该batch的预测token的正确数与总数batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 进行一定step的梯度累计之后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exception# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# save modellogger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss

从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/195354.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Centos系列:Centos7下部署nginx(三种方式安装部署,图文结合超详细,适合初学者)

Centos7下部署nginx(三种方式安装部署,图文结合超详细,适合初学者) Centos7下部署nginx一. ngxin是什么二. nginx的作用正向代理和反向代理的区别 三. 安装部署安装环境1. yum安装配置nginx源启动nginx浏览器访问, IP:…

打印菱形图案C语言

C代码实现&#xff1a; #include <stdio.h> void printDiamond(int n) { int i, j, space n - 1; // 打印上半部分包括中间行 for (i 0; i < n; i) { // 打印空格 for (j 0; j < space; j) printf(" "); // 打印星号 for (j 1; j < 2 *…

Canvas鼠标画线

鼠标按下开始画线,鼠标移动根据鼠标的轨迹去画,鼠标抬起停止画线 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0">…

Java多线程技术三:锁——ReentrantLock的使用

1 概述 在Java多线程中可以使用synchronzied关键字来实现线程间同步&#xff0c;不过在JDK1.5中新增的ReentrantLock类也能达到同样的效果&#xff0c;并且在扩展功能上更加强大。

举例说明自然语言处理(NLP)技术。

本文章由AI生成&#xff01; 以下是自然语言处理&#xff08;NLP&#xff09;技术的一些例子&#xff1a; 机器翻译&#xff1a;将一种语言翻译成另一种语言的自动化过程。常见的机器翻译系统包括谷歌翻译&#xff0c;百度翻译等。 语音识别&#xff1a;将口头语言转换成文本…

备忘录怎么传到电脑?备忘录手机电脑互传方法

对于那些记性不好的人来说&#xff0c;手机上的备忘录简直是个不可或缺的好帮手。可是有时候&#xff0c;我们在手机上记录的内容需要在电脑上查看&#xff0c;这时候该怎么办呢&#xff1f; 曾经&#xff0c;我也为备忘录的手机电脑互传问题头疼不已。手机上记录的事项&#…

Pytorch当中transpose()和permute()函数的区别

在 PyTorch 中&#xff0c;transpose() 和 permute() 都是用于张量维度的转换&#xff0c;但有一些区别&#xff1a; transpose() 方法&#xff1a; transpose() 方法允许你交换张量的两个维度&#xff0c;使其维度发生变化。当你使用 transpose(dim1, dim2) 时&#xff0c;它会…

element UI改写时间线组件为左右分布

2023.12.4今天我学习了如何使用element的时间线组件&#xff0c;效果如&#xff1a; 代码如下&#xff1a;&#xff08;关键代码 v-if"item.send_type"&#xff09;判断左右分布情况。因为如果没有这个判断的话&#xff0c;其实会两边都有显示。可以用一个判断表示0显…

基于ssm的疫苗预约系统(有报告)。Javaee项目。ssm项目。

演示视频&#xff1a; 基于ssm的疫苗预约系统&#xff08;有报告&#xff09;。Javaee项目。ssm项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&#xff0c;通过Spring Spri…

语音芯片的BUSY状态指示功能特征:提升用户体验与系统稳定性的关键

在电子产品的音频系统中&#xff0c;语音芯片扮演着至关重要的角色。为了保证音频的流畅播放和功能的正常运行&#xff0c;语音芯片的各种状态指示功能变得尤为重要。其中&#xff0c;BUSY状态指示功能是语音芯片中的一项关键特征&#xff0c;它对于提升用户体验和系统稳定性具…

Pytorch深度强化学习1-5:详解蒙特卡洛强化学习原理

目录 0 专栏介绍1 蒙特卡洛强化学习2 策略评估原理3 策略改进原理3.1 同轨蒙特卡洛强化学习3.2 离轨蒙特卡洛强化学习 0 专栏介绍 本专栏重点介绍强化学习技术的数学原理&#xff0c;并且采用Pytorch框架对常见的强化学习算法、案例进行实现&#xff0c;帮助读者理解并快速上手…

C++STL容器

一、顺序性容器 简述&#xff1a;顺序容器为程序员提供了控制元素存储和访问顺序的能力。这种顺序不依赖元素的值&#xff0c;而是与元素加入容器时的位置相对应。所有顺序容器都提供了快速顺序访问元素的能力 1.vector(向量) 基本概念和介绍 对于vector容器&#xff0c;它…

大模型概述

文章目录 AI大模型的定义AI大模型的分类LoRA 微调 AI大模型的定义 AI大模型是通过深度学习算法和人工神经网络训练出的具有庞大规模参数的人工智能模型。这些模型使用大量的多媒体数据资源作为输入&#xff0c;并通过复杂的数学运算和优化算法来完成大规模的训练&#xff0c;以…

4382系列数字荧光示波器

4382系列数字荧光示波器 简述&#xff1a; 4382系列手持式数字荧光示波器具有8个产品型号&#xff0c;带宽200MHz、350MHz、500MHz、1GHz&#xff0c;最高采样率5GSa/s&#xff0c;最大存储深度60kpts/CH&#xff0c;最快波形捕获率10万个波形/秒&#xff0c;独创的Any Acquire…

专业课145+总分440+东南大学920考研专业基础综合信号与系统数字电路经验分享

个人情况简介 今年考研440&#xff0c;专业课145&#xff0c;数一140&#xff0c;期间一年努力辛苦付出&#xff0c;就不多表了&#xff0c;考研之路虽然艰难&#xff0c;付出很多&#xff0c;当收获的时候&#xff0c;都是值得&#xff0c;考研还是非常公平&#xff0c;希望大…

【部署】Deploying Trino on linux

文章目录 一. Requirements1. Linux operating system2. Java 环境3. Python 二. Installing Trino三. Configuring Trino1. 节点配置2. JVM 配置3. Config properties4. Log levels5. Catalog properties 四. Running Trino 一. Requirements 1. Linux operating system 64位…

SpringBoot错误处理机制解析

SpringBoot错误处理----源码解析 文章目录 1、默认机制2、使用ExceptionHandler标识一个方法&#xff0c;处理用Controller标注的该类发生的指定错误1&#xff09;.局部错误处理部分源码2&#xff09;.测试 3、 创建一个全局错误处理类集中处理错误&#xff0c;使用Controller…

基于java技术的电子商务支撑平台

摘 要 随着网络技术的发展&#xff0c;Internet变成了一种处理日常事务的交互式的环境。互联网上开展各种服务已经成为许多企业和部门的急切需求。Web的普遍使用从根本上改变了人们的生活方式、工作方式&#xff0c;也改变了企业的经营方式和服务方式。人们可以足不出户办理各…

财务管理在IT服务管理中的重要作用

官方网站 www.itilzj.com 文档资料: wenku.itilzj.com 财务管理作为一种管理组织财务资源的方法&#xff0c;在IT服务领域扮演着关键的角色。其涵盖的范围涉及预算编制、成本控制、投资决策、财务报告和绩效评估等多个方面&#xff0c;直接关系到IT服务的财务健康和整体运作。…

WT588F02A-16S录放音语音芯片为何需要配备自动增益控制麦克风?

在语音录放领域&#xff0c;一款优秀的语音芯片如WT588F02A-16S不仅需要具备高品质的录音和播放功能&#xff0c;还需要与合适的麦克风配合&#xff0c;以确保音频输入的最佳效果。而其中&#xff0c;自动增益控制&#xff08;AGC&#xff09;麦克风在这一过程中发挥着重要作用…