LLM | Gemma的初体验

一起来体验一下吧~

技术报告书:jgoogle/gemma-7b-it · Hugging Facegemma-report.pdf (storage.googleapis.com)

代码1  :google-deepmind/gemma: Open weights LLM from Google DeepMind. (github.com)

代码2 :https://github.com/google/gemma_pytorch

代码3 :

技术报告书:jgoogle/gemma-7b-it · Hugging Face

1.论文详解

谷歌介绍的Gemma的主要特点如下。

  • 新车型将提供两种变体:Gemma 2B 和 Gemma 7B。这两种类型分别带有预训练和指令调整的变体。
  • 新的负责任的生成式 AI 工具包提供指导和基本工具,帮助您使用 Gemma 构建更安全的 AI 应用程序。
  • Native Keras 3.0 在 JAX、PyTorch 和 TensorFlow 等领先框架中提供了用于监督微调 (SFT) 的工具链。
  • 它配备了即用型 Colab 和 Kaggle 笔记本电脑,以及 Hugging Face、MaxText 和 NVIDIA NeMo 等通用工具,使 Gemma 易于用户使用。
  • 预先训练和指令调整的 Gemma 模型可在笔记本电脑、工作计算机甚至 Google Cloud 上使用,并且使用 Vertex AI 和 Google Kubernetes Engine (GKE) 轻松安装。
  • Gemma 针对各种 AI 硬件平台进行了优化,确保了行业领先的性能,包括 NVIDIA GPU 和 Google Cloud TPU。
  • 条款和条件允许负责任的商业用途和分发给各种规模的企业。

Gemma 是由 Google 推出的一系列轻量级、先进的开源模型,基于 Google Gemini 模型的研究和技术而构建。它们是一系列text generation,decoder-only的大型语言模型,对英文的支持较好,具有模型权重开源、并提供预训练版本(base模型)和指令微调版本(chat模型)。

本次 Gemma 开源提供了四个大型语言模型,提供了 2B 和 7B 两种参数规模的版本,每种都包含了预训练版本(base模型)和指令微调版本(chat模型)。

官方除了提供 pytorch 版本之外,也提供了GGUF版本,可在各类消费级硬件上运行,无需数据量化处理,并拥有高达 8K tokens 的处理能力,Gemma 7B模型的预训练数据高达6万亿Token,也证明了通过大量的高质量数据训练,可以大力出奇迹,小模型也可以持续提升取得好的效果。


1.1.模型构造

采用transformer的解码器

下图是针对不同任务/数据集的结果对比,取得了一个很好的结果。

微调

  • 使用 QLoRA 对 UltraChat 数据集执行监督微调 (SFT) 的脚本
  • 在 TPU 设备上使用 FSDP 执行 SFT 的脚本
  • 可以在免费套餐 Google Colab 实例上运行的笔记本,用于对英语报价数据集执行 SFT

2.实战 

对于 Gemma 型号的 7B 指令版本。还可以选择 2B 基本模型、7B 基本模型和 2B 指导模型的模型卡。

2.0环境设置

pip install -U transformers
pip install packaging
pip install accelerate
pip install -U scikit-learn scipy matplotlib

更多请参考A1 

git clone https://huggingface.co/google/gemma-7b-it/ 
cd gemma-7b-it

 

2.1.在CPU上运行,这里省略

from transformers import AutoTokenizer, AutoModelForCausalLMtokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it")input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt")outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

2.2.在单个/多个 GPU 上运行模型

本文RAM24G,2张TITAN卡

# pip install accelerate
from transformers import AutoTokenizer, AutoModelForCausalLMtokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it", device_map="auto")input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")outputs = model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))

运行后

默认 max_length` =20

可能会出现如下情况,

修改输出outputs = model.generate(**input_ids,max_length=64)就好啦~

2.3.模型微调

本文使用代码数据集MBPP进行微调

import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig,AutoTokenizer,TrainingArguments,
)
from trl import SFTTrainermodel_name = "google/gemma-7b-it"
#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, add_eos_token=True, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id =  tokenizer.eos_token_id
tokenizer.padding_side = 'left'#ds = load_dataset("timdettmers/openassistant-guanaco")
ds = load_dataset("mbpp") 
#ds = load_dataset("Muennighoff/mbpp")  # 974 rows test only
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=compute_dtype,bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map={"": 0}
)
model = prepare_model_for_kbit_training(model)
#Configure the pad token in the model
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching
peft_config = LoraConfig(lora_alpha=16,lora_dropout=0.05,r=16,bias="none",task_type="CAUSAL_LM",target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj', "gate_proj", "down_proj", "up_proj"]
)
training_arguments = TrainingArguments(output_dir="./results_qlora",evaluation_strategy="steps",do_eval=True,optim="paged_adamw_8bit",per_device_train_batch_size=2,per_device_eval_batch_size=2,log_level="debug",save_steps=50,logging_steps=50,learning_rate=2e-5,eval_steps=50,max_steps=300,warmup_steps=30,lr_scheduler_type="linear",
)
trainer = SFTTrainer(model=model,train_dataset=ds['train'],eval_dataset=ds['test'],peft_config=peft_config,dataset_text_field="text",max_seq_length=512,tokenizer=tokenizer,args=training_arguments,
)
trainer.train()

仅仅测试训练,所以设置4 个,但是训练相对来说蛮快的,4个epoch30分钟左右。

 

过程中遇到的问题及解决【PS】

【PS1】 Traceback (most recent call last):
  File "/root/anaconda3/envs/sam/lib/python3.8/site-packages/huggingface_hub/utils/_errors.py", line 304, in hf_raise_for_status
    response.raise_for_status()
  File "/root/anaconda3/envs/sam/lib/python3.8/site-packages/requests/models.py", line 1021, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 401 Client Error: Unauthorized for url: https://huggingface.co/google/gemma-7b-it/resolve/main/config.json

from huggingface_hub import login
login()

在自己的huggingface账号上创建一个token

点击自己账号->Access Tokens->New token -> 输入名称->选择可写入->Generate a token 

生成token后,复制到服务器里~

登录

【PS2】ImportError: cannot import name 'ClusterInfo' from 'triton._C.libtriton.triton' (unknown location)

 尝试

# 怀疑是triton版本问题pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly#或者
pip install Triton==2.1.0#之前的版本是2.0.0输入
python -m bitsandbytes

 

 然后就可以啦~

扩展

A1

accelerate==0.27.2
bitsandbytes==0.42.0
huggingface-hub==0.21.3
peft==0.9.0

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/726821.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

报名开启丨掘金海外,探寻泛娱乐社交APP出海新风口

随着国内泛娱乐行业用户规模趋于见顶,泛娱乐社交APP转向出海是必然趋势。 根据行业数据显示,有超过35%的国内实时社交企业已启动或者正在规划出海,而其中出海商户的音视频流量增长均超过了100%。尤其是在东南亚、中东、拉美等新兴…

Maya笔记 软选择

文章目录 1什么是软选择2注意3如何打开软选择3.1方法一3.2方法二 4调整软选择的范围5衰减模式5.1体积模式5.2表面模式 6衰减曲线 1什么是软选择 也就是渐变选择,从中心点向外影响力度越来越小 软选择针对的是点线面这些模型元素 下图中展示了对被软选择的区域移动…

Rust入门:Rust如何调用C静态库的函数

关于Rust调用C,因为接口比较复杂,貌似Rust不打算支持。而对于C函数,则相对支持较好。 如果要研究C/Rust相互关系的话,可以参考: https://docs.rs/cxx/latest/cxx/ Rust ❤️ C 这里只对调用C静态库做一个最简短的介…

干货教程【软件篇】如何在Windows上安装Python环境以及设置国内源(Miniconda/Anaconda安装)

本文章涉及的所有安装包均在文章下方公众号中,回复python即可获取资源。 也可关注我们的官方网站: 考拉AI 小白安装前须了解一下 Python解释器是用来解释运行我们编写的Python代码。 Python标准库是Python自带的一系列标准模块,提供了各种…

浏览器修改接口返回数据展示在页面上

前端自己调试,想修改接口返回来的数据,然后展示在页面上 举例 接口返回了数据,想要修改此数据 这时就可以修改数据了,修改完成保存 然后刷新页面就会使用本地保存的数据了

Linux编程3.4 进程-进程标识

1、相关函数 #include<unistd.h> #include<sys/types.h> pid_t getpid(void); 获得当前进程ID uid_t getuid(void); 获得当前进程的实际用户ID uit_t geteuid(void); 获得当前进程的有效用户ID git_t getgid(void); 获得当前进程的用户组ID pit_t getppid(…

信号处理--基于单通道脑电信号EEG的睡眠分期评估

背景 睡眠对人体健康很重要。监测人体的睡眠分期对于人体健康和医疗具有重要意义。 亮点 架构在第一层使用两个具有不同滤波器大小的 CNN 和双向 LSTM。 CNN 可以被训练来学习滤波器&#xff0c;以从原始单通道 EEG 中提取时不变特征&#xff0c;而双向 LSTM 可以被训练来将…

数据库-DDL

show databases; 查询所有数据库 select database(); 查询当前数据库 use 数据库名&#xff1b; 使用数据库 creat database[if not exists] 数据库名…

vue2中如何实现添加一个空标签的效果,</>

前言&#xff1a; 众所周知&#xff0c;vue3突破了每一个vue文件中只能有一个根标签的限制&#xff0c;但是我们还有很多项目都是vue2的项目&#xff0c;如果让vue2中实现一个类似</>的效果呢&#xff0c;像react的16.2.0的版本之后&#xff0c;可以直接用<></&…

2024 年 AI 辅助研发趋势

随着人工智能技术的持续发展与突破&#xff0c;2024年AI辅助研发正成为科技界和工业界瞩目的焦点。从医药研发到汽车设计&#xff0c;从软件开发到材料科学&#xff0c;AI正逐渐渗透到研发的各个环节&#xff0c;变革着传统的研发模式。在这一背景下&#xff0c;AI辅助研发不仅…

【动态规划】【数论】【区间合并】3041. 修改数组后最大化数组中的连续元素数目

作者推荐 视频算法专题 本文涉及知识点 动态规划汇总 数论 区间合并 LeetCode3041. 修改数组后最大化数组中的连续元素数目 给你一个下标从 0 开始只包含 正 整数的数组 nums 。 一开始&#xff0c;你可以将数组中 任意数量 元素增加 至多 1 。 修改后&#xff0c;你可以从…

Spring Boot 3核心技术与最佳实践

&#x1f482; 个人网站:【 海拥】【神级代码资源网站】【办公神器】&#x1f91f; 基于Web端打造的&#xff1a;&#x1f449;轻量化工具创作平台&#x1f485; 想寻找共同学习交流的小伙伴&#xff0c;请点击【全栈技术交流群】 highlight: a11y-dark 引言 Spring Boot作为…

企业财务分析该怎么做?重点分析哪些财务指标?

在企业经营管理的过程中&#xff0c;财务分析是评估当前企业或特定部门财务状况和绩效的过程&#xff0c;这一过程通常涉及对财务报表&#xff08;如资产负债表、利润表和现金流量表&#xff09;进行定量和定性的评估&#xff0c;以便为盈利能力、偿债能力、现金流动性和资金稳…

解决 RuntimeError: “LayerNormKernelImpl“ not implemented for ‘Half‘

解决 RuntimeError: “LayerNormKernelImpl” not implemented for ‘Half’。 错误类似如下&#xff1a; Traceback (most recent call last): File “cli_demo.py”, line 21, in for results in webglm.stream_query(question): File “/root/WebGLM/model/modeling_webgl…

<C++>【继承篇】

​ ✨前言✨ &#x1f393;作者&#xff1a;【 教主 】 &#x1f4dc;文章推荐&#xff1a; ☕博主水平有限&#xff0c;如有错误&#xff0c;恳请斧正。 &#x1f4cc;机会总是留给有准备的人&#xff0c;越努力&#xff0c;越幸运&#xff01; &#x1f4a6;导航助手&#x1…

Vue+SpringBoot打造校园疫情防控管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 学生2.2 老师2.3 学校管理部门 三、系统展示四、核心代码4.1 新增健康情况上报4.2 查询健康咨询4.3 新增离返校申请4.4 查询防疫物资4.5 查询防控宣传数据 五、免责说明 一、摘要 1.1 项目介绍 基于JAVAVueSpringBoot…

入门版式设计:设计小白的必备知识!

你曾经被一本华丽的杂志、一张引人注目的海报或一个优雅的网站设计所吸引吗&#xff1f;这些都是版式设计的魅力所在。作为一个设计小白&#xff0c;我们可能不熟悉版式设计&#xff0c;但事实上&#xff0c;它无处不在&#xff0c;深深影响着我们的生活。那么&#xff0c;什么…

大型网站架构演化总结

本文图解大型网站架构演化。 目录 1、单一应用服务阶段 2、应用与数据服务分离阶段 3、利用缓存提高性能阶段 4、应用服务集群阶段 5、数据库读写分离阶段 6、反向代理与CDN加速阶段 7、分布式数据库阶段 8、 NoSQL与搜索引擎阶段 9、业务拆分阶段 10、分布式服务阶…

Leetcode刷题(三十八)

旋转矩阵&#xff08;Medium&#xff09; 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。你必须在 原地 旋转图像&#xff0c;这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。示例 1&#xff1a;输入&#xff1a;mat…

基于springboot+vue的医疗挂号管理系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…