大模型中 .safetensors 文件、.ckpt文件和.pth以及.bin文件区别、加载和保存以及转换方式

目录

模型格式介绍

加载以及保存

- 加载.safetensors文件:

- 保存/加载.pth文件:

- 保存/加载.ckpt文件:

- 处理.bin文件:

模型之间的互相转换

pytorch-lightning 和 pytorch

ckpt和safetensors


模型格式介绍

在大型深度学习模型的上下文中,.safetensors.bin 和 .pth ckpt 文件的用途和区别如下:

  1. .safetensors 文件

    • 这是由 Hugging Face 推出的一种新型安全模型存储格式,特别关注模型安全性、隐私保护和快速加载。
    • 它仅包含模型的权重参数,而不包括执行代码,这样可以减少模型文件大小,提高加载速度。
    • 加载方式:使用 Hugging Face 提供的相关API来加载 .safetensors 文件,例如 safetensors.torch.load_file() 函数。
  2. ckpt文件

    • ckpt 文件是 PyTorch Lightning 框架采用的模型存储格式,它不仅包含了模型参数,还包括优化器状态以及可能的训练元数据信息,使得用户可以无缝地恢复训练或执行推理。
  3. .bin 文件

    • 通常是一种通用的二进制格式文件,它可以用来存储任意类型的数据。
    • 在机器学习领域,.bin 文件有时用于存储模型权重或其他二进制数据,但并不特指PyTorch的官方标准格式。
    • 对于PyTorch而言,如果用户自己选择将模型权重以二进制格式保存,可能会使用 .bin 扩展名,加载时需要自定义逻辑读取和应用这些权重到模型结构中。
  4. .pth 文件

    • 是 PyTorch 中用于保存模型状态的标准格式。
    • 主要用于保存模型的 state_dict,包含了模型的所有可学习参数,或者整个模型(包括结构和参数)。
    • 加载方式:使用 PyTorch 的 torch.load() 函数直接加载 .pth 文件,并通过调用 model.load_state_dict() 将加载的字典应用于模型实例。

总结起来:

  • .safetensors 侧重于安全性和效率,适合于那些希望快速部署且对安全有较高要求的场景,尤其在Hugging Face生态中。
  • .ckpt 文件是 PyTorch Lightning 框架采用的模型存储格式,它不仅包含了模型参数,还包括优化器状态以及可能的训练元数据信息,使得用户可以无缝地恢复训练或执行推理。
  • .bin 文件不是标准化的模型保存格式,但在某些情况下可用于存储原始二进制权重数据,加载时需额外处理。
  • .pth 是PyTorch的标准模型保存格式,方便模型的持久化和复用,支持完整模型结构和参数的保存与恢复。

加载以及保存

加载.safetensors文件

# 用SDXL举例
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_filebase = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "/home/bino/svul/models/sdxl/sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!# Load model.
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(ckpt, device="cuda"))
# unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")# Ensure using the same inference steps as the loaded model and CFG set to 0.
pipe("A girl smiling", num_inference_steps=4, guidance_scale=0).images[0].save("output.png")

保存/加载.pth文件

 # 保存模型状态字典torch.save(model.state_dict(), "model.pth")# 加载模型状态字典到已有模型结构中model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load("model.pth"))# 或者保存整个模型,包括结构torch.save(model, "model.pth")# 加载整个模型model = torch.load("model.pth", map_location=device)

保存/加载.ckpt文件

import pytorch_lightning as pl# 定义一个 PyTorch Lightning 训练模块
class MyLightningModel(pl.LightningModule):def __init__(self):super().__init__()self.linear_layer = nn.Linear(10, 1)self.loss_function = nn.MSELoss()def forward(self, inputs):return self.linear_layer(inputs)def training_step(self, batch, batch_idx):features, targets = batchpredictions = self(features)loss = self.loss_function(predictions, targets)self.log('train_loss', loss)return loss# 初始化 PyTorch Lightning 模型
lightning_model = MyLightningModel()# 配置 ModelCheckpoint 回调以定期保存最佳模型至 .ckpt 文件
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor='val_loss',filename='best-model-{epoch:02d}-{val_loss:.2f}',save_top_k=3,mode='min'
)# 创建训练器并启动模型训练
trainer = pl.Trainer(callbacks=[checkpoint_callback],max_epochs=10
)
trainer.fit(lightning_model)# 从 .ckpt 文件加载最优模型权重
best_model = MyLightningModel.load_from_checkpoint(checkpoint_path='best-model.ckpt')# 使用加载的 .ckpt 文件中的模型进行预测
sample_input = torch.randn(1, 10)
predicted_output = best_model(sample_input)
print(predicted_output)

在此示例中,我们首先定义了一个 PyTorch Lightning 模块,该模块集成了模型训练的逻辑。然后,我们配置了 ModelCheckpoint 回调函数,在训练过程中按照验证损失自动保存最佳模型至 .ckpt 文件。接着,我们展示了如何加载 .ckpt 文件中的最优模型权重,并利用加载后的模型对随机输入数据进行预测,同样输出预测结果。值得注意的是,由于 .ckpt 文件完整记录了训练状态,它在实际应用中常被用于模型微调和进一步训练。

处理.bin文件

如果.bin文件是纯二进制权重文件,加载时需要知道模型结构并且手动将权重加载到对应的层中,例如:

 # 假设已经从.bin文件中读取到了模型权重数据weights_data = load_binary_weights("weights.bin")# 手动初始化模型并加载权重model = TheModelClass(*args, **kwargs)for name, param in model.named_parameters():if name in weights_mapping:  # 需要预先知道权重映射关系param.data.copy_(weights_data[weights_mapping[name]])

模型之间的互相转换

pytorch-lightning 和 pytorch

由于 PyTorch Lightning 模型本身就是 PyTorch 模型,因此不存在严格意义上的转换过程。你可以直接通过 LightningModule 中定义的神经网络层来进行保存和加载,就像普通的 PyTorch 模型一样:

# 假设 model 是一个 PyTorch Lightning 模型实例
model = MyLightningModel()# 保存模型权重
torch.save(model.state_dict(), 'lightning_model.pth')# 加载到一个新的 PyTorch 模型实例
new_model = MyLightningModel()
new_model.load_state_dict(torch.load('lightning_model.pth'))# 或者加载到一个普通的 PyTorch Module 实例(假设结构一致)
plain_pytorch_model = MyPlainPytorchModel()
plain_pytorch_model.load_state_dict(torch.load('lightning_model.pth'))

ckpt和safetensors

转换后的模型在stable-diffussion-webui中使用过没有问题,不知道有没有错误,或者没转换成功

import torch
import os
import safetensors
from typing import Dict, List, Optional, Set, Tuple
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_filedef ckpt2safetensors():loaded = torch.load('v1-5-pruned-emaonly.ckpt')if "state_dict" in loaded:loaded = loaded["state_dict"]safetensors.torch.save_file(loaded, 'v1-5-pruned-emaonly.safetensors')def st2ckpt():# 加载 .safetensors 文件data = safetensors.torch.load_file('v1-5-pruned-emaonly.safetensors.bk')data["state_dict"] = data# 将数据保存为 .ckpt 文件torch.save(data, os.path.splitext('v1-5-pruned-emaonly.safetensors')[0] + '.ckpt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/726654.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Pygame教程05:帧动画原理+边界值检测,让小球来回上下运动

------------★Pygame系列教程★------------ Pygame教程01:初识pygame游戏模块 Pygame教程02:图片的加载缩放旋转显示操作 Pygame教程03:文本显示字体加载transform方法 Pygame教程04:draw方法绘制矩形、多边形、圆、椭圆、弧…

baidu, google和chatgpt -- 翻译对比

原文 That ChatGPT can automatically generate something that reads even superficially like human-written text is remarkable, and unexpected. But how does it do it? And why does it work? My purpose here is to give a rough outline of what’s going on inside…

Context

在 Android 开发中,Context 是一个表示应用程序环境的类,它提供了访问应用程序资源和执行应用程序级操作的接口。它是一个抽象类,具体的实现类是 ContextImpl。 Context 类的实例在整个 Android 应用程序中广泛使用,它可以用于执…

Linux-socket套接字

前言 在当今数字化时代,网络通信作为连接世界的桥梁,成为计算机科学领域中至关重要的一部分。理解网络编程是每一位程序员必备的技能之一,而掌握套接字编程则是深入了解网络通信的关键。本博客将深入讨论套接字编程中的基本概念、常见API以及…

国际数字影像产业园:全面推进“AI+”行动,加快标准建设,厚植创新沃土

人工智能作为数字经济时代的重要基础设施、关键技术、先导产业以及赋能引擎,将长期为我国各行业转型升级和数字经济发展提供核心驱动力。树莓集团总部国际数字影像产业园,作为新时代科技与数字产业的交汇点,正全面推进“AI”行动,…

小白在VMware Workstation Pro上安装部署SinoDB V16.8

一、安装环境说明 CPU:2核或以上,内存:2G或以上;磁盘10G或以上;网卡:千兆 1.1检查服务器内存大小 命令:free -m 1.2检查服务器磁盘空间大小 命令:df -h 1.3检查服务器网络配置信息 命…

bunx 使用文档

注意 — bunx 是 bun x 的别名。安装 bun 时,bunx CLI 将自动安装。 使用 bunx 从 npm 自动安装和运行包。它相当于 npx 或 yarn dlx。 bunx cowsay "Hello world!" ⚡️ 速度 — 由于 Bun 的启动时间很快,对于本地安装的软件包,b…

服务器防火墙和安全组放开

问题 我的项目上传后安全组也放开了但是访问项目地址404,最后发现是服务器防火墙没放行。 下面介绍一下如何排查防火墙问题。 服务器防火墙操作命令 查看防火墙状态:systemctl status firewalld 禁用防火墙:systemctl stop firewalld 启…

Linux系统安装Dashy服务结合内网穿透实现公网访问本地导航页

文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 简介 Dashy 是一个开源的自托管的导航页配置服务,具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你可以将自己常用的一些网站聚合起来放在一起,形成自己的导航…

OPENWRT本地局域网模拟域名多IP

本地配置MINIO服务时,会遇到域名多IP的需求。当某一个节点失效时,可以通过域名访问平滑过渡到其它的节点继续服务。 【MINIO搭建过程略】 搭建完毕后,有4个节点,对应的docker搭建命令: docker run --nethost --rest…

基于OpenCV的图形分析辨认05(补充)

目录 一、前言 二、实验内容 三、实验过程 一、前言 编程语言:Python,编程软件:vscode或pycharm,必备的第三方库:OpenCV,numpy,matplotlib,os等等。 关于OpenCV,num…

第十二篇 - IAB 标准技术条款和定义-我为什么要翻译介绍美国人工智能科技巨头IAB公司?

前言 这是2021年IAB公司发布的《市场营销人工智能使用案例及最佳实践报告》的最后一篇译文。翻译工作不难,但是非常考验一个人的态度,需要译者忠于自己的初心,严谨对待所有文字、数据、信息、技术和观点。时代变化如此之快,3年前…

学生信息管理展示-h5版(uniapp+springboot+vue)

记录一下做的第一个完整的h5业务。 一、登录 二、个人中心 三、首页(管理员) 四、首页(学生) 五、视频展示 学生信息管理展示(h5)完整版

人工智能英文缩写术语

人工智能术语: FFNN FFNN:feedforward neural network,前馈神经网络; RNN RNN:Recurrent Neural Network,循环神经网络 CNN CNN:Convolutional Neural Network,卷积神经网络 在…

京东数据分析平台(京东店铺数据分析工具)推荐

京东店铺数据分析能够帮助商家了解自己的经营状况,优化商品策略,提高销售效率。以下是京东店铺数据分析的一些基本步骤和方法: 首先,在进行京东店铺数据分析时,我们需要借助一些电商数据分析工具来获取相关数据&#…

Linux Ubuntu部署SVN服务端结合内网穿透实现客户端公网访问

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

2016年认证杯SPSSPRO杯数学建模C题(第一阶段)如何有效的抑制校园霸凌事件的发生解题全过程文档及程序

2016年认证杯SPSSPRO杯数学建模 C题 如何有效的抑制校园霸凌事件的发生 原题再现: 近年来,我国发生的多起校园霸凌事件在媒体的报道下引发了许多国人的关注。霸凌事件对学生身体和精神上的影响是极为严重而长远的,因此对于这些情况我们应该…

Openwrt(IstoreOS)安装iventoy

背景 目前家里有两台不用的旧主机,平时没事在家里折腾这两台机器。经常换装各种系统。最早是将镜像刷入u盘作为启动盘,这样需要重复装系统就特别麻烦。后来用了ventoy以后一个U盘可以放多个系统镜像,还能做口袋系统(SystemToGo&a…

OpenXR 超详细的spec--Instance介绍

4. Instance OpenXR instance是一个允许OpenXR application和runtime进行通信的句柄对象。application通过调用xrCreateInstance()和接收一个XrInstance对应的handle完成通信。 XrInstance对象存储和追踪OpenXR相关应用的状态,不需要在application的全局地址空间中…

红队攻击手“实战”特训

伴随着新的一年的到来,我们最新一期的红队攻防,也如约而至~ 每一期我们都会做二次学员反馈,根据同学们的真实反馈和需求,来调整讲师及授课内容 新的一期我们增加了C基础,python基础,汇编基础的课程&#…