pytorch——保存‘类别名与类别数量’到权值文件中

前言

不知道大家有没有像我一样,每换一次不一样的模型,就要输入不同的num_classes和name_classes,反正我是很头疼诶,尤其是项目里面不止一个模型的时候,更新的时候看着就很头疼,然后就想着直接输入模型权值文件的path该多好,然后我就搞起来了。

在自己的类中加入想要加入数据信息

class your_nets(nn.Module):def __init__(self, num_classes = 21,name_classes=None):super(your_nets, self).__init__()self.num_classes = num_classesself.name_classes = name_classes

训练过程之保存文件

      
model = your_nets(num_classes=num_classes, name_classes=name_classes)save_dict = {'state_dict': model.state_dict(),'num_classes': model.num_classes,'name_classes': model.name_classes}torch.save(save_dict, os.path.join(save_dir, "best_epoch_weights.pth"))

使用 

model = get_nets_class(model_path=model_path)class get_nets_class(object):def __init__(self ,**kwargs):device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')load_dict  = torch.load(self.model_path, map_location=device)state_dict =load_dict['state_dict']num_classes = load_dict['num_classes']name_classes = load_dict['name_classes']if num_classes is not None and name_classes is not None:self.num_classes =num_classesself.name_classes = name_classesself.net = your_nets(num_classes=self.num_classes,name_classes=name_classe)self.net.load_state_dict(state_dict)else:self.net = your_nets(num_classes=self.num_classes, backbone=self.backbone)self.net.load_state_dict(load_dict)self.net = self.net.eval()def predict(self,image,name_classes,object_list):#你的预处理操作,没有就忽略image_data = preprocess(image)with torch.no_grad():# 推理pr = self.net(images)[0]# softmax 得出概率 pr.permute(1, 2, 0), dim=-1为我自己的操作,没有请忽略pr = F.softmax(pr.permute(1, 2, 0), dim=-1).cpu().numpy()#你的后处理操作,没有就忽略pr = postprocess(pr)#这一步与object_list有关 object_list是你想要模型去预测的内容# 例如你训练了识别cat、dog、pig、person的类别 那么你想只识别人,那么就object_list=['person'] if object_list is not None:model_object_list = [name_classes.index(i) for i in object_list if i in name_classes]temp_list = [i for i in range(len(name_classes))]remove_list = [i for i in temp_list if i not in model_object_list]for i in remove_list:pr[pr==i] = 0retuen pr

我是觉得已经很详细了,大家要是不懂可以再问,我可以慢慢改进,每个人的写法都不一样 。

欢迎大家点赞加收藏哟~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/669985.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【极简】Pytorch中的register_buffer()

register buffer 定义模型能用torch.save保存的、但是不更新参数。 使用:只要是nn.Module的子类就能直接self.调用使用: class A(nn.Module): #... self.register_buffer(betas, torch.linspace(beta_1, beta_T, T).double()) #...手动定义参数 上述…

C++之字符串

C风格字符串 字符串处理在程序中应用广泛&#xff0c;C风格字符串是以\0&#xff08;空字符&#xff09;来结尾的字符数组。对字符串进行操作的C函数定义在头文件<string.h>或中。常用的库函数如下&#xff1a; //字符检查函数(非修改式操作) size_t strlen( const char …

linux麒麟系统安装mongodb7.0

1.mogedb下载 下载的是他tar包 https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-rhel80-7.0.5.tgz wget -o https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-rhel80-7.0.5.tgz 也可以下载rpm包 2.将包上传至服务器并解压 #进入目录 并解压 cd /opt/ tar …

【 buuctf-另外一个世界】

flag 就隐藏在这一串二进制数中&#xff0c;可以利用在线工具转换得到 flag&#xff0c;本次讲一下用代码怎么转换。将二进制数转换成 ascii 字母&#xff0c;手写的话两种思路&#xff1a; 1.将二进制数四位一组先转成十六进制数&#xff0c;再将十六进制数两位一组&#xff…

JavaScript的表单、控件

&#x1f9d1;‍&#x1f393; 个人主页&#xff1a;《爱蹦跶的大A阿》 &#x1f525;当前正在更新专栏&#xff1a;《VUE》 、《JavaScript保姆级教程》、《krpano》、《krpano中文文档》 ​ ​ ✨ 前言 表单是 web 开发中不可或缺的一部分&#xff0c;用于收集用户输入。本…

React 错误边界组件 react-error-boundary 源码解析

文章目录 捕获错误 hook创建错误边界组件 Provider定义错误边界组件定义边界组件状态捕捉错误渲染备份组件重置组件通过 useHook 控制边界组件 捕获错误 hook getDerivedStateFromError 返回值会作为组件的 state 用于展示错误时的内容 componentDidCatch 创建错误边界组件 P…

vscode 突然连接不上服务器了(2024年版本 自动更新从1.85-1.86)

vscode日志 ll192.168.103.5s password:]0;C:\WINDOWS\System32\cmd.exe [17:09:16.886] Got some output, clearing connection timeout [17:09:16.887] Showing password prompt [17:09:19.688] Got password response [17:09:19.688] "install" wrote data to te…

AE2023 After Effects 2023

After Effects 2023是一款非常强大的视频编辑软件&#xff0c;提供了许多新功能和改进&#xff0c;使得视频编辑和合成更加高效和灵活。以下是一些After Effects 2023的特色功能&#xff1a; 新合成预设列表&#xff1a;After Effects 2023彻底修改了预设列表&#xff0c;使其…

Spring Boot项目整合Seata AT模式

目录 1、添加依赖2.、配置Seata3、创建AT模式表4、使用Seata分布式事务 1、添加依赖 <dependency><groupId>io.seata</groupId><artifactId>seata-spring-boot-starter</artifactId></dependency>上述依赖适用于springboot项目 如果你的项…

【Iceberg学习五】Iceberg中性能和可靠性保证

Performance 性能 Iceberg 旨在处理巨大的表格&#xff0c;在生产环境中使用&#xff0c;单个表格可以包含数十PB&#xff08;拍字节&#xff09;的数据。即使是多PB级别的表格&#xff0c;也可以从单个节点读取&#xff0c;无需依赖分布式SQL引擎来筛查表格元数据。 扫描计…

算法——二分查找算法

1. 二分算法是什么&#xff1f; 简单来说&#xff0c;"二分"指的是将查找的区间一分为二&#xff0c;通过比较目标值与中间元素的大小关系&#xff0c;确定目标值可能在哪一半区间内&#xff0c;从而缩小查找范围。这个过程不断重复&#xff0c;每次都将当前区间二分…

Kafka零拷贝技术与传统数据复制次数比较

读Kafka技术书遇到困惑: "对比传统的数据复制和“零拷贝技术”这两种方案。假设有10个消费者&#xff0c;传统复制方式的数据复制次数是41040次&#xff0c;而“零拷贝技术”只需110 11次&#xff08;一次表示从磁盘复制到页面缓存&#xff0c;另外10次表示10个消费者各自…

黑马头条 Kafka

我是南城余&#xff01;阿里云开发者平台专家博士证书获得者&#xff01; 欢迎关注我的博客&#xff01;一同成长&#xff01; 一名从事运维开发的worker&#xff0c;记录分享学习。 专注于AI&#xff0c;运维开发&#xff0c;windows Linux 系统领域的分享&#xff01; 知…

C语言中的结构体

在C语言中&#xff0c;结构体&#xff08;struct&#xff09;是一种可以封装多个不同类型数据的数据结构。通过使用结构体&#xff0c;我们可以将多个相关的变量组合成一个单一的实体&#xff0c;从而方便地进行管理和操作。在本篇博客中&#xff0c;我们将通过示例代码来详细探…

4.0 Zookeeper Java 客户端搭建

本教程使用的 IDE 为 IntelliJ IDEA&#xff0c;创建一个 maven 工程&#xff0c;命名为 zookeeper-demo&#xff0c;并且引入如下依赖&#xff0c;可以自行在maven中央仓库选择合适的版本&#xff0c;介绍原生 API 和 Curator 两种方式。 IntelliJ IDEA 相关介绍&#xff1a;…

macOS Sonoma 14系统安装包

macOS Sonoma 14是苹果公司最新推出的操作系统&#xff0c;为Mac用户带来了全新的使用体验。Sonoma是苹果继Catalina之后的又一重要更新&#xff0c;它在改善系统性能、增加新功能、优化用户界面等方面做出了显著贡献。 macOS Sonoma 14系统有许多令人兴奋的新功能和改进&…

【DDD】学习笔记-数据模型与对象模型

在建立数据设计模型时&#xff0c;我们需要注意表设计与类设计之间的差别&#xff0c;这事实上是数据模型与对象模型之间的差别。 数据模型与对象模型 我们首先来分析在设计时对冗余的考虑。前面在讲解数据分析模型时就提及&#xff0c;在确定数据项模型时&#xff0c;需要遵…

Sentinel(理论版)

Sentinel 1.什么是Sentinel Sentinel 是一个开源的流量控制组件&#xff0c;它主要用于在分布式系统中实现稳定性与可靠性&#xff0c;如流量控制、熔断降级、系统负载保护等功能。简单来说&#xff0c;Sentinel 就像是一个交通警察&#xff0c;它可以根据系统的实时流量&…

Windows启动一个进程CreateProcess

CreateProcess 函数创建独立于创建进程运行的新进程。 参数接口BOOL CreateProcessA([in, optional] LPCSTR lpApplicationName,[in, out, optional] LPSTR lpCommandLine,[in, optional] LPSECURITY_ATTRIBUTES lpProcessAttributes…

算法学习——LeetCode力扣链表篇1

算法学习——LeetCode力扣链表篇1 203. 移除链表元素 203. 移除链表元素 - 力扣&#xff08;LeetCode&#xff09; 描述 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 示例 …