使用 Hugging Face 的 Transformers 库加载预训练模型遇到的问题

题意:

Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch

这个错误信息 "Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch" 通常出现在使用 Hugging Face 的 Transformers 库加载预训练模型时,模型的某些参数与预训练模型检查点(checkpoint)中的参数形状不匹配。

问题背景:

I want to finetune an LLM. I am able to successfully finetune LLM. But when reload the model after save, gets error. Below is the code

import argparse
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLMfrom trl import DPOTrainer, DPOConfig
def preprocess_data(item):return {'prompt': 'Instruct: ' + item['prompt'] + '\n','chosen': 'Output: ' + item['chosen'],'rejected': 'Output: ' + item['rejected']}        def main():parser = argparse.ArgumentParser()parser.add_argument("--epochs", type=int, default=1)parser.add_argument("--beta", type=float, default=0.1)parser.add_argument("--batch_size", type=int, default=4)parser.add_argument("--lr", type=float, default=1e-6)parser.add_argument("--seed", type=int, default=2003)parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-14m")parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1")parser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()# Determine device based on local_rankdevice = torch.device("cuda", args.local_rank) if torch.cuda.is_available() else torch.device("cpu")tokenizer = AutoTokenizer.from_pretrained(args.model_name)tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)dataset = load_dataset(args.dataset_name, split="train")dataset = dataset.map(preprocess_data)# Split the dataset into training and validation setsdataset = dataset.train_test_split(test_size=0.1, seed=args.seed)train_dataset = dataset['train']val_dataset = dataset['test']training_args = DPOConfig(learning_rate=args.lr,num_train_epochs=args.epochs,per_device_train_batch_size=args.batch_size,logging_steps=10,remove_unused_columns=False,max_length=1024,max_prompt_length=512,fp16=True        )# Verify and print embedding dimensions before finetuningprint("Base model embedding dimension:", model.config.hidden_size)model.train()ref_model.eval()dpo_trainer = DPOTrainer(model,ref_model,beta=args.beta,train_dataset=train_dataset,eval_dataset=val_dataset,tokenizer=tokenizer,args=training_args,)dpo_trainer.train()# Evaluateevaluation_results = dpo_trainer.evaluate()print("Evaluation Results:", evaluation_results)save_model_name = 'finetuned_model'model.save_pretrained(save_model_name)if __name__ == "__main__":main()

Error I was getting as below

return model_class.from_pretrained(File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3838, in from_pretrained) = cls._load_pretrained_model(File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4349, in _load_pretrained_modelraise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")RuntimeError: Error(s) in loading state_dict for GPTNeoXForCausalLM:size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

After finetuning, model works perfectly. But after reloading the saved trained model its not working. Any idea why gets this error when reloading the model ?

问题解决:

Instead of

model.save_pretrained(save_model_name)

try this

dpo_trainer.save_model(save_model_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/869411.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Elasticsearch基础(四):Elasticsearch语法与案例介绍

文章目录 Elasticsearch语法与案例介绍 一、Restful API 二、查询语法 1、ES分词器 2、ES查询 2.1、match 2.2、match_phrase 2.3、multi_match 2.4、term 2.5、terms 2.6、fuzzy 2.7、range 2.8、bool Elasticsearch语法与案例介绍 一、Restful API Elastics…

服务攻防——中间件Jboss

文章目录 一、Jboss简介二、Jboss渗透2.1 JBoss 5.x/6.x 反序列化漏洞(CVE-2017-12149)2.2 JBoss JMXInvokerServlet 反序列化漏洞(CVE-2015-7501)2.3 JBossMQ JMS 反序列化漏洞(CVE-2017-7504)2.4 Adminis…

Java如何自定义注解及在SpringBoot中的应用

注解 注解(Annotation),也叫元数据。一种代码级别的说明。它是JDK1.5及以后版本引入的一个特性,与类、接口、枚举是在同一个层次。它可以声明在包、类、字段、方法、局部变量、方法参数等的前面,用来对这些元素进行说…

在Spring Boot中实现RESTful API设计

在Spring Boot中实现RESTful API设计 大家好,我是微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 1. RESTful API简介 1.1 什么是RESTful API? RESTful API是一种设计风格,基于HTTP协议…

leetcode:LCR 018. 验证回文串(python3解法)

难度:简单 给定一个字符串 s ,验证 s 是否是 回文串 ,只考虑字母和数字字符,可以忽略字母的大小写。 本题中,将空字符串定义为有效的 回文串 。 示例 1: 输入: s "A man, a plan, a canal: Panama" 输出: t…

【C++】开源:坐标转换和大地测量GeographicLib库配置使用

😏★,:.☆( ̄▽ ̄)/$:.★ 😏 这篇文章主要介绍坐标转换和大地测量GeographicLib库配置使用。 无专精则不能成,无涉猎则不能通。——梁启超 欢迎来到我的博客,一起学习,共同进步。 喜欢的朋友可以关…

Effective C++笔记之二十一:One Definition Rule(ODR)

ODR细节有点复杂,跨越各种情况。基本内容如下: ●普通(非模板)的noninline函数和成员函数、noninline全局变量、静态数据成员在整个程序中都应当只定义一次。 ●class类型(包括structs和unions)、模板&…

2018-2022 年份微博签到数据集

前阵子接到一个实验室老师的需求,采集五年前(2024-52019)过年前后的北京微博签到数据。 前两年采集的深圳签到数据是 2022 年是当年的尚可,这次虽然时间跨度只有两个月,但是由于时间太过久远,但是颇费了一…

Spring学习04-[Spring容器核心技术AOP学习]

AOP学习 AOP介绍使用对业务方法添加计算时间的增强 EnableAspectJAutoProxyAOP的术语通知前置通知Before后置通知After返回通知AfterReturning异常通知AfterThrowing总结-通知执行顺序 切点表达式的提取-使用Pointcut进行抽取切点表达式的详细用法execution和annotation组合 Sp…

STM32快速搭建项目框架

注:编写本博客的原因,学习期间基于复习之前知识点的需要,故撰写本教程,即是复习前面的知识点也是作为博客的补充 1.0 文件夹的创建 创建一个STM32项目为模版工程,问价夹下分别包含4个子文件夹,一个是Librar…

Java 使用POI 读取Excel 中的全部图片,非嵌入式图片

示例代码 excel文件格式 : xlsx public static Map<String, XSSFPictureData> getPictures(Sheet sheet) throws Exception {Map<String, XSSFPictureData> pictures new HashMap<>();// 对于XLSX文件if (sheet instanceof XSSFSheet) {XSSFDrawing drawin…

嘉立创EDA学习笔记

嘉立创EDA学习笔记 PCB引线一、设计规则间距安全间距其他间距 物理导线网络长度差分对过孔尺寸 平面铺铜 PCB布线 作为一个嵌入式开发潜力工程师&#xff0c;咱们必须得学会如何绘制开发板以满足顾客各种功能的需求&#xff0c;因此小编去学习了一下嘉立创&#xff0c;写这篇文…

VSCode用ssh连接ubuntu虚拟机实现远程访问文件夹

1. ubuntu安装ssh服务 1.1 安装 sudo apt-get install ssh sudo apt-get install openssh-server1.2 启动ssh服务 sudo service ssh start sudo service ssh status # 查看状态 ## 或者用下面方式重启ssh服务 ## /etc/init.d/ssh restart1.3 ssh服务加入开机启动 sudo syst…

C++实现对结构体信息排序

思路解读&#xff1a; 定义结构体 Student: 结构体 Student 用来表示学生信息&#xff0c;包含两个成员变量&#xff1a;name&#xff08;学生姓名&#xff09;和 score&#xff08;学生分数&#xff09;。Student 结构体定义了一个构造函数&#xff0c;用于初始化 name 和 sco…

代码随想录算法训练营第7天

454.四数相加 题目链接&#xff1a;454. 四数相加 II - 力扣&#xff08;LeetCode&#xff09; 视频/文档链接&#xff1a;代码随想录 (programmercarl.com) 第一想法 遍历数组num1,num2&#xff0c;计算其和出现的数量&#xff0c;放入map集合中&#xff0c;键为和&#xff0…

HTML语言常见标签

语法 HEAD部分的HTML标签 1 标题标签 <title>标题内容</title> 2 段落标签 <meta charset"utf-8"/> BODY部分的HTML标签 1标题标签&#xff08;独占一行&#xff09;<h1>标题内容</h1> 2段落标签&#xff08;独占一行&#xff09;…

大模型一些概念的理解 - 线性层、前向传播、后向传播

文章目录 前言一、线性层1. 什么是线性层&#xff1f;2. 通俗解释3. 示例 二、前向传播1. 什么是前向传播&#xff1f;2. 通俗解释3. 示例 三、后向传播1. 什么是后向传播&#xff1f;2. 通俗解释3. 具体步骤 四、示例五、在 PyTorch 中的后向传播 前言 最近提问里有问到一些名…

TK 检查输入框是否为空

在Python的Tkinter库中&#xff0c;你可以使用事件绑定或者在按钮点击事件中检查输入框的值是否为空来实现这个功能。以下是一个简单的例子&#xff1a; import tkinter as tk from tkinter import messageboxdef check_input():entry input_box.get()if not entry:messagebo…

TLP152 光耦合器:工程师的可靠选择

东芝的 TLP152 光耦合器是一款稳健且多功能的组件&#xff0c;能够满足各种高速和高可靠性应用中的工程师需求。本文将深入探讨 TLP152 的技术特性、优点和应用&#xff0c;突出其在市场中的独特性。 主要特点和规格 TLP152 光耦合器集成了一颗铝镓砷&#xff08;GaAlAs&…

C#字符串操作:判断一个字符串按特定字符分割后的子字符串是否有重复的几种常用方法

C#判断一个字符串按特定字符分割后的子字符串是否有重复的几种常用方法&#xff1a; 方法一&#xff1a;使用 LINQ 你可以使用 LINQ 来简化检查重复子字符串的过程&#xff1a; using System; using System.Linq;class Program {static void Main(){string input "CCT…