HuggingFace学习笔记--BitFit高效微调

1--BitFit高效微调

        BitFit,全称是 bias-term fine-tuning,其高效微调只去微调带有 bias 的参数,其余参数全部固定;

2--实例代码

from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer# 分词器
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}if __name__ == "__main__":# 加载数据集dataset = load_from_disk("./PEFT/data/alpaca_data_zh")# 处理数据tokenized_ds = dataset.map(process_func, remove_columns = dataset.column_names)# print(tokenizer.decode(tokenized_ds[1]["input_ids"]))# print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))# 创建模型model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)# 基于bitfit只训练带有bias的参数for name, param in model.named_parameters():if "bias" not in name:param.requires_grad = False# 训练参数args = TrainingArguments(output_dir = "./chatbot",per_device_train_batch_size = 1,gradient_accumulation_steps = 8,logging_steps = 10,num_train_epochs = 1)# trainertrainer = Trainer(model = model,args = args,train_dataset = tokenized_ds,data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True))# 训练模型trainer.train()# 模型推理pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "output = pipe(ipt, max_length=256, do_sample=True)print(output)

结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/192293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【每日OJ —— 226. 翻转二叉树】

每日OJ —— 226. 翻转二叉树 1.题目:226. 翻转二叉树2.解法2.1.算法讲解2.2.代码实现2.3.代码提交通过展示 1.题目:226. 翻转二叉树 2.解法 2.1.算法讲解 我们从根节点开始,递归地对树进行遍历,并从叶子节点先开始翻转。如果当前…

持续集成交付CICD:CentOS 7 安装 Sonarqube9.6

目录 一、实验 1.CentOS 7 安装 Sonarqube9.6 二、问题 1.安装postgresql13服务端报错 2.postgresql13创建用户报错 一、实验 1.CentOS 7 安装 Sonarqube9.6 (1)下载软件及依赖包 ①Sonarqube9.6下载地址 https://binaries.sonarsource.com/Dis…

深度学习之基于yolov3学生课堂行为及专注力检测预警监督系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 深度学习技术在学生课堂行为及专注力检测预警监督系统的应用是一项极具挑战性和创新性的研究领域。利用YOLOv3&…

Docker常见命令介绍

命令说明 docker pull 拉取镜像 docker push 推送镜像到DockerRegistry docker images 查看本地镜像 docker rmi 删除本地镜像 docker run 创建并运行容器(不能重复创建) docker stop 停止指定容器 docker start 启动指定容器 docker rest…

设计模式-结构型模式之外观设计模式

文章目录 七、外观模式 七、外观模式 外观模式(Facade Pattern)隐藏系统的复杂性,并向客户端提供了一个客户端可以访问系统的接口。它向现有的系统添加一个接口,来隐藏系统的复杂性。 这种模式涉及到一个单一的类,该类…

揭秘原型链:探索 JavaScript 面向对象编程的核心(上)

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

第九节HarmonyOS 常用基础组件2-Image

一、组件介绍 组件(Component)是界面搭建与显示的最小单位,HarmonyOS ArkUI声名式为开发者提供了丰富多样的UI组件,我们可以使用这些组件轻松的编写出更加丰富、漂亮的界面。 组件根据功能可以分为以下五大类:基础组件…

SmartSoftHelp8,Web前端性能提升,js,css,html 优化压缩工具

Web前端js,css,html 优化压缩工具 提高web 前端性能,访问速度优化专业工具 CSS,js,html 单文件,多文件 单个,批量压缩优化 web前端优化:减少空格,体积压缩&#xff0…

基于算能的国产AI边缘计算盒子8核心A53丨17.6Tops算力

边缘计算盒子 8核心A53丨17.6Tops算力 ● 可提供17.6TOPS(INT8)的峰值计算能力、2.2TFLOPS(FP32)的高精度算力,单芯片最高支持32路H.264 & H.265的实时解码能力。 ● 适配Caffe/TensorFlow/MxNet/PyTorch/ ONNX/…

面试数据库八股文十问十答第一期

面试数据库八股文十问十答第一期 作者:程序员小白条,个人博客 1.MySQL常见索引、 MySQL常见索引有: 主键索引、唯一索引、普通索引、全文索引、组合索引(最左前缀)主键索引特点:唯一性,非空,自增(如果使用…

Ubuntu 安装 MySQL8 配置、授权、备份、远程连接

目录 0100 系统环境0200 下载0300 安装0400 服务管理0401 关闭、启动、重启服务0402 查看服务状态 0500 查看配置文件0600 账号管理0601 添加账号0602 删除账号0603 修改密码0604 忘记root密码 0700 自动备份0800 远程访问 0100 系统环境 [rootlocalhost ~]# cat /proc/versio…

AD生产BOM表时如何隐藏不需要的器件记录

在完成图纸设计号通常需要生产BOM表,以便采购等,如果不做一些操作,往往输出的BOM表中包含一些非需要采购的器件,如下图 这时就需要对原理图或者PCB图做一些处理,以原理图为例,在需要屏蔽的器件上双击&#…

Nginx实现多虚拟主机配置

Nginx实现多虚拟主机配置 Nginx为什么要进行多虚拟主机配置呢?what? Nginx实现多虚拟主机配置的主要原因是,一个服务器可能会承载多个网站或应用程序,这些网站或应用程序需要使用不同的域名或IP地址来进行访问。如果只有一个虚拟…

ctfhub技能树_web_web前置技能_HTTP

目录 一、HTTP协议 1.1、请求方式 1.2、302跳转 1.3、Cookie 1.4、基础认证 1.5、响应包源代码 一、HTTP协议 1.1、请求方式 注:HTTP协议中定义了八种请求方法。这八种都有:1、OPTIONS :返回服务器针对特定资源所支持的HTTP请求方法…

通过查看ThreadLocal的源码进行简单理解

目录 为什么要使用ThreadLocal? 简单案例 ThreadLocal源码分析 断点跟踪 为什么要使用ThreadLocal 在多线程下,如果同时修改公共变量可能会存在线程安全问题,JDK虽然提供了同步锁与Lock等方法给公共访问资源加锁,但在高并发…

力扣 --- H指数

题目描述: 给你一个整数数组 citations ,其中 citations[i] 表示研究者的第 i 篇论文被引用的次数。计算并返回该研究者的 h 指数。 根据维基百科上 h 指数的定义:h 代表“高引用次数” ,一名科研人员的 h 指数 是指他&#xff…

【HDFS】调试慢节点pipiline ack信息

Client - DN1 - DN2 - DN3 DN3 send ack:[0][d3]。 DN2 send ack: [从dn2入队到收到dn3的ack耗时,0] [d2,d3]。 DN1 send ack: [pkt从dn1入队到收到dn2的ack耗时,pkt从dn2入队到收到dn3的ack耗时,0] [d1,d2,d3]。 Client receive: 就是DN1发送过来数据。 客户端收到的第一个…

【论文笔记】Universal Guidance for Diffusion Models

Abstract 典型的扩散模型经过训练以接受特定形式的条件作用(最常见的是文本),并且如果不经过重新训练就不能接受其他形式的条件的作用。 这项工作中提出了一种通用制导算法(universal guidance algorithm),使扩散模型能够通过任意…

python弹球小游戏

import pygame import random# 游戏窗口大小 WIDTH 800 HEIGHT 600# 定义颜色 WHITE (255, 255, 255) BLACK (0, 0, 0) RED (255, 0, 0) GREEN (0, 255, 0) BLUE (0, 0, 255)# 球的类 class Ball:def __init__(self):self.radius 10self.speed [random.randint(2, 4),…

DELL EMC unity 存储系统日志收集方法

对于一些非简单的硬件故障,解决故障最有效、最快速的方法就是收集日志,而不是瞎搞。常见的乱搞方法就是 1. reimage系统‘ 2. 更换控制器;3, 重启。 本文详细介绍了图形界面GUI和命令行CLI下如何收集DELL EMC Unity日志的方法和常…