大模型额外篇章一:用huggingface的电影评论数据集情感分类训练模型

文章目录

    • 一、介绍和准备
      • 1)介绍
      • 2)准备(安装依赖)
    • 二、开始训练

一、介绍和准备

1)介绍

工具:huggingface
目的:情感分类
输入:电影评论
输出:标签 [‘neg’,‘pos’]
数据源:https://huggingface.co/datasets/rotten_tomatoes或https://hf-mirror.com/datasets

2)准备(安装依赖)

# pip安装
pip install transformers # 安装最新的版本
pip install transformers == 4.30 # 安装指定版本
# conda安装
conda install -c huggingface transformers  # 只4.0以后的版本

二、开始训练

  • 步骤
    1、指定训练集和数据集
    2、加载模型
    3、加载tokenizer(运行时自动下载)
    4、其它相关公共变量赋值(随机种子\标签集\标签转 token_id)
    5、处理数据集(模型接受的输入格式)
    6、定义数据规整器:训练时自动将数据拆分成Batch
    7、定义训练超参
    8、定义训练器,并开始训练
    9、开始训练

(这里之后是训练后的推理步骤)
10、加载训练后的模型进行推理
11、加载 checkpoint 并继续训练

  • 代码
import datasets
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from transformers import AutoModelForCausalLM
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from transformers import Trainer, Seq2SeqTrainer
import transformers
from transformers import DataCollatorWithPadding
from transformers import TextGenerationPipeline
import torch
import numpy as np
import os, re
from tqdm import tqdm
import torch.nn as nn#1、指定训练集和数据集
# 数据集名称(运行时下载)
DATASET_NAME = "rotten_tomatoes"
# 加载数据集
raw_datasets = load_dataset(DATASET_NAME)
# 训练集
raw_train_dataset = raw_datasets["train"]
# 验证集
raw_valid_dataset = raw_datasets["validation"]#2、加载模型
# 模型名称
MODEL_NAME = "gpt2"
# 加载模型
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME,trust_remote_code=True)#3、加载tokenizer(运行时自动下载)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME,trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token_id = 0#4、其它相关公共变量赋值
# 设置随机种子:同个种子的随机序列可复现
transformers.set_seed(42)
# 标签集
named_labels = ['neg','pos']
# 标签转 token_id
label_ids = [tokenizer(named_labels[i],add_special_tokens=False)["input_ids"][0]for i in range(len(named_labels))
]#5、处理数据集(模型接受的输入格式)
MAX_LEN=32   #最大序列长度(输入+输出)
DATA_BODY_KEY = "text" # 数据集中的输入字段名
DATA_LABEL_KEY = "label" #数据集中输出字段名
# 定义数据处理函数,把原始数据转成input_ids, attention_mask, labels
def process_fn(examples):model_inputs = {"input_ids": [],"attention_mask": [],"labels": [],}for i in range(len(examples[DATA_BODY_KEY])):inputs = tokenizer(examples[DATA_BODY_KEY][i],add_special_tokens=False)label = label_ids[examples[DATA_LABEL_KEY][i]]input_ids = inputs["input_ids"] + [tokenizer.eos_token_id, label]raw_len = len(input_ids)input_len = len(inputs["input_ids"]) + 1if raw_len >= MAX_LEN:input_ids = input_ids[-MAX_LEN:]attention_mask = [1] * MAX_LENlabels = [-100]*(MAX_LEN - 1) + [label]else:input_ids = input_ids + [tokenizer.pad_token_id] * (MAX_LEN - raw_len)attention_mask = [1] * raw_len + [0] * (MAX_LEN - raw_len)labels = [-100]*input_len + [label] + [-100] * (MAX_LEN - raw_len)model_inputs["input_ids"].append(input_ids)                     #初始的纯数据model_inputs["attention_mask"].append(attention_mask)           #加0也就是pad成相等长度,方便矩阵计算model_inputs["labels"].append(labels)                           #-100一系列操作标识哪些部分token参与计算return model_inputs
# 处理训练数据集
tokenized_train_dataset = raw_train_dataset.map(process_fn,batched=True,remove_columns=raw_train_dataset.columns,desc="Running tokenizer on train dataset",
)
# 处理验证数据集
tokenized_valid_dataset = raw_valid_dataset.map(process_fn,batched=True,remove_columns=raw_valid_dataset.columns,desc="Running tokenizer on validation dataset",
)# 6、定义数据规整器:训练时自动将数据拆分成Batch
collater = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt",
)#7、定义训练超参
LR=2e-5         # 学习率
BATCH_SIZE=8    # Batch大小
INTERVAL=100    # 每多少步打一次 log / 做一次 eval
training_args = TrainingArguments(output_dir="./output",              # checkpoint保存路径evaluation_strategy="steps",        # 按步数计算eval频率overwrite_output_dir=True,num_train_epochs=1,                 # 训练epoch数per_device_train_batch_size=BATCH_SIZE,     # 每张卡的batch大小gradient_accumulation_steps=1,              # 累加几个step做一次参数更新per_device_eval_batch_size=BATCH_SIZE,      # evaluation batch sizeeval_steps=INTERVAL,                # 每N步eval一次logging_steps=INTERVAL,             # 每N步log一次save_steps=INTERVAL,                # 每N步保存一个checkpointlearning_rate=LR,                   # 学习率
)#8、定义训练器,并开始训练
# 节省显存
model.gradient_checkpointing_enable()
trainer = Trainer(model=model, # 待训练模型args=training_args, # 训练参数data_collator=collater, # 数据校准器train_dataset=tokenized_train_dataset,  # 训练集eval_dataset=tokenized_valid_dataset,   # 验证集# compute_metrics=compute_metric,         # 计算自定义评估指标
)
# 开始训练
trainer.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/843997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「架构」单元测试及运用

在参与管理和研发软件项目的过程中,单元测试的实际运用对于确保最终产品的质量至关重要。以下是一些实际运用的案例和说明。 静态测试的实际运用 在TechCorp的电子商务平台项目中,静态测试作为代码质量保证的第一道防线。开发团队在编写代码的同时,使用SonarQube等静态代码…

【学习Day1】计算机基础

✍🏻记录学习过程中的输出,坚持每天学习一点点~ ❤️希望能给大家提供帮助~欢迎点赞👍🏻收藏⭐评论✍🏻指点🙏 1.1 中央处理单元CPU 中央处理器(CPU,central processing unit&…

在全志H616核桃派开发板上进行音频配置的方法详解

耳机口​ 核桃派板载的3.5mm音频输出口,该接口有一定的输出功率,可以使用耳机或者带功放的扬声器都可以播放声音。 查看音频设备​ 可以使用下面指令来查看音频信息: aplay -l音频播放测试​ 播放系统自带wav音频文件测试, 下面指令的au…

控制台生产厂家生产流程详解

控制台生产厂家的生产流程是一个复杂而精细的过程,它涉及多个环节,从原材料的准备到最终产品的出厂检验,每一步都至关重要。以下是控制台生产厂家的一般生产流程: 厂家会根据客户的需求和市场趋势进行产品设计。设计师会综合考虑控…

闪电加载:Hexo博客性能优化全攻略

巴索罗缪大熊 前言 这些年积累了很多前端性能优化的知识点和思路,日常工作很少涉及技术层极限优化,近期终于一点点把博客独立搭建并部署了,对之前的一些技术点进行了深度探索,最终结果也达到了预期效果,由于水平有限&…

河北奥润顺达集团研究院PMO经理常江南受邀为第十三届中国PMO大会演讲嘉宾

全国PMO专业人士年度盛会 河北奥润顺达集团研究院PMO经理、研发部运营管理办负责人常江南先生受邀为PMO评论主办的2024第十三届中国PMO大会演讲嘉宾,演讲议题为“初建PMO的体系宣贯和人员培养实践总结”。大会将于6月29-30日在北京举办,敬请关注&#xf…

如何利用云平台上更好地规划安全生产教育与培训

在平台上进行安全教育和培训,可以采取以下步骤和策略,以确保教育的有效性和参与度: 一、明确教育目标和培训内容 确定教育目标:明确希望员工通过培训达到的安全意识和技能水平。 制定培训内容:根据行业特点、岗位需求…

centos7安装python-gdal环境

python3 yum install python3 python3-pip -y gdal-3.6.2 参考编译postgis python安装gdal export CPLUS_INCLUDE_PATH/usr/local/gdal-3.6.2/include export C_INCLUDE_PATH/usr/local/gdal-3.6.2/include export LDFLAGS"-L/usr/local/gdal-3.6.2/lib64" pip3…

猿编程是用什么语言编程的:深入剖析其背后的语言选择与魅力

猿编程是用什么语言编程的:深入剖析其背后的语言选择与魅力 猿编程,这个富有创意和活力的编程平台,引发了众多编程爱好者的关注。那么,猿编程究竟是用什么语言进行编程的呢?这背后又蕴含着怎样的语言选择与魅力&#…

wordpress子比主题文章付费发卡插件

插件仅适用于子比主题 插件演示 免费下载 :子比主题文章付费发卡插件_麦田吧 如下图,添加卡密支持批量添加,按照卡号(英文逗号/空格/—-)密码的格式输入,一行一条,可以直接添加数据&#xff0…

​​人工智能_大模型083_大模型时代机遇02_提示词优化开发工具_立项_计量模式_真实需求_5why法---人工智能工作笔记0218

上一节我们提供了一个非常好用的提示词,优化开发的,调试工具 vellum 可以看到是这个工具 使用的时候,写完一段提示词,可以选择不同的模型,看看给出的效果情况 对应的模型非常多. ### 立项在立项阶段,要对这三个要素有初步的答案:1. 真实需求是什么? 2. 商业模式是什么? 3…

SSH远程登录时常见问题解决

SSH时出现WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! 问题解决——SSH时出现WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! WARNING: REMOTE HOST IDENTIFICATION HAS CHANGED! 翻译过来就是 警告:远程主机标识已更改! 此报错是由于远程的…

Tomcat端口配置和网页浏览

安装完成Tomcat后,到安装目录里看到内容如下: 各文件夹作用 bin:可执行文件(启动文件startup.bat、关闭文件shutdown.bat)conf:配置文件(修改端口号:server.xml,建议将s…

[自动驾驶技术]-5 Tesla自动驾驶方案之算法(AI Day 2021)

有朋友问我,如何有效学习一个新技术。笔者这么多年的经验是:1)了解国内外产业应用和标准法规现状,先建立宏观知识图谱及技术系统框架;2)根据系统框架逐块进行深入研究(横向、纵向)&a…

【html+css(大作业)】二级菜单导航栏

目录 实现效果 代码及其解释 html部分 CSS部分 hello&#xff0c;hello好久不见&#xff01; 今天我们来写二级导航栏&#xff0c;所谓二级导航栏&#xff0c;简单来说就是鼠标放上去就有菜单拉出&#xff1a; 实现效果 代码及其解释 html部分 <!DOCTYPE html> &l…

gulp入门4:dest

在Gulp中&#xff0c;dest() 方法是一个核心功能&#xff0c;用于指定文件处理流程后输出文件的目录。以下是对 gulp.dest() 的深入讲解&#xff0c;按照分点表示和归纳进行整理&#xff1a; 1. 基本用法 gulp.dest() 的基本语法为 gulp.dest(path[, options])&#xff0c;其…

嵌入式进阶——矩阵键盘

&#x1f3ac; 秋野酱&#xff1a;《个人主页》 &#x1f525; 个人专栏:《Java专栏》《Python专栏》 ⛺️心若有所向往,何惧道阻且长 文章目录 矩阵按键原理图按键状态检测单行按键状态检测多行按键状态检测 状态记录状态优化循环优化 矩阵按键 矩阵键盘是一种常见的数字输入…

Databend 开源周报第 146 期

Databend 是一款现代云数仓。专为弹性和高效设计&#xff0c;为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务&#xff1a;https://app.databend.cn 。 Whats On In Databend 探索 Databend 本周新进展&#xff0c;遇到更贴近你心意的 Databend 。 支持 Expressio…

网络编程基础知识

一、网络的相关概念 二、Ip 对于ipv4&#xff0c;是由4个字节&#xff08;32位&#xff09;表示&#xff0c;一个字节的范围是0~255&#xff0c;采用的是十进制表示ipv6的地址长度位128位&#xff0c;是ipv4的4倍&#xff0c;采用的是16进制表示查看ip地址&#xff1a;在命令行…

windows 下载redis (通过redis-server.exe启动服务)

下载链接&#xff1a; https://github.com/MicrosoftArchive/redis/releases 启动&#xff1a; 查看&#xff1a; 人工智能学习网站 https://chat.xutongbao.top