【Python】科研代码学习:十七 模型参数合并,safetensors / bin

【Python】科研代码学习:十七 模型参数合并,safetensors / bin

  • 前言
  • 解决代码
  • 知识点:safetensors 和 bin 的区别?
  • 知识点:save_pretrained 还会新增的文件
  • 知识点:在保存模型参数时,大小发生了成倍的变化

前言

  • 众所周知,LLM的模型参数一般保存在 .safetensors 或者 .bin 结尾的大文件
    在这里插入图片描述
  • 但是通过一个 RLHF 的一个训练后,使用了 FSDP 分布式训练器
    所以把文件参数保存在了 .pt 文件中
    在这里插入图片描述
  • 那么问题来了,保存的参数我如何合并到模型里去,做其他推理任务呢?

解决代码

  • 经过复杂的尝试和询问,然后使用下面的几个方法就成功了
    第一步,加载初始的模型,使用 .from_pretrained 即可加载本地模型的参数
    第二步,加载 policy.pt 里面的 state 的内容,使用 model.load_state_dict 即可使用这些参数来覆盖原始模型的参数
    第三步,保存模型参数到文件夹,使用 model.save_pretrained 即可
def FSDP_model_merge(model_path : str, pt_path : str, output_path : str):print("Loading Model")model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)print("Loading Checkpoint")model.load_state_dict(torch.load(pt_path)['state'])print("Saving Model")model.save_pretrained(output_path,safe_serialization=True, torch_dtype=torch.float16)print("Done")

知识点:safetensors 和 bin 的区别?

  • 【知乎】
    简单来说,bin 是通用的二进制存储文件
    safetensors 是更加安全的文件,专门存储张量数据
    所以这两者都可以存模型的参数
  • 如何设置保存的时候使用哪个格式?
    model.save_pretrained() 方法里面的 safe_serialization 设置成 True 的话,就会用 safetensors 格式了,注意不同 transformers 版本的该方法的 safe_serialization 的默认值是不同的(较新的版本该值默认为 True,较老的为 False
  • 看了下,貌似对于文件保存的大小来说,几乎没什么差异

知识点:save_pretrained 还会新增的文件

  • model.save_pretrained 方法调用后,在文件夹中其实还会新增/替换这几个文件:
    config.json
    generation_config.json
    model.safetensors.index.json
  • model.safetensors.index.json 的文件主要是参数和文件的存储关系映射
    以及可以从 total_size 中查看模型的参数大小
    比如这里,13476839424,除以 1 0 9 10^9 109 之后为 13 13 13,即该模型参数大小大约为 13 G 13G 13G
    然后后面可以看到保存了哪些参数权重,比如有 mlp.down_proj
    在这里插入图片描述
  • generation_config.json 主要是生成任务的参数,还有 transformers 库的版本号
    在这里插入图片描述
  • config.json 比较重要,是记录该模型的重要参数
    有模型的架构 LlamaForCausalLM,中间各种网络的参数,词汇表大小等。
    在这里插入图片描述

知识点:在保存模型参数时,大小发生了成倍的变化

  • 这次就遇到了这个问题,我一开始还以为是合并时两份参数加在一起而没有覆盖导致的
    最终文件大小加倍了
  • 但最后是发现 torch_dtype 原本是 float16,我直接保存的话类型变成了 float32,因此文件大小翻倍了
    在加载和保存处设置好数据类型即可。
    【这启示我们,对于精度类型还是得注意清楚的,比如在训练的时候使用混合精度等问题】
  • 最终发现,在model.safetensors.index.json 里面,多了一个 self_attn.rotary_emb.inv_freq 参数,但这个貌似对于内存不是特别影响,应该问题是不大的
    total_size 只打了7k多
    并且它原本是参数分成了三份,这次分成了两份,这个也会有变化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/808790.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytest教程-25-生成覆盖率报告插件-pytest-cov

领取资料,咨询答疑,请➕wei: June__Go 上一小节我们学习了pytest多重断言插件pytest-assume,本小节我们讲解一下pytest生成覆盖率报告插件pytest-cov。 测量代码覆盖率的工具在测试套件运行时观察你的代码,并跟踪哪些行被运行,…

10:00面试,10:08就出来了,问的问题有点变态。。。

从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到8月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%…

【我的小工具】生成React页面类

有了数据表的结构信息,就能生成React 的页面类,快捷方便。 生成界面如下: 生成的React FrmUser.js页面如下: 只需再写里面的操作逻辑代码。

Claude使用教程

claude 3 opus面世后,网上盛传吊打了GPT-4。网上这几天也已经有了许多应用,但竟然还有很多小伙伴不知道国内怎么用gpt,也不知道怎么去用这个据说已经吊打了gpt-4的claude3。 今天我们想要进行的一项尝试就是—— 用claude3和gpt4&#xff0c…

C语言操作符详解(三)

一、表达式求值 1.1整型提升 C语言中整型算术运算总是至少以缺省整型类型的精度来进行的。 为了获得这个精度,表达式中的字符和短整型操作数在使用之前被转换为普通整型,这种转换称为整型提升。 如何进行整型提升呢? 1. 有符号整数提升是按…

不入耳开放式耳机哪个品牌好?2024年热销榜前五名品牌推荐

为何开放式耳机近年来如此火爆?首先,开放式耳机以其开放式的声学设计,打破了传统耳机的局限,为用户带来了更加自然、宽广的音质体验。其次,随着音乐文化的普及和人们对高品质生活的追求,开放式耳机作为高端…

4.9学习总结

一.File类 (一).概述: File 类的对象代表操作系统的文件(文件、文件夹),File 类提供了诸如:创建文件对象代表文件,获取文件信息(大小、修改时间)、删除文件、创建文件(文件夹)等功…

HarmonyOS开发实例:【数字管家app】

一.概述 本应用是基于RK3399开发板,使用OpenHarmony3.1-Release开发的应用。通过OpenHarmony的分布式技术,使多人能够一起画画。 1.应用运行效果图: 2.分布式画板使用示意图 如上图所示,用户1、用户2在各自本地端进行…

Stack_经典例题_最小栈

题目: 题目分析: 在满足栈的特点的同时,还需要设计一个接口,就是获取栈内的最小元素! 解题思路: 因为是栈,所以不好遍历的!所以这题的方式不能采用遍历的方式,如果采取…

分布式锁-redission可重入锁原理

5.3 分布式锁-redission可重入锁原理 在Lock锁中,他是借助于底层的一个voaltile的一个state变量来记录重入的状态的,比如当前没有人持有这把锁,那么state0,假如有人持有这把锁,那么state1,如果持有这把锁的…

DELL VMWare R730 R740 R750 iDRAC配置与ESXI安装部署

VMware vCenter Server与ESXI版本兼容对照表 ESXI下载 VMware vcenter7.0许可证 Esxi7.0许可证 VSAN 7.0许可证 DELL VMWare R730 R740 R750 iDRAC配置与ESXI安装部署 vmware vcenter server 7.0 安装教程 1. 进入BIOS界面配置iDRAC网络 开机按F10,开机点击F10选择…

国家统计局行政区划获取及入库ES实践

我们先看下最终效果: 1. ES索引新建 PUT administrative_division {"mappings": {"properties": {"province": {"type": "keyword"},"province_code": {"type": "keyword"},&q…

docker安装oracle

程序员的公众号:源1024,获取更多资料,无加密无套路! 最近整理了一波电子书籍资料,包含《Effective Java中文版 第2版》《深入JAVA虚拟机》,《重构改善既有代码设计》,《MySQL高性能-第3版》&…

C++从入门到精通——类和对象(中篇)

1. 类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类。空类中什么都没有吗?并不是的,任何一个类在我们不写的情况下,都会自动生成下面6个默认成员函数。 class Date {}; 2. 构造函数 2.1 概念 对于以下的日期类&am…

Linux下使用C语言实现高并发服务器

高并发服务器 这一个课程的笔记 相关文章 协议 Socket编程 高并发服务器实现 线程池 使用多进程并发服务器时要考虑以下几点: 父进程最大文件描述个数(父进程中需要close关闭accept返回的新文件描述符)系统内创建进程个数(与内存大小相关)进程创建过多是否降低整体…

AI电商图制作解决方案助力企业高效营销

电商行业蓬勃发展,一张吸睛的电商海报或宣传视频往往能够成为企业吸引顾客、提升品牌形象的利器。然而,传统电商图制作流程繁琐,需要投入大量时间和人力资源,成为众多企业面临的难题。为了解决这一问题,美摄科技凭借其…

前端学习之路-项目实战(1)

每日吐槽:有一个奇怪的问题,怎么一眼看出一个求职者是否是培训班出来的,有的求职上写着,希望大家坦诚一点,but,你这艘诚实的泰坦尼克号终究还是撞上了社会阴暗面的冰山,OMG,不让包装…

Leetcode 239. 滑动窗口最大值和Leetcode 347. 前 K 个高频元素

目录标题 Leetcode 239. 滑动窗口最大值题目描述C语言代码和题解解题思路 Leetcode 347. 前 K 个高频元素题目描述C语言题解和思路解题思路 Leetcode 239. 滑动窗口最大值 题目描述 给你一个整数数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最…

Tensorflow(GPU版本配置)一步到位!!!

Tensorflow(GPU版本配置)一步到位!!! CUDA安装CUDA配置Tensorflow配置常见的包 CUDA安装 配置了N次的Tensorflow–Gpu版本,完成了踩坑,这里以配置Tensorflow_gpu 2.6.0为例子进行安装 以下为ten…

数学之光照亮AI之路:探究数学背景在人工智能学习中的优势

在科技日新月异的今天,人工智能(AI)已成为引领未来发展的重要力量。然而,对于许多初涉此领域的学习者来说,AI的复杂性和深度常常让他们望而却步。有趣的是,那些数学基础扎实的人在学习AI时,往往…