【llm 微调code-llama 训练自己的数据集 一个小案例】

这也是一个通用的方案,使用peft微调LLM。

准备自己的数据集

根据情况改就行了,jsonl格式,三个字段:context, answer, question

import pandas as pd
import random
import jsondata = pd.read_csv('dataset.csv')
train_data = data[['prompt','Code']]
train_data = train_data.values.tolist()random.shuffle(train_data)train_num = int(0.8 * len(train_data))with open('train_data.jsonl', 'w') as f:for d in train_data[:train_num]:d = {'context':'','question':d[0],'answer':d[1]}f.write(json.dumps(d)+'\n')
with open('val_data.jsonl', 'w') as f:for d in train_data[train_num:]:d = {'context':'','question':d[0],'answer':d[1]}f.write(json.dumps(d)+'\n')

初始化

from datetime import datetime
import os
import sysimport torchfrom peft import (LoraConfig,get_peft_model,get_peft_model_state_dict,prepare_model_for_int8_training,
)
from transformers import (AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM,TrainingArguments, Trainer, DataCollatorForSeq2Seq)# 加载自己的数据集
from datasets import load_datasettrain_dataset = load_dataset('json', data_files='train_data.jsonl', split='train')
eval_dataset = load_dataset('json', data_files='val_data.jsonl', split='train')# 读取模型
base_model = 'CodeLlama-7b-Instruct-hf'model = AutoModelForCausalLM.from_pretrained(base_model,load_in_8bit=True,torch_dtype=torch.float16,device_map="auto",low_cpu_mem_usage=True
)tokenizer = AutoTokenizer.from_pretrained(base_model)

微调前的效果

tokenizer.pad_token = tokenizer.eos_token
prompt = """You are programming coder.Now answer the question:{}"""
prompts = [prompt.format(train_dataset[i]['question']) for i in [1,20,32,45,67]]model_input = tokenizer(prompts, return_tensors="pt", padding=True).to("cuda")model.eval()
with torch.no_grad():outputs = model.generate(**model_input, max_new_tokens=300)outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)print(outputs)

进行微调

tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"def tokenize(prompt):result = tokenizer(prompt,truncation=True,max_length=512,padding=False,return_tensors=None,)# "self-supervised learning" means the labels are also the inputs:result["labels"] = result["input_ids"].copy()return resultdef generate_and_tokenize_prompt(data_point):full_prompt =f"""You are a powerful programming model. Your job is to answer questions about a database. You are given a question.You must output the code that answers the question.### Input:
{data_point["question"]}### Response:
{data_point["answer"]}
"""return tokenize(full_prompt)tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)model.train() # put model back into training mode
model = prepare_model_for_int8_training(model)config = LoraConfig(r=16,lora_alpha=16,target_modules=["q_proj","k_proj","v_proj","o_proj",
],lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
if torch.cuda.device_count() > 1:model.is_parallelizable = Truemodel.model_parallel = Truebatch_size = 128
per_device_train_batch_size = 32
gradient_accumulation_steps = batch_size // per_device_train_batch_size
output_dir = "code-llama-ft"training_args = TrainingArguments(per_device_train_batch_size=per_device_train_batch_size,gradient_accumulation_steps=gradient_accumulation_steps,warmup_steps=100,max_steps=400,learning_rate=3e-4,fp16=True,logging_steps=10,optim="adamw_torch",evaluation_strategy="steps", # if val_set_size > 0 else "no",save_strategy="steps",eval_steps=20,save_steps=20,output_dir=output_dir,load_best_model_at_end=False,group_by_length=True, # group sequences of roughly the same length together to speed up trainingreport_to="none", # if use_wandb else "none", wandbrun_name=f"codellama-{datetime.now().strftime('%Y-%m-%d-%H-%M')}", # if use_wandb else None,)trainer = Trainer(model=model,train_dataset=tokenized_train_dataset,eval_dataset=tokenized_val_dataset,args=training_args,data_collator=DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True),
)

开始训练

model.config.use_cache = Falseold_state_dict = model.state_dict
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(model, type(model)
)
if torch.__version__ >= "2" and sys.platform != "win32":print("compiling the model")model = torch.compile(model)
trainer.train()

进行测试

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizerbase_model = 'CodeLlama-7b-Instruct-hf'
model = AutoModelForCausalLM.from_pretrained(base_model,load_in_8bit=True,torch_dtype=torch.float16,device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(base_model)output_dir = "code-llama-ft"
model = PeftModel.from_pretrained(model, output_dir)eval_prompt = """You are a powerful programming model. Your job is to answer questions about a database. You are given a question.You must output the code that answers the question.### Input:
Write a function in Java that takes an array and returns the sum of the numbers in the array, or 0 if the array is empty. Except the number 13 is very unlucky, so it does not count any 13, or any number that immediately follows a 13.### Response:
"""model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")model.eval()
with torch.no_grad():outputs = model.generate(**model_input, max_new_tokens=100)[0]
print(tokenizer.decode(outputs, skip_special_tokens=True))

主要参考icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/660933421

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/632748.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pyspark 笔记:窗口函数window

窗口函数相关的概念和基本规范可以见:pyspark笔记:over-CSDN博客 1 创建Pyspark dataFrame from pyspark.sql.window import Window import pyspark.sql.functions as F employee_salary [("Ali", "Sales", 8000),("Bob&qu…

USACO介绍 报名流程 成绩查询方式详解(文末有备赛资料)

USACO美国计算机奥林匹克活动 2023-2024新赛季的时间线安排是怎么样的? 2023-2024USACO竞赛时间 一般来说,USACO竞赛时间在12月-3月期间,每月都有一场比赛每次3-5小时,并在规定时间内完成3-4道题。23-24年USACO竞赛时间安排如下&a…

uniapp h5 生成 ubuntu桌面程序 并运行方法

uniapp h5 生成 ubuntu桌面程序 并运行方法,在window环境下开发,发布到ubuntu桌面,并运行 1、安装Nodejs 安装包官方下载地址:https://www.nodejs.com.cn/ 安装完后cmd,如图,即安装成功 2、通过Nodejs安装 electron…

[flutter]GIF速度极快问题的两种解决方法

原因: 当GIF图没有设置播放间隔时间时,电脑上会默认间隔0.1s,而flutter默认0s。 解决方法一: 将图片改为webp格式。 解决方法二: 为图片设置帧频率,添加播放间隔。例如可以使用GIF依赖组件设置每秒运行…

【音视频】基于NGINX如何播放rtmp视频流

背景 现阶段直播越来越流行,直播技术发展也越来越快。Webrtc、rtmp、rtsp是比较火热的技术,而且应用也比较广泛。本文通过实践来展开介绍关于rtmp如何播放。 概要 本文重点介绍基于NGINX如何播放rtmp视频流 正文 1、构造rtsp视频流 可以参考上一篇…

Cacti 前台SQL注入漏洞复现(CVE-2023-39361)

0x01 产品简介 Cacti 是一套基于 PHP,MySQL,SNMP 及 RRDTool 开发的网络流量监测图形分析工具。 0x02 漏洞概述 该漏洞存在于graph_view.php文件中。默认情况下,访客用户无需身份验证即可访问graph_view.php,在启用情况下使用时会导致SQL注入漏洞。 攻击者可能利用此漏洞…

HCIP-7

IPV6: 为什么使用IPV6: V4地址数量不够V4使用NAT,破坏了端到端原则 IPV6的优点: 全球单播地址聚合性强(IANA组织进行合理的分配)多宿主----一个接口可以配置N个地址--且这些地址为同一级别自动配置---1)…

Fastapi+Jsonp实现前后端跨域请求

文章目录 一、实现方法1.后端部分【Fastapi】2.前端部分【JS】二、测试一、实现方法 1.后端部分【Fastapi】 # coding:utf-8import json from fastapi import FastAPI, Response from fastapi.middleware.cors import CORSMiddlewareapp = FastAPI(

IPhone、IPad、安卓手机、平板以及鸿蒙系统使用惠普无线打印教程

演示机型:惠普M281fdw,测试可行机型:惠普M277,惠普M452、惠普M283 点击右上角图标。 点击WI-FI Direct 开,(如果WI-FI Direct关闭,请打开!) 记录打印机的wifi名称(SSID)和密码。 打开IPhone、I…

django后台进行加密手机号字段,加密存储,解密显示

需求: 1 :员工在填写用户的手机号时,直接填写,在django后台中输入 2:当员工在后台确认要存储到数据库时,后台将会把手机号进行加密存储,当数据库被黑之后,手机号字段为加密字符 3:员…

AD导出BOM表 导出PDF

1.Simple BOM: 这种模式下,最好在pcb界面,这样的导出的文件名字是工程名字,要是在原理图界面导出,会以原理图的名字命名表格。 直接在菜单栏 报告->Simple BOM 即可导出物料清单,默认导出 comment pattern qu…

253:vue+openlayers 加载HERE多种地图(v2软件版本)

第253个 点击查看专栏目录 本示例的目的是介绍演示如何在vue+openlayers中添加HERE地图,并且含多种的表现形式。包括地图类型,文字标记的设置、语言的选择、PPI的设定。 直接复制下面的 vue+openlayers源代码,操作2分钟即可运行实现效果 文章目录 示例效果图配置方式示例源…

2023年移远车载全面开花,智能座舱加速进击

作为汽车智能化的关键组件,车载模组正发挥着越来越重要的作用。 移远通信进入车载模组领域近十年,已形成了完善的车载产品队列,不但在5G/4G车载通信、智能座舱、C-V2X车路协同等领域打造了一枝独秀的产品线,也推出了车规级Wi-Fi/蓝…

LaWGPT安装和使用教程的复现版本【细节满满】

文章目录 前言一、下载和部署1.1 下载1.2 环境安装1.3 模型推理 总结 前言 LaWGPT 是一系列基于中文法律知识的开源大语言模型。该系列模型在通用中文基座模型(如 Chinese-LLaMA、ChatGLM等)的基础上扩充法律领域专有词表、大规模中文法律语料预训练&am…

【FastAPI】请求体

在 FastAPI 中,请求体(Request Body)是通过请求发送的数据,通常用于传递客户端提交的信息。FastAPI 使得处理请求体变得非常容易。 请求体是客户端发送给 API 的数据。响应体是 API 发送给客户端的数据 注:不能使用 …

1.6 面试经典150题 - 跳跃游戏

跳跃游戏 给你一个非负整数数组 nums ,你最初位于数组的 第一个下标 。数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个下标,如果可以,返回 true ;否则,返回 false 。 class Solution…

2024年回炉计划之排序算法(一)

算法是计算机科学和信息技术中的重要领域,涉及到问题求解和数据处理的方法。要学习算法,你可能需要掌握以下一些基本知识: 基本数据结构: 了解和熟练使用各种数据结构,如数组、链表、栈、队列、树和图等。数据结构是算…

vue中data和props的区别

一、两者区别 区别一: data不需要用户(开发者)传值,自身维护 props需要用户(开发者)传值 区别二: 1、data上的数据都是可读可写的, 2、props上的数据只可以读的,无…

Qt固件映像 Raspberry Pi 嵌入式C++(Qt)编程

Qt C创建突围游戏应用示例 在我们的游戏中,我们有一个桨、一个球和三十块砖。 计时器用于创建游戏周期。 我们不处理角度,我们只是改变方向:上、下、左、右。 Qt5 库是为创建计算机应用程序而开发的。尽管如此,它也可以用来创建…

Java导出Excel并合并单元格

需求:需要在导出excel时合并指定的单元格 ruoyi excel 项目基于若伊框架二次开发,本着能用现成的就不自己写的原则,先是尝试了Excel注解中needMerge属性 /*** 是否需要纵向合并单元格,应对需求:含有list集合单元格)*/public boolean needMer…