使用阿里云微调chatglm2

完整的代码可以参考:https://files.cnblogs.com/files/lijiale/chatglm2-6b.zip?t=1691571940&download=true

# %% [markdown]
# # 微调前# %%
model_path = "/mnt/workspace/ChatGLM2-6B/chatglm2-6b"from transformers import AutoTokenizer, AutoModel
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)from IPython.display import display, Markdown, clear_outputdef display_answer(model, query, history=[]):for response, history in model.stream_chat(tokenizer, query, history=history):clear_output(wait=True)display(Markdown(response))return historymodel = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()display_answer(model, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞")# %% [markdown]
# # 微调后的效果
# # %%
import os
import torch
from transformers import AutoConfig
from transformers import AutoTokenizer, AutoModelmodel_path = "/mnt/workspace/ChatGLM2-6B/chatglm2-6b"tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join("/mnt/workspace/ChatGLM2-6B/ptuning/output/adgen-chatglm2-6b-pt-128-2e-2/checkpoint-3000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():if k.startswith("transformer.prefix_encoder."):new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
print(response)# %%
!pip install torchkeras# %%
#导入模块
import numpy as np
import pandas as pd 
import torch
from torch import nn 
from torch.utils.data import Dataset,DataLoader from argparse import Namespace
cfg = Namespace()from argparse import Namespace
cfg = Namespace()#dataset
cfg.prompt_column = 'prompt'
cfg.response_column = 'response'
cfg.history_column = None
cfg.source_prefix = '' #添加到每个prompt开头的前缀引导语cfg.max_source_length = 128 
cfg.max_target_length = 128#model
cfg.model_name_or_path = '/mnt/workspace/ChatGLM2-6B/chatglm2-6b'  #远程'THUDM/chatglm-6b' 
cfg.quantization_bit = None #仅仅预测时可以选 4 or 8 #train
cfg.epochs = 100 
cfg.lr = 5e-3
cfg.batch_size = 1
cfg.gradient_accumulation_steps = 16 #梯度累积import transformers
from transformers import  AutoModel,AutoTokenizer,AutoConfig,DataCollatorForSeq2Seqconfig = AutoConfig.from_pretrained(cfg.model_name_or_path, trust_remote_code=True)tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, trust_remote_code=True)model = AutoModel.from_pretrained(cfg.model_name_or_path,config=config,trust_remote_code=True).half() #先量化瘦身
if cfg.quantization_bit is not None:print(f"Quantized to {cfg.quantization_bit} bit")model = model.quantize(cfg.quantization_bit)#再移动到GPU上
model = model.cuda();# 通过注册jupyter魔法命令可以很方便地在jupyter中测试ChatGLM 
from torchkeras.chat import ChatGLM 
chatglm = ChatGLM(model,tokenizer)# %%
%%chatglm
类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞# %%
#定义一条知识样本~
import json
keyword = '梦中情炉'description = '''梦中情炉一般指的是炼丹工具torchkeras。
这是一个通用的pytorch模型训练模版工具。
torchkeras是一个三好炼丹炉:好看,好用,好改。
她有torch的灵动,也有keras的优雅,并且她的美丽,无与伦比。
所以她的作者一个有毅力的吃货给她取了一个别名叫做梦中情炉。'''#对prompt使用一些简单的数据增强的方法,以便更好地收敛。
def get_prompt_list(keyword):return [f'{keyword}', f'你知道{keyword}吗?',f'{keyword}是什么?',f'介绍一下{keyword}',f'你听过{keyword}吗?',f'啥是{keyword}?',f'{keyword}是何物?',f'何为{keyword}?',]# data =[{'prompt':x,'response':description} for x in get_prompt_list(keyword) ]
data = []
with open("/mnt/workspace/ChatGLM2-6B/ptuning/AdvertiseGen_Simple/train.json", "r", encoding="utf-8") as f:lines = f.readlines()for line in lines:d = json.loads(line)data.append({'prompt':d['content'],'response':d['summary']})dfdata = pd.DataFrame(data)
display(dfdata) # %%
import datasets 
#训练集和验证集一样
ds_train_raw = ds_val_raw = datasets.Dataset.from_pandas(dfdata)# %%
def preprocess(examples):max_seq_length = cfg.max_source_length + cfg.max_target_lengthmodel_inputs = {"input_ids": [],"labels": [],}for i in range(len(examples[cfg.prompt_column])):if examples[cfg.prompt_column][i] and examples[cfg.response_column][i]:query, answer = examples[cfg.prompt_column][i], examples[cfg.response_column][i]history = examples[cfg.history_column][i] if cfg.history_column is not None else Noneprompt = tokenizer.build_prompt(query, history)prompt = cfg.source_prefix + prompta_ids = tokenizer.encode(text=prompt, add_special_tokens=True, truncation=True,max_length=cfg.max_source_length)b_ids = tokenizer.encode(text=answer, add_special_tokens=False, truncation=True,max_length=cfg.max_target_length)context_length = len(a_ids)input_ids = a_ids + b_ids + [tokenizer.eos_token_id]labels = [tokenizer.pad_token_id] * context_length + b_ids + [tokenizer.eos_token_id]pad_len = max_seq_length - len(input_ids)input_ids = input_ids + [tokenizer.pad_token_id] * pad_lenlabels = labels + [tokenizer.pad_token_id] * pad_lenlabels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]model_inputs["input_ids"].append(input_ids)model_inputs["labels"].append(labels)return model_inputsds_train = ds_train_raw.map(preprocess,batched=True,num_proc=4,remove_columns=ds_train_raw.column_names
)ds_val = ds_val_raw.map(preprocess,batched=True,num_proc=4,remove_columns=ds_val_raw.column_names
)data_collator = DataCollatorForSeq2Seq(tokenizer,model=None,label_pad_token_id=-100,pad_to_multiple_of=None,padding=False
)dl_train = DataLoader(ds_train,batch_size = cfg.batch_size,num_workers = 2, shuffle = True, collate_fn = data_collator )
dl_val = DataLoader(ds_val,batch_size = cfg.batch_size,num_workers = 2, shuffle = False, collate_fn = data_collator )for batch in dl_train:break
print(len(dl_train))# %%
!pip install peft# %%
from peft import get_peft_model, AdaLoraConfig, TaskType#训练时节约GPU占用
model.config.use_cache=Falsemodel.supports_gradient_checkpointing = True  #
model.gradient_checkpointing_enable()
model.enable_input_require_grads()peft_config = AdaLoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False,r=8,lora_alpha=32, lora_dropout=0.1,target_modules=["query", "value"]
)peft_model = get_peft_model(model, peft_config)peft_model.is_parallelizable = True
peft_model.model_parallel = Truepeft_model.print_trainable_parameters()# %%
from torchkeras import KerasModel 
from accelerate import Accelerator class StepRunner:def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None, optimizer = None, lr_scheduler = None):self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stageself.optimizer,self.lr_scheduler = optimizer,lr_schedulerself.accelerator = accelerator if accelerator is not None else Accelerator() if self.stage=='train':self.net.train() else:self.net.eval()def __call__(self, batch):#losswith self.accelerator.autocast():loss = self.net(input_ids=batch["input_ids"],labels=batch["labels"]).loss#backward()if self.optimizer is not None and self.stage=="train":self.accelerator.backward(loss)if self.accelerator.sync_gradients:self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)self.optimizer.step()if self.lr_scheduler is not None:self.lr_scheduler.step()self.optimizer.zero_grad()all_loss = self.accelerator.gather(loss).sum()#losses (or plain metrics that can be averaged)step_losses = {self.stage+"_loss":all_loss.item()}#metrics (stateful metrics)step_metrics = {}if self.stage=="train":if self.optimizer is not None:step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']else:step_metrics['lr'] = 0.0return step_losses,step_metricsKerasModel.StepRunner = StepRunner #仅仅保存lora相关的可训练参数
def save_ckpt(self, ckpt_path='checkpoint', accelerator = None):unwrap_net = accelerator.unwrap_model(self.net)unwrap_net.save_pretrained(ckpt_path)def load_ckpt(self, ckpt_path='checkpoint'):self.net = self.net.from_pretrained(self.net.base_model.model,ckpt_path)self.from_scratch = FalseKerasModel.save_ckpt = save_ckpt 
KerasModel.load_ckpt = load_ckpt # %%
optimizer = torch.optim.AdamW(peft_model.parameters(),lr=cfg.lr) 
keras_model = KerasModel(peft_model,loss_fn = None,optimizer=optimizer) 
ckpt_path = 'single_chatglm3'# %%
keras_model.fit(train_data = dl_train,val_data = dl_val,epochs=100,patience=20,monitor='val_loss',mode='min',ckpt_path = ckpt_path,mixed_precision='fp16',gradient_accumulation_steps = cfg.gradient_accumulation_steps)# %%
#验证模型
from peft import PeftModel 
ckpt_path = 'single_chatglm3'
model_old = AutoModel.from_pretrained(cfg.model_name_or_path,load_in_8bit=False, trust_remote_code=True)
peft_loaded = PeftModel.from_pretrained(model_old,ckpt_path).cuda()
model_new = peft_loaded.merge_and_unload() #合并lora权重chatglm = ChatGLM(model_new,tokenizer,max_chat_rounds=20) #支持多轮对话,可以从之前对话上下文提取知识。# %%
chatglm = ChatGLM(model_new,tokenizer,max_chat_rounds=0) #支持多轮对话,可以从之前对话上下文提取知识。# %%
%%chatglm
类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞# %%
save_path = "chatglm2-6b-adgen"
model_new.save_pretrained(save_path, max_shard_size='2GB')
tokenizer.save_pretrained(save_path)# %%
!cp ChatGLM2-6B/chatglm2-6b/*.py chatglm2-6b-adgen/# %%
from transformers import  AutoModel,AutoTokenizermodel_name = "chatglm2-6b-adgen" 
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name,trust_remote_code=True).half().cuda()
response,history = model.chat(tokenizer,query = '你听说过梦中情炉吗?',history = [])
print(response)# %%
response,history = model.chat(tokenizer,query = '类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞',history = [])
print(response)# %%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/30546.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elasticsearch 性能调优指南

目录 1、通用优化策略 1.1 通用最小化法则 1.2 职责单一原则 1.3 其他 2、写性能调优 2.1 基本原则 2.2 优化手段 2.2.1 增加 flush 时间间隔, 2.2.2 增加refresh_interval的参数值 2.2.3 增加Buffer大小, 2.2.4 关闭副本 2.2.5 禁用swap 2…

嘉楠勘智k230开发板上手记录(四)--HHB神经网络模型部署工具

按照K230_AI实战_HHB神经网络模型部署工具.md,HHB文档,RISC-V 编译器和模拟器安装来 一、环境 1. 拉取docker 镜像然后创建docker容器并进入容器 docker pull hhb4tools/hhb:2.4.5 docker run -itd --namehhb2_4 -p 22 "hhb4tools/hhb:2.4.5"…

【CSS】背景图定位问题适配不同机型

需求 如图, 实现一个带有飘带的渐变背景 其中头像必须显示飘带凹下去那里 , 需要适配不同的机型, 一不下心容易错位 实现 因为飘带背景是版本迭代中更新的, 所以飘带和渐变背景实则两个div 飘带切图如下 , 圆形部分需要契合头像 <view class"box-bg"><…

Linux ——实操篇

Linux ——实操篇 前言vi 和 vim 的基本介绍vi和vim常用的三种模式正常模式插入模式命令行模式 vi和vim基本使用各种模式的相互切换vi和vim快捷键关机&重启命令基本介绍注意细节 用户登录和注销基本介绍使用细节 用户管理基本介绍添加用户基本语法应用案例细节说明 指定/修…

Java基础(九)数组工具类

数组工具类 1. Arrays类 a. 导入方法 import java.util.Arrays;b. Arrays类常用的方法 方法返回类型说明equals(array1, array2)boolean比较两个数组是否相等sort(array)void对数组 array 的元素进行排序toString(array)String把一个数组 array 转换成一个字符串fill(array,…

获取接口的所有实现

一、获取接口所有实现类 方法1&#xff1a;JDK自带的ServiceLoader实现 ServiceLoader是JDK自带的一个类加载器&#xff0c;位于java.util包当中&#xff0c;作为 A simple service-provider loading facility。 &#xff08;1&#xff09;创建接口 package com.example.dem…

ArcGIS API for JavaScript 4.x 教程(四) 添加点、线和多边形

了解如何在地图中显示点、线和多边形图形。 图形是用于在地图或场景中显示点、线、多边形和文本的视觉元素。图形由几何图形、符号和属性组成&#xff0c;单击时可以显示弹出窗口。您通常使用图形来显示未连接到数据库&#xff08;即GPS位置&#xff09;的地理数据。 在本教程…

Springboot中拦截GET请求获取请求参数验证合法性

目录 目的 核心方法 完整代码 创建拦截器 注册拦截器 测试效果 目的 在Springboot中创建拦截器拦截所有GET类型请求&#xff0c;获取请求参数验证内容合法性防止SQL注入&#xff08;该方法仅适用拦截GET类型请求&#xff0c;POST类型请求参数是在body中&#xff0c;所以下面…

K8S资源管理方式

K8S资源管理方式 文章目录 K8S资源管理方式一、陈述式资源管理1.基础命令操作2.创建pod3.查看资源状态4.查看pod中的容器日志5.进入pod中的容器6.删除pod资源7.pod扩容8.项目生命周期管理&#xff08;创建-->发布-->更新-->回滚-->删除&#xff09;8.1创建services…

3.1 计算机网络和网络设备

数据参考&#xff1a;CISP官方 目录 计算机网络基础网络互联设备网络传输介质 一、计算机网络基础 1、ENIAC&#xff1a;世界上第一台计算机的诞生 1946年2月14日&#xff0c;宾夕法尼亚大学诞生了世界上第一台计算机&#xff0c;名为电子数字积分计算机&#xff08;ENIAC…

准确率、召回率和F1数值区别

目录 准确率、召回率和F1数值 准确率、召回率和F1数值 一、准确率与召回率(Precision & Recall) 准确率和召回率是广泛用于信息检索和统计学分类领域的两个度量值,用来评价结果的质量。 其中精度是检索出相关文档数与检索出的文档总数的比率,衡量的是检索系统的查准…

【Autolayout案例02-距离四周边距 Objective-C语言】

一、好,来看第二个案例 1.第二个案例,是什么意思呢,第二个案例,要求屏幕中间,有一个UIView UIView,是个红色的UIView UIView的大小,我不限定 但是无论你是什么屏幕下 这个UIView距离上边,始终是50 距离右边,始终是50, 距离下边,始终是50, 距离左边,始终是5…

Nginx跳转模块——location与rewrite

一、location 1、location作用 用于匹配uri&#xff08;文件、图片、视频&#xff09; uri&#xff1a;统一资源标识符。是一种字符串标识&#xff0c;用于标识抽象的或物理资源文件、图片、视频 2、locatin分类 1、精准匹配&#xff1a;location / {...} 2、一般匹配&a…

PROFINET转DeviceNet网关普通网线能代替profinet吗

捷米JM-DNT-PN这款神器&#xff0c;连接PROFINET和DeviceNet网络&#xff0c;让两边数据轻松传输。 这个网关不仅从ETHERNET/IP和DEVICENET一侧读写数据&#xff0c;还可以将缓冲区数据交换&#xff0c;这样就可以在两个网络之间愉快地传递数据了&#xff01;而且&#xff0c;…

虚幻引擎游戏开发过程中,游戏鼠标如何双击判定?

UE虚幻引擎对于游戏开发者来说都不陌生&#xff0c;市面上有47%主机游戏使用虚幻引擎开发游戏。作为是一款游戏的核心动力&#xff0c;它的功能十分完善&#xff0c;囊括了场景制作、灯光渲染、动作镜头、粒子特效、材质蓝图等。本文介绍了虚幻引擎游戏开发过程中游戏鼠标双击判…

CSDN付费专栏写作协议

一、总则 1.1、欢迎您选用CSDN付费专栏服务&#xff08;“本服务”&#xff09;。以下所述条款和条件即构成您与CSDN就使用本服务所达成的协议&#xff08;“本协议&#xff09;。本协议被视为《CSDN用户服务条款》&#xff08;链接&#xff1a;https://passport.csdn.net/ser…

springboot+mybatis实现简单的增、删、查、改

这篇文章主要针对java初学者&#xff0c;详细介绍怎么创建一个基本的springboot项目来对数据库进行crud操作。 目录 第一步&#xff1a;准备数据库 第二步&#xff1a;创建springboot项目 方法1&#xff1a;通过spring官网的spring initilizer创建springboot项目 方法2&am…

JavaScript基础 第三天

1.for循环 2.数组的基本使用和操作 3.数组排序 一.for循环 ① 语法&#xff1a;把声明起始值&#xff0c;循环条件&#xff0c;变量值写到一起&#xff0c;让人一目了然 for(变量起始值;终止条件;变量变化量) {// 循环体 }举例&#xff1a; for (let i 0; i < 100; i)…

SQL SERVER ip地址改别名

SQL server在使用链接服务器时必须使用别名&#xff0c;使用ip地址就会把192.188.0.2这种点也解析出来 解决方案&#xff1a; 1、物理机ip 192.168.0.66 虚拟机ip 192.168.0.115 2、在虚拟机上找到 C:\Windows\System32\drivers\etc 下的 &#xff08;我选中的文件&a…

批量打印-----jsPDF将图片转为pdf,并合并pdf

注意一、 使用jspdf将图片&#xff08;jpg/jpeg/png/bmp&#xff09;转pdf&#xff08;记为pdfA&#xff09;&#xff0c;得到的pdf&#xff08;pdfA&#xff09;和需要合并的pdf(记为pdfB)类型不一致&#xff0c;需要将pdfA转为pdfB类型&#xff0c;才能合并&#xff0c;使用a…