Gemma

Gemma

  • 1.使用
  • 2.RAG
  • 3.LoRA
    • 3.1LoRA分类任务
    • 3.2LoRA中文建模任务

1.使用

首先是去HF下载模型,但一直下载不了,所以去了HF镜像网站,下载gemma需要HF的Token,按照步骤就可以下载。代码主要是Kaggle论坛里面的分享内容。

huggingface-cli download --token hf_XXX --resume-download google/gemma-7b --local-dir gemma-7b-mirror

这里我有时是2b有时是7b,换着用。

from transformers import AutoTokenizer, AutoModelForCausalLM  
tokenizer = AutoTokenizer.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")
Gemma = AutoModelForCausalLM.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")
def answer_the_question(question):input_ids = tokenizer(question, return_tensors="pt")generated_text = Gemma.generate(**input_ids,max_length=256)answer = tokenizer.decode(generated_text[0], skip_special_tokens=True)return answer
question = "给我写一首优美的诗歌?"
answer = answer_the_question(question)
print(answer)

2.RAG

参考

from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
##2.1 根据question检索sentence chunk
import os
def get_all_pdfs(directory):pdf_files = []for root, dirs, files in os.walk(directory):for file in files:if file.endswith(".pdf"):pdf_files.append(os.path.join(root, file))return pdf_filesclass RAG:def __init__(self, num_retrieved_docs=5, pdf_folder_path='D:/Gemma/PDF'):pdf_files = get_all_pdfs(pdf_folder_path)print("Documents used", pdf_files)loaders = [PyPDFLoader(pdf_file) for pdf_file in pdf_files]all_documents = []for loader in loaders:raw_documents = loader.load()text_splitter = CharacterTextSplitter(separator="\n\n",chunk_size=10,chunk_overlap=1,# length_function=len,)documents = text_splitter.split_documents(raw_documents)all_documents.extend(documents)embeddings = HuggingFaceEmbeddings(model_name="D:/Projects/model/m3e-base")    self.db = FAISS.from_documents(all_documents, embeddings)self.retriever = self.db.as_retriever(search_kwargs={"k": num_retrieved_docs})def search(self, query):docs = self.retriever.get_relevant_documents(query)return docs
retriever = RAG()
##2.2根据sentence chunk和question去回答
class Assistant:def __init__(self):self.tokenizer = AutoTokenizer.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")self.Gemma = AutoModelForCausalLM.from_pretrained("D:/Gemma/gemma-2b-int-mirror2")def create_prompt(self, query, retrieved_info):prompt = f"""你是人工智能助手,需要根据Relevant information里面的相关内容回答用户的Instruction,其中相关信息如下:Instruction: {query}Relevant information: {retrieved_info}Output:"""print(prompt)return promptdef reply(self, query, retrieved_info):prompt = self.create_prompt(query, retrieved_info)input_ids = self.tokenizer(query, return_tensors="pt").input_ids# Generate text with a focus on factual responsesgenerated_text = self.Gemma.generate(input_ids,do_sample=True,max_length=500,temperature=0.7, # Adjust temperature according to the task, for code generation it can be 0.9)# Decode and return the answeranswer = self.tokenizer.decode(generated_text[0], skip_special_tokens=True)return answer
chatbot = Assistant()
## 2.3开始使用RAG
def generate_reply(query):related_docs = retriever.search(query)#print('related docs', related_docs)reply = chatbot.reply(query, related_docs)return reply
reply = generate_reply("存在的不足及后续的优化工作")
for s in reply.split('\n'):print(s)

3.LoRA

3.1LoRA分类任务

参考
使用nlp-getting-started数据集训练模型做二分类任务。首先拿到源model

from datasets import load_dataset
from transformers import AutoTokenizer,AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments,pipeline
from peft import prepare_model_for_int8_training,LoraConfig, TaskType, get_peft_model
import numpy as np
NUM_CLASSES = 2#模型输出分类的类别数
BATCH_SIZE,EPOCHS,R,LORA_ALPHA,LORA_DROPOUT = 8,5,64,32,0.1#LoRA训练的参数
MODEL_PATH="D:/Gemma/gemma-2b-int-mirror2"#模型地址
# 1.源model,设置输出二分类
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH,num_labels=NUM_CLASSES)
print(model)

处理csv数据,将输入文字经过tokenizer编码处理

#2.处理dataset,输入过长进行truncation(tokenizer处理后)
dataset = load_dataset('csv', data_files='D:/Gemma/nlp-getting-started/train.csv')
dataset['test'] = dataset['train']
dataset = dataset.remove_columns(['id', 'keyword', 'location'])
dataset = dataset.rename_column("target", "label")#csv最后只保留了text列和label列
tokenized_dataset = {}#train和test
for split in dataset.keys():tokenized_dataset[split] = dataset[split].map(lambda x: tokenizer(x["text"], truncation=True), batched=True)
print(tokenized_dataset["train"])
print(tokenized_dataset["train"][1])

在源model基础上配置LoRA的参数,形成lora_model

#3.LoRA模型参数设置
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(r=R,lora_alpha=LORA_ALPHA,lora_dropout=LORA_DROPOUT,task_type=TaskType.SEQ_CLS,#SEQ_CLS:序列分类任务;TOKEN_CLS命名实体识别;SEQ2SEQ机器翻译;LM语言建模任务target_modules='all-linear'#all-linear所有线性层;embeddings嵌入层;convs卷积层
)
lora_model = get_peft_model(model, lora_config)
print(lora_model)
print(lora_model.print_trainable_parameters())#LoRA模型要训练的参数

配置lora_model的训练参数

#4.LoRA训练参数设置(损失计算等)
def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return {"accuracy": (predictions == labels).mean()}trainer = Trainer(model=lora_model,args=TrainingArguments(output_dir="./LoAR_data/",learning_rate=2e-5,per_device_train_batch_size=BATCH_SIZE,per_device_eval_batch_size=BATCH_SIZE,evaluation_strategy="epoch",save_strategy="epoch",num_train_epochs=EPOCHS,weight_decay=0.01,load_best_model_at_end=True,logging_steps=10,report_to="none"),train_dataset=tokenized_dataset["train"],eval_dataset=tokenized_dataset["test"],tokenizer=tokenizer,data_collator=DataCollatorWithPadding(tokenizer=tokenizer),compute_metrics=compute_metrics,
)

开始训练并保存使用模型

#5.训练并评估
print("Evaluating the Model Before Training!")
trainer.evaluate()
print("Training the Model")
trainer.train()
print("Evaluating the trained model")
trainer.evaluate()
#6.保存并使用
lora_model.save_pretrained('fine-tuned-model')
clf = pipeline("text-classification", lora_model, tokenizer=MODEL_PATH)#LoRA训练后的模型

3.2LoRA中文建模任务

参考
首先拿到源model和config

from transformers import AutoConfig,AutoTokenizer,AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model,prepare_model_for_kbit_training,PeftModel
import torch
import  datasets
from tqdm import tqdm
import json
BATCH_SIZE,EPOCHS,R,LORA_ALPHA,LORA_DROPOUT = 8,5,64,32,0.1#LoRA训练的参数
MODEL_PATH="D:/Gemma/gemma-2b-int-mirror2"#模型地址
device = torch.device('cuda:0')
# 1.源model和model的config
config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True)
config.is_causal = True  #确保模型在生成文本时只能看到左侧的上下文
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH,device_map="auto", config=config,trust_remote_code=True)

根据模型和config处理json数据

#2.根据model的config处理dataset(tokenizer处理后),并保存加载
def preprocess(tokenizer, config, file_path, max_seq_length, prompt_key, target_key, skip_overlength=False):with open(file_path, "r", encoding="utf8") as f:for line in tqdm(f.readlines()):example = json.loads(line)prompt_ids = tokenizer.encode(example[prompt_key], max_length=max_seq_length, truncation=True)target_ids = tokenizer.encode(example[target_key], max_length=max_seq_length, truncation=True)input_ids = prompt_ids + target_ids + [config.eos_token_id]if skip_overlength and len(input_ids) > max_seq_length:continueinput_ids = input_ids[:max_seq_length]yield {"input_ids": input_ids,"seq_len": len(prompt_ids)}
dataset = datasets.Dataset.from_generator(lambda: preprocess(tokenizer, config, "D:/Gemma/try/hc3_chatgpt_zh_specific_qa.json", max_seq_length=2000, prompt_key="q",target_key="a",))dataset.save_to_disk("h3c-chinese")  # 保存处理后的数据集
train_set = datasets.load_from_disk("h3c-chinese")#加载处理后的数据集

配置Lora参数

#3.LoRA模型参数设置
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(r=R,lora_alpha=LORA_ALPHA,lora_dropout=LORA_DROPOUT,task_type="CAUSAL_LM",target_modules='all-linear'
)
lora_model = get_peft_model(model, lora_config)
print(lora_model)
print(lora_model.print_trainable_parameters())#LoRA模型要训练的参数

配置lora的训练参数,包括损失计算compute_metrics,并对输入的input_ids构造输入样本列表批次处理。

tokenizer.pad_token_id = config.pad_token_id
def data_collator(features):#封装每一批数据forward前预处理的函数len_ids = [len(feature["input_ids"]) for feature in features]longest = max(len_ids)input_ids = []labels_list = []for ids_l, feature in sorted(zip(len_ids, features), key=lambda x: -x[0]):ids = feature["input_ids"]seq_len = feature["seq_len"]labels = ([-100] * (seq_len) + ids[seq_len:] + [-100] * (longest - ids_l))ids = ids + [tokenizer.pad_token_id] * (longest - ids_l)input_ids.append(torch.LongTensor(ids))labels_list.append(torch.LongTensor(labels))return {"input_ids": torch.stack(input_ids),"labels": torch.stack(labels_list),}
def compute_metrics(inputs):  # 使用模型计算损失  loss = model(input_ids=inputs["input_ids"], labels=inputs["labels"]).loss  return {  "loss": loss.item()}# 将Tensor转换为Python数字  
trainer = Trainer(model=lora_model,args=TrainingArguments(output_dir="./LoAR_data2/",learning_rate=2e-5,per_device_train_batch_size=BATCH_SIZE,save_strategy="epoch",num_train_epochs=EPOCHS,weight_decay=0.01,logging_steps=10,report_to="none"),train_dataset=train_set,tokenizer=tokenizer,data_collator=data_collator,compute_metrics=compute_metrics
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/706373.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

3D可视化项目,选择unity3D还是three.js,是时候挑明了。

2023-08-10 23:07贝格前端工场 Hi,我是贝格前端工场,在开发3D可视化项目中,是选择U3D还是three,js时,很多老铁非常的迷茫,本文给老铁们讲清楚该如何选择,欢迎点赞评论分享转发。 一、Unity3D和three.js简…

RTCA DO-178C 机载系统和设备认证中的软件注意事项-附录 B

ANNEX B 附录 B 缩略语和术语表 ACRONYMS AND GLOSSARY OF TERMS 缩写 Acronym 释义 Meaning 译文 Translate ARP Aerospace Recommended Practice 航空航天推荐做法 ATM Air Traffic Management 空中交通管理 CAST Certification Authorities Software Team 认证机…

小程序里.vue界面中传值的两种方式

1.跳转携带参数后通过生命周期取值 1.1跳转 function juMp(){let arr JSON.stringify(specs.specs_data)wx.navigateTo({url:/pages/specs/specs?sku arr})}1.2取值 import {onLoad} from dcloudio/uni-apponLoad((event)>{let Arr JSON.parse(event.sku)})2.通过监听器…

String类-equals和==的区别-遍历-SubString()-StringBuilder-StringJoiner-打乱字符串

概述 String 类代表字符串,Java 程序中的所有字符串文字(例如“abc”)都被实现为此类的实例。也就是说,Java 程序中所有的双引号字符串,都是 String 类的对象。String 类在 java.lang 包下,所以使用的时候…

jquery实现select2插件鼠标点击任意地方时默认选中该输入框内的值

jquery实现select2插件鼠标点击任意地方时默认选中该输入框内的值 最近发现一个问题,插件select2中的select2可输入可选择的下拉框,在你输入值后鼠标点击别的地方,输入框内的值会被清空,特此记录一下这里的优化,这里修…

[Mac软件]Adobe Substance 3D Stager 2.1.4 3D场景搭建工具

应用介绍 Adobe Substance 3D Stager,您设备齐全的虚拟工作室。在这个直观的舞台工具中构建和组装 3D 场景。设置资产、材质、灯光和相机。导出和共享媒体,从图像到 Web 和 AR 体验。 处理您的最终图像 Substance 3D Stager 可让您在上下文中做出创造性…

网络原理——HTTPS

HTTPS是 在HTTP的基础上,引入了一个加密层(SSL)。 1. 为什么需要HTTPS 在我们使用浏览器下载一些软件时,相信大家都遇到过这种情况:明明这个链接显示的是下载A软件,点击下载时就变成了B软件,这种情况是运…

计算机设计大赛 深度学习手势检测与识别算法 - opencv python

文章目录 0 前言1 实现效果2 技术原理2.1 手部检测2.1.1 基于肤色空间的手势检测方法2.1.2 基于运动的手势检测方法2.1.3 基于边缘的手势检测方法2.1.4 基于模板的手势检测方法2.1.5 基于机器学习的手势检测方法 3 手部识别3.1 SSD网络3.2 数据集3.3 最终改进的网络结构 4 最后…

深入理解Python中的JSON模块:基础大总结与实战代码解析【第102篇—JSON模块】

深入理解Python中的JSON模块:基础大总结与实战代码解析 在Python中,JSON(JavaScript Object Notation)模块是处理JSON数据的重要工具之一。JSON是一种轻量级的数据交换格式,广泛应用于Web开发、API通信等领域。本文将…

WinForms中的Timer探究:Form Timer与Thread Timer的差异

WinForms中的Timer探究:Form Timer与Thread Timer的差异 在Windows Forms(WinForms)应用程序开发中,定时器(Timer)是一个常用的组件,它允许我们执行定时任务,如界面更新、周期性数据…

Matlab 矩阵基础

Matlab 基础 MATLAB 是“矩阵实验室matrix laboratory”的缩写。其他编程语言大多一次处理一个数字,MATLAB 主要用于处理整个矩阵和数组。 所有 MATLAB 变量都是多维数组,无论数据类型如何。矩阵是常用于线性代数的二维数组。 若要创建一个包含单行中…

osi模型,tcp/ip模型(名字由来+各层介绍+中间设备介绍)

目录 网络协议如何分层 引入 osi模型 tcp/ip模型 引入 命名由来 介绍 物理层 数据链路层 网络层 传输层 应用层 中间设备 网络协议如何分层 引入 我们已经知道了网络协议是层状结构,接下来就来了解了解下网络协议如何分层 常见的网络协议分层模型是OSI模型 和 …

Flink CDC 3.0 Starrocks建表失败会导致任务卡主!

Flink CDC 3.0 Starrocks建表失败会导致任务卡主! 现象 StarRocks建表失败,然后任务自动重启,重启完毕后数据回放,jobMaster打印下面日志后,整个任务会卡主 There are already processing requests. Wait for proce…

windows 连接 Ubuntu 失败 -- samba服务

1. windows10连接ubuntu的时候,提示不允许一个用户使用一个以上用户名与服务器或共享资源的多重连接,中断与此服务器或共享资源的所有连接,然后再试一次 2. 换一台同事的电脑却又可以连上,我之前一直能用的,隔一段时间…

PostgreSQL创建数据库、数据库管理员用户、该库的只读用户

1.创建用户: create user pgdbAdmin with password "Pgdb_15432";2.创建数据库: create database pgdb owner pgdbAdmin;3.创建SCHEMA; create schema pgdbAdmin;4.赋予数据库管理员用户权限: grant all privileges…

UE5 C++ 单播 多播代理 动态多播代理

一. 代理机制,代理也叫做委托,其作用就是提供一种消息机制。 发送方 ,接收方 分别叫做 触发点和执行点。就是软件中的观察者模式的原理。 创建一个C Actor作为练习 二.单播代理 创建一个C Actor MyDeligateActor作为练习 在MyDeligateAc…

【蓝桥杯】包子凑数(DP)

一.题目描述 二.输入描述 三.输出描述 四.问题分析 几个两两互质的数,最大公约数是1,最小公倍数是他们的乘积。 两个互质的数a和b最小不能表示的数就是(a-1)(b-1)-1,即,两个互质的数…

uniapp_微信小程序日历

一、需求要求这样 二、代码实现 <view class"calender" click"showriliall"><text class"lineText">探视日期&#xff1a;</text><text class"middleText">{{timerili}}</text><image src"/s…

Ubuntu服务器fail2ban的使用

作用&#xff1a;限制ssh远程登录&#xff0c;防止被人爆破服务器&#xff0c;封禁登录ip 使用lastb命令可查看到登录失败的用户及ip&#xff0c;无时无刻的不在爆破服务器 目录 一、安装fail2ban 二&#xff0c;配置fail2ban封禁ip的规则 1&#xff0c;进入目录并创建ssh…

CVE-2024-0713 Monitorr 服务配置 upload.php 无限制上传漏洞

### Monitorr是一个自托管的PHP网络应用&#xff0c;可以监控本地和远程网络服务、网站和应用的状态。经过分析&#xff0c;该系统存在文件上传漏洞&#xff0c;攻击者可以通过该漏洞上传webshell至目标系统从而获取目标系统权限。 漏洞的位置在 assets\php\upload.php &#…