【AI大模型】Transformers大模型库(五):AutoModel、Model Head及查看模型结构

目录​​​​​​​

一、引言 

二、自动模型类(AutoModel)

2.1 概述

2.2 Model Head(模型头)

2.3 代码示例

三、总结


一、引言 

 这里的Transformers指的是huggingface开发的大模型库,为huggingface上数以万计的预训练大模型提供预测、训练等服务。

🤗 Transformers 提供了数以千计的预训练模型,支持 100 多种语言的文本分类、信息抽取、问答、摘要、翻译、文本生成。它的宗旨是让最先进的 NLP 技术人人易用。
🤗 Transformers 提供了便于快速下载和使用的API,让你可以把预训练模型用在给定文本、在你的数据集上微调然后通过 model hub 与社区共享。同时,每个定义的 Python 模块均完全独立,方便修改和快速研究实验。
🤗 Transformers 支持三个最热门的深度学习库: Jax, PyTorch 以及 TensorFlow — 并与之无缝整合。你可以直接使用一个框架训练你的模型然后用另一个加载和推理。

本文重点介绍自动模型类(AutoModel)。

二、自动模型类(AutoModel)

2.1 概述

AutoModel是Hugging Face的Transformers库中的一个非常实用的类,它属于自动模型选择的机制。这个设计允许用户在不知道具体模型细节的情况下,根据给定的模型名称或模型类型自动加载相应的预训练模型。它减少了代码的重复性,并提高了灵活性,使得开发者可以轻松地切换不同的模型进行实验或应用。

2.2 Model Head(模型头)

Model Head在预训练模型的基础上添加一层或多层的额外网络结构来适应特定的模型任务,方便于开发者快速加载transformers库中的不同类型模型,不用关心模型内部细节。

  •  ForCausalLM:因果语言模型头,用于decoder类型的任务,主要进行文本生成,生成的每个词依赖于之前生成的所有词。比如GPT、Qwen
  •  ForMaskedLM:掩码语言模型头,用于encoder类型的任务,主要进行预测文本中被掩盖和被隐藏的词,比如BERT。
  •  ForSeq2SeqLM:序列到序列模型头,用于encoder-decoder类型的任务,主要处理编码器和解码器共同工作的任务,比如机器翻译或文本摘要。
  • ForQuestionAnswering:问答任务模型头,用于问答类型的任务,从给定的文本中抽取答案。通过一个encoder来理解问题和上下文,对答案进行抽取。
  • ForSequenceClassification:文本分类模型头,将输入序列映射到一个或多个标签。例如主题分类、情感分类。
  • ForTokenClassification:标记分类模型头,用于对标记进行识别的任务。将序列中的每个标记映射到一个提前定义好的标签。如命名实体识别,打标签
  • ForMultiplechoice:多项选择任务模型头,包含多个候选答案的输入,预测正确答案的选项。

2.3 代码示例

对于目前常见的LLM,比如GLM、Qwen、Baichuan等,通常使用AutoModelForCausalLM模型头进行加载,比如下面代码中使用AutoModelForCausalLM.from_pretrained加载Qwen2模型。

from modelscope import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
#model_dir = snapshot_download('ZhipuAI/glm-4-9b-chat')
model_dir = snapshot_download('Qwen/Qwen2-7B-Instruct')
import torchdevice = "cuda:2" # the device to load the model ontotokenizer = AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)prompt = "介绍一下大语言模型"
messages = [{"role": "system", "content": "你是一个智能助理."},{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)model = AutoModelForCausalLM.from_pretrained(model_dir,device_map="cuda:2",trust_remote_code=True,output_attentions=True
)gen_kwargs = {"max_length": 512, "do_sample": True, "top_k": 1}
with torch.no_grad():outputs = model.generate(**model_inputs, **gen_kwargs)outputs = outputs[:, model_inputs['input_ids'].shape[1]:] #切除system、user等对话前缀print(tokenizer.decode(outputs[0], skip_special_tokens=True))print(model)

AutoModelForCausalLM.from_pretrained常见参数:

  • model_name_or_path (str): 指定预训练模型的名称或模型文件的路径。例如,"gpt2"、"distilgpt2"或本地模型文件夹的路径。
  • config (Optional[PretrainedConfig]): 模型配置对象或其配置的字典。通常不需要手动提供,因为如果未提供,它会根据model_name_or_path自动加载。
  • tokenizer (Optional[PreTrainedTokenizer]): 与模型一起使用的分词器。如果提供,可以用于快速预处理文本数据。如果未提供,某些功能可能受限。
  • cache_dir (Optional[str]): 用于存储下载的模型文件的缓存目录路径。
  • from_tf (bool, default=False): 是否从TensorFlow检查点加载模型。
  • force_download (bool, default=False): 是否强制重新下载模型,即使本地已有。
  • resume_download (bool, default=False): 是否从上次下载中断的地方继续下载。
  • proxies (dict, optional): 如果需要通过代理服务器下载模型,可以提供代理的字典。
  • output_loading_info (bool, default=False): 是否返回加载模型时的详细信息。
  • local_files_only (bool, default=False): 是否仅从本地文件加载模型,不尝试在线下载。
  • low_cpu_mem_usage (bool, default=False): 是否优化模型加载以减少CPU内存使用,这对于大型模型特别有用。
  • device_map (Optional[Dict[str, Union[int, str]]]): 用于在多GPU或特定设备上分配模型的字典。在PyTorch 2.0及Transformers的相应版本中更为常见。
  • revision (str, optional): 指定模型版本或分支,用于从Hugging Face Hub加载特定版本的模型。
  • use_auth_token (Optional[Union[str, bool]]): 如果模型存储在私有仓库中,需要提供访问令牌。

 特别有用的一个功能就是输出模型结构,有助于快速理解模型

from modelscope import snapshot_download
model_dir = snapshot_download('Qwen/Qwen2-7B-Instruct')from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_dir)
print(model)

Qwen2的模型结构如下

Qwen2ForCausalLM((model): Qwen2Model((embed_tokens): Embedding(152064, 3584)(layers): ModuleList((0-27): 28 x Qwen2DecoderLayer((self_attn): Qwen2SdpaAttention((q_proj): Linear(in_features=3584, out_features=3584, bias=True)(k_proj): Linear(in_features=3584, out_features=512, bias=True)(v_proj): Linear(in_features=3584, out_features=512, bias=True)(o_proj): Linear(in_features=3584, out_features=3584, bias=False)(rotary_emb): Qwen2RotaryEmbedding())(mlp): Qwen2MLP((gate_proj): Linear(in_features=3584, out_features=18944, bias=False)(up_proj): Linear(in_features=3584, out_features=18944, bias=False)(down_proj): Linear(in_features=18944, out_features=3584, bias=False)(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm()(post_attention_layernorm): Qwen2RMSNorm()))(norm): Qwen2RMSNorm())(lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)

三、总结

本文对使用transformers的AutoModel自动模型类进行介绍,主要用于加载transformers模型库中的大模型,文中详细介绍了应用于不同任务的Model Head(模型头)、使用模型头、输出模型结构等关于AutoModel常用的方法。希望对您有帮助。

如果您还有时间,可以看看我的其他文章:

《AI—工程篇》

AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效

AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署

AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署

AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署

AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署

《AI—模型篇》

AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用

AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat 训练与推理实战

AI智能体研发之路-模型篇(三):中文大模型开、闭源之争

AI智能体研发之路-模型篇(四):一文入门pytorch开发

AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比

AI智能体研发之路-模型篇(六):【机器学习】基于tensorflow实现你的第一个DNN网络

AI智能体研发之路-模型篇(七):【机器学习】基于YOLOv10实现你的第一个视觉AI大模型

AI智能体研发之路-模型篇(八):【机器学习】Qwen1.5-14B-Chat大模型训练与推理实战

AI智能体研发之路-模型篇(九):【机器学习】GLM4-9B-Chat大模型/GLM-4V-9B多模态大模型概述、原理及推理实战

《AI—Transformers应用》

【AI大模型】Transformers大模型库(一):Tokenizer

【AI大模型】Transformers大模型库(二):AutoModelForCausalLM

【AI大模型】Transformers大模型库(三):特殊标记(special tokens)

【AI大模型】Transformers大模型库(四):AutoTokenizer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/25234.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用 Keras 的 Stable Diffusion 实现高性能文生图

前言 在本文中,我们将使用基于 KerasCV 实现的 [Stable Diffusion] 模型进行图像生成,这是由 stable.ai 开发的文本生成图像的多模态模型。 Stable Diffusion 是一种功能强大的开源的文本到图像生成模型。虽然市场上存在多种开源实现可以让用户根据文本…

【会议征稿,IEEE出版】第三届能源与电力系统国际学术会议 (ICEEPS 2024,7月14-16)

如今,全球能源行业正面临着前所未有的挑战。一方面,加快向清洁、可再生能源转型是遏制能源环境污染问题的最佳途径之一;另一方面,电力系统中新能源发电、人工智能技术、电力电子装备等被广泛应用和期待,以提高能源可持…

transformer - 注意力机制

Transformer 的注意力机制 Transformer 是一种用于自然语言处理任务的模型架构,依赖于注意力机制来实现高效的序列建模。注意力机制允许模型在处理一个位置的表示时,考虑输入序列中所有其他位置的信息,而不仅仅是前面的几个位置。这种机制能…

ATTCK红队评估(五)

环境搭建 靶场拓扑图: 靶机下载地址: 漏洞详情 外网信息收集 确定目标靶机地址: 发现主机192.168.135.150主机是本次攻击的目标地址。探测靶机开放的端口信息: 目标靶机开放了两个端口:80、3306,那没什么意外的话就是…

每天壁纸不重样~下载必应每日图片

下载必应每日图片 必应不知道你用过没有你下载过必应的图片没有你又没搜索过桌面图片你是不是安装过桌面图片软件你是不是为找一个好看的图片下载过很多桌面软件 必应每日图片 必应每天都会有一张不同的风景图片,画质清晰,而且不收费可以下载使用 但…

重生之我要精通JAVA--第八周笔记

文章目录 多线程线程的状态线程池自定义线程池最大并行数多线程小练习 网络编程BS架构优缺点CS架构优缺点三要素IP特殊IP常用的CMD命令 InetAddress类端口号协议UDP协议(重点)UDP三种通信方式 TCP协议(重点)三次握手四次挥手 反射…

sqlmap直接嗦 dnslog注入 sqllibs第8关

dnslog注入是解决注入的时候没有回显的情况,通过dns外带来进行得到我们想要的数据。 我们是用了dns解析的时候会留下记录,这时候就可以看见我们想要的内容。 这个时候我们还要了解unc路径以及一个函数load_file()以及concat来进行注入。看看我的笔记 unc…

sqli-labs 靶场 less-8、9、10 第八关到第十关详解:布尔注入,时间注入

SQLi-Labs是一个用于学习和练习SQL注入漏洞的开源应用程序。通过它,我们可以学习如何识别和利用不同类型的SQL注入漏洞,并了解如何修复和防范这些漏洞。Less 8 SQLI DUMB SERIES-8判断注入点 当输入id为1时正常显示: 加上单引号就报错了 …

零基础非科班也能掌握的C语言知识19 动态内存管理

动态内存管理 1.为什么要有动态内存分配2.malloc和free2.1 malloc2.2 free 3.calloc和realloc3.1 calloc3.2realloc 4.常见的动态内存的错误4.1对NULL指针的解引用操作4.2对动态开辟空间的越界访问4.3对非动态内存开辟的空间free4.4使用free释放⼀块动态开辟内存的⼀部分4.5对同…

在Anaconda中安装keras-contrib库

文章目录 1. 有git2. 无git2.1 步骤12.2 步骤22.3 步骤3 1. 有git 如果环境里有git,直接运行以下命令: pip install githttps://www.github.com/farizrahman4u/keras-contrib.git2. 无git 2.1 步骤1 打开网址:https://github.com/keras-tea…

Vue3【十四】watchEffect自动监视多个数据实现,不用明确指出监视哪个数据

Vue3【十四】watchEffect自动监视多个数据实现&#xff0c;不用明确指出监视哪个数据 Vue3【十四】watchEffect自动监视多个数据实现&#xff0c;不用明确指出监视哪个数据 进入立即执行一次&#xff0c;并监视数据变化 案例截图 目录结构 代码 Person.vue <template>&…

Java----抽象类和接口

欢迎大家来这次博客-----抽象类和接口。 1.抽象类 1.1 抽象类概念 在Java中我们都是通过类来描述对象&#xff0c;但反过来并不是所有的类都是用来描述对象的。当一个类中没有足够的信息来描述一个具体对象&#xff0c;我们就将该类称为抽象类。 如上图中的Shape类&#xff…

通用Mapper基础学习

一、引入 二、快速入门 1.创建测试数据 2.搭建MyBatis+Spring 开发环境 3.集成Mapper 4.第一个操作 Mapper接口源码介绍: 创建测试类: 三、常见操作

统计信号处理基础 习题解答10-9

题目 某质检员的工作是监控制造出来的电阻阻值。为此他从一批电阻中选取一个并用一个欧姆表来测量它。他知道欧姆表质量较差&#xff0c;它给测量带来了误差&#xff0c;这个误差可以看成是一个的随机变量。为此&#xff0c;质检员取N个独立的测量。另外&#xff0c;他知道阻值…

FreeRTOS基础(十三):队列集

队列集&#xff08;Queue Set&#xff09;通常指的是一组队列&#xff0c;它们可以用于处理不同的任务或数据流。每个队列可以独立地处理自己的元素&#xff0c;但作为一个集群&#xff0c;它们可以协同工作来完成更复杂的任务。下面进行介绍。 目录 一、队列集简介 二、队列…

详解 Flink 的 ProcessFunction API

一、Flink 不同级别的 API Flink 拥有易于使用的不同级别分层 API 使得它是一个非常易于开发的框架最底层的 API 仅仅提供了有状态流处理&#xff0c;它将处理函数&#xff08;Process Function &#xff09;嵌入到了 DataStream API 中。底层处理函数&#xff08;Process Func…

HarmonyOS开发-鸿蒙UiAbility 组件间跳转

前言 随着春节假期结束各行各业复产复工&#xff0c;一年一度的春招也持续火热起来。最近&#xff0c;有招聘平台发布了《2024年春招市场行情周报&#xff08;第一期&#xff09;》。总体来说今年的就业市场还是人才饱和的状态&#xff0c;竞争会比较激烈。 但是&#xff0c;…

Unity编辑器扩展,快捷键的使用

代码部分 编辑器界面 使用方法&#xff1a; 使用方法和如图1一样&#xff0c;只需要在Menuitem的路径后面加上标识符号就行。 "#"对应的是shift "&"对应的是Alt "%"对应的是ctrl 比如我图中的是&#xff0c;%#s对应的是CtrlShifts&…

基于51单片机的串口乒乓球小游戏

基于51单片机的乒乓球小游戏 &#xff08;仿真&#xff0b;程序&#xff09; 功能介绍 具体功能&#xff1a; 1.用两块单片机串口进行通信&#xff1b; 2.一排LED模拟乒乓球运动&#xff08;哪里亮表示运动到哪&#xff09;&#xff1b; 3.当最左边LED亮&#xff0c;表示球…

【java、lucene、python】互联网搜索引擎课程报告二:建立搜索引擎

一、项目要求 建立并实现文本搜索功能 对经过预处理后的500个英文和中文文档/网页建立搜索并实现搜索功能对文档建立索引&#xff0c;然后通过前台界面或者已提供的界面&#xff0c;输入关键字&#xff0c;展示搜索结果前台可通过网页形式、应用程序形式、或者利用已有的界面…