训练自己的GPT2

训练自己的GPT2

  • 1.预训练与微调
  • 2.准备工作
  • 2.在自己的数据上进行微调

1.预训练与微调

所谓的预训练,就是在海量的通用数据上训练大模型。比如,我把全世界所有的网页上的文本内容都整理出来,把全人类所有的书籍、论文都整理出来,然后进行训练。这个训练过程代价很大,首先模型很大,同时数据量又很大,比如GPT3参数量达到了175B,训练数据达到了45TB,训练一次就话费上千万美元。如此大代价学出来的是一个通用知识的模型,他确实很强,但是这样一个模型,可能无法在一些专业性很强的领域上取得比较好的表现,因为他没有针对这个领域的数据进行训练过。

因此,大模型火了之后,很多人都开始把大模型用在自己的领域。通常也就是把自己领域的一些数据,比如专业书、论文等等整理出来,使用预训练好的大模型在新的数据集上进行微调。微调的成本相比于预训练就要小得多了。

2.准备工作

首先需要安装第三方库transformerstransformers是一个用于自然语言处理(NLP)的Python第三方库,实现Bert、GPT-2和XLNET等比较新的模型,支持TensorFlow和PyTorch。以及下载预训练好的模型权重。

pip install transformers

安装完成之后,我们可以直接使用下面的代码,来构造一个预训练的GPT2

from transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

当运行的时候,代码会自动从hugging face上下载模型。但是由于hugging face是国外网站,可能下载起来很慢或者无法下载,因此我们也可以自己手动下载之后在本地读取。

打开hugging face的网站,搜索GPT2。或者直接进入GPT2的页面。

下载上图中的几个文件到本地,假设下载到./gpt2文件夹

然后就可以使用下面的代码来尝试预训练的模型直接生成文本你的效果。

from transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("./gpt2")
model = GPT2LMHeadModel.from_pretrained("./gpt2")q = "tell me a fairy story"ids = tokenizer.encode(q, return_tensors='pt')
final_outputs = model.generate(ids,do_sample=True,max_length=100,pad_token_id=model.config.eos_token_id,top_k=50,top_p=0.95,
)print(tokenizer.decode(final_outputs[0], skip_special_tokens=True))

回答如下:

2.在自己的数据上进行微调

首先把我们的数据,也就是文本,全部整理到一起。比如可以把所有文本拼接到一起。

假设所有的文本数据都存到一个文件中。那么可以直接使用下面的代码进行训练。

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification, AdamW, GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, TextDatasetdef load_data_collator(tokenizer, mlm = False):data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=mlm,)return data_collatordef load_dataset(file_path, tokenizer, block_size = 128):dataset = TextDataset(tokenizer = tokenizer,file_path = file_path,block_size = block_size,)return datasetdef train(train_file_path, model_name,output_dir,overwrite_output_dir,per_device_train_batch_size,num_train_epochs,save_steps):tokenizer = GPT2Tokenizer.from_pretrained(model_name)train_dataset = load_dataset(train_file_path, tokenizer)data_collator = load_data_collator(tokenizer)tokenizer.save_pretrained(output_dir)model = GPT2LMHeadModel.from_pretrained(model_name)model.save_pretrained(output_dir)training_args = TrainingArguments(output_dir=output_dir,overwrite_output_dir=overwrite_output_dir,per_device_train_batch_size=per_device_train_batch_size,num_train_epochs=num_train_epochs,)trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset,)trainer.train()trainer.save_model()train_file_path = "./train.txt"   # 你自己的训练文本
model_name = './gpt2'  # 预训练的模型路径
output_dir = './custom_data'  # 你自己设定的模型保存路径
overwrite_output_dir = False
per_device_train_batch_size = 96  # 每一台机器上的batch size。
num_train_epochs = 50   
save_steps = 50000# Train
train(train_file_path=train_file_path,model_name=model_name,output_dir=output_dir,overwrite_output_dir=overwrite_output_dir,per_device_train_batch_size=per_device_train_batch_size,num_train_epochs=num_train_epochs,save_steps=save_steps
)     

训练完成之后,推理的话,直接使用第二节里的代码,将预训练模型路径换成自己训练的模型路径就行了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/611281.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

刷题第十五天-存在重复元素Ⅲ

存在重复元素Ⅲ 题目要求 解题思路 主要使用滑动窗口方法,让滑动窗口代销固定为t。 本题最大的难点在于快速地找到滑动窗口内的最大值和最小值,以及删除指定元素。 如果遍历求滑动窗口内的最大值和最小值,时间复杂度是O(K&#…

Linux(上篇)

计算机硬件软体系 顺序执行程序 计算机硬件由运算器,控制器,存储器,输入设备,输出设备五大部分组成 计算机硬件组成 输入设备 用来将人们熟悉的信息形式转换为机器能够识别的信息形式。 输出设备 将机器运算的结果装换为人…

redis的使用、打开、关闭的详细介绍

redis的使用、打开、关闭的详细介绍 1.安装redis cd / cd opt/ wget https://download.redis.io/releases/redis-5.0.5.tar.gz 2.解压redis tar xzf redis-5.0.5.tar.gz 3.执行make cd redis-5.0.5/ make 如果出现找不到make的情况就yum install -y make 如果没有gcc就…

【BIAI】Lecture 7 - EEG data analysis

EEG data analysis 专业术语 EEG 脑电图 excitatory postsynaptic potential(EPSP)兴奋性突触后电位 inhibitory postsynaptic potential(IPSP) 抑制性突触后电位 action potential 动作电位 dipoles 偶极子 Pyramidal neurons 椎体细胞 Axon 轴突 Dendrite 树突 Synapse 突触…

【大数据架构】OLAP实时分析引擎选型

OLAP引擎面临的挑战 常见OLAP引擎对比 OLAP分析场景中,一般认为QPS达到1000就算高并发,而不是像电商、抢红包等业务场景中,10W以上才算高并发,毕竟数据分析场景,数据海量,计算复杂,QPS能够达到1…

慕课热搜01

uniapp过滤器使用 创建一个过滤器: 在入口函数注册过滤器 // 注册过滤器 import * as filters from "./filters/index.js"Object.keys(filters).forEach(key>{Vue.filter(key,filters[key]) })使用过滤器: onPageScroll , uniapp监听滚动…

Edge无法卸载也无法上网的处理

1、在C盘把Microsoft下的子文件删掉,注意最好用delete删,别右键删! 2、删掉用户文件夹下\AppData\Local\Microsoft\Edge\User Data下的所有文件 3、到微软官网下载最新的edge,再安装就可以了: https://www.microsoft.com/zh-cn…

FlinkAPI开发之数据合流

案例用到的测试数据请参考文章: Flink自定义Source模拟数据流 原文链接:https://blog.csdn.net/m0_52606060/article/details/135436048 概述 在实际应用中,我们经常会遇到来源不同的多条流,需要将它们的数据进行联合处理。所以…

【GDAL】Windows下VS+GDAL开发环境搭建

Step.0 环境说明(vs版本,CMake版本) 本地的IDE环境是vs2022,安装的CMake版本是3.25.1。 Step.1 下载GDAL和依赖的组件 编译gdal之前需要安装gdal依赖的组件,gdal所依赖的组件可以在官网文档找到,可以根据…

中文语音识别转文字的王者,阿里达摩院FunAsr足可与Whisper相颉顽

君不言语音识别技术则已,言则必称Whisper,没错,OpenAi开源的Whisper确实是世界主流语音识别技术的魁首,但在中文领域,有一个足以和Whisper相颉顽的项目,那就是阿里达摩院自研的FunAsr。 FunAsr主要依托达摩…

截图识别文字怎么弄?分享3个工具!

随着科技的不断发展,我们的生活和工作中需要处理越来越多的数字信息。有时候,我们需要从图片或者截图中提取文字,例如整理资料、处理图片注释等等。这时,一款好用的截图识别文字工具就显得尤为重要。今天,就让我们来聊…

浏览器不支持 css 中 :not 表达式的解决方法

问题 使用 :not 表达式的样式在不同浏览器中存在不生效的问题。 原因 不生效是因为浏览器版本较低所导致的。(更多详细信息请看:MDN) 解决方法 初始写法: .input-group:not(.user-name, .user-passwork){width: auto; }改成…

常见Mysql数据库操作语句

-- DDL创建数据库结构 -- 查询所有数据库 show databases ; -- 修改数据库字符集 alter database db02 charset utf8mb4; -- 创建字符编码为utf——8的数据库 create database db05 DEFAULT CHARACTER SET utf8;-- 创建表格 create table tb_user(id int auto_increment primar…

搜维尔科技:【简报】元宇宙数字人赛道,2022年金奖《金魚姬》赏析!

一名网络直播主名叫琉璃,在即将展开她日常进行的每日准时直播前,肚子极为不舒服,突然很想上厕所,由于时间紧迫,导致琉璃需要在厕所里面完成直播!为了掩饰自己所在的处境,她决定运用自己设计的虚…

85.乐理基础-记号篇-力度记号

内容来源于:三分钟音乐社 上一个内容:78.乐理基础-非常见拍号如何打拍子-CSDN博客 85-78之间的内容观看索引: 腾讯课堂-三分钟音乐社-打拍子(20)-总结、重点、练习与检验方法开始看 力度记号:p、f、mp、…

基于SpringBoot的精品在线试题库系统(系统+数据库+文档)

🍅点赞收藏关注 → 私信领取本源代码、数据库🍅 本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目 希望你能有所收获,少走一些弯路。🍅关注我不迷路🍅一、绪论 1. 研究背景 现在大家…

软件测试|Python Faker库使用指南

简介 Faker是一个Python库,用于生成虚假(假的)数据,用于测试、填充数据库、生成模拟数据等目的。它可以快速生成各种类型的虚假数据,如姓名、地址、电子邮件、电话号码、日期等,非常适合在开发和测试过程中…

【Vue】文件管理页面制作

<template><div><div style"margin: 10px 0"><el-input style"width: 200px" placeholder"请输入名称" suffix-icon"el-icon-search" v-model"name"></el-input><el-button class"ml…

RIP复习实验

条件: R1为外网&#xff0c;R8和r9的环回分别是172.16.1.0/24和172.16.2.0/24 中间使用78.1.1.0/24 剩下的路由器2-6使用172.16.0.0/16 要求: R1为运营商 r1远程登录r2实际登录r7 R2访问r7要求走r5去访问 全网可达 实现流程: 首先配置好各接口ip address 然后r2-r7使用rip…

数据库授权问题 ERROR 1410 (42000): You are not allowed to create a user with GRANT

当我要给数据库授权时&#xff0c;却出现了错误。 ERROR 1410 (42000): You are not allowed to create a user with GRANT 包括对数据库角色权限信息的查询&#xff0c;同样也会出现问题 ERROR: 1141: There is no such grant defined for user xuxu on host localhost 这是…