【CenterFusion】模型的创建、导入、保存CenterFusion/src/lib/model/model.py

文件内容:CenterFusion/src/lib/model/model.py
文件作用:模型的创建、导入、保存

model.py 具体内容如下:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport torchvision.models as models
import torch
import torch.nn as nn
import osfrom .networks.dla import DLASeg
from .networks.resdcn import PoseResDCN
from .networks.resnet import PoseResNet
from .networks.dlav0 import DLASegv0
from .networks.generic_network import GenericNetwork_network_factory = {'resdcn': PoseResDCN,'dla': DLASeg,'res': PoseResNet,'dlav0': DLASegv0,'generic': GenericNetwork
}def create_model(arch, head, head_conv, opt=None):num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0'''处理字符串 arch = dla_34 ,将下划线后半部分取出最后 num_layers = 34'''arch = arch[:arch.find('_')] if '_' in arch else arch'''将 arch = dla_34 中下划线前半部分取出最后 arch = 'dla''''model_class = _network_factory[arch]'''根据 arch = 'dla' 获取 _network_factory 中的值最后 model_class = DLASegDLASeg 类定义在 CenterFusion/src/lib/model/networks/dla.py 第 594 行'''model = model_class(num_layers, heads=head, head_convs=head_conv, opt=opt)'''配置模型'''return modeldef load_model(model, model_path, opt, optimizer=None):start_epoch = 0'''设定初始轮次 = 0'''checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)print('loaded {}, epoch {}'.format(model_path, checkpoint['epoch']))'''torch.load() 函数:用来加载 torch.save() 保存的模型文件'''state_dict_ = checkpoint['state_dict']'''获取 checkpoint 模型文件中的 state_dict 属性这个属性存放训练过程中需要学习的权重和偏执系数state_dict 作为 python 的字典对象将每一层的参数映射成 tensor 张量需要注意的是 torch.nn.Module 模块中的 state_dict 只包含卷积层和全连接层的参数'''state_dict = {}for k in state_dict_:if k.startswith('module') and not k.startswith('module_list'):state_dict[k[7:]] = state_dict_[k]else:state_dict[k] = state_dict_[k]'''startswith(str) 函数:检测字符串 str,检测到返回 True,否则返回 False这里只执行了 else 语句,相当于保存导入模型的网络参数'''model_state_dict = model.state_dict()'''浅拷贝 main.py 中创建的新模型 DLA 的网络参数'''for k in state_dict:'''遍历导入的模型中的每层网络参数'''if k in model_state_dict:'''判断新模型的网络参数中是否有导入的模型的参数是有的,因为导入的模型也是 DLA 模型'''if (state_dict[k].shape != model_state_dict[k].shape) or \(opt.reset_hm and k.startswith('hm') and (state_dict[k].shape[0] in [80, 1])):'''第一个条件为 True其余条件全部为 False'''if opt.reuse_hm:'''不执行'''print('Reusing parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))# todo: bug in next line: both sides of < are the sameif state_dict[k].shape[0] < state_dict[k].shape[0]:model_state_dict[k][:state_dict[k].shape[0]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:model_state_dict[k].shape[0]]state_dict[k] = model_state_dict[k]elif opt.warm_start_weights:'''不执行'''try:print('Partially loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))if state_dict[k].shape[1] < model_state_dict[k].shape[1]:model_state_dict[k][:,:state_dict[k].shape[1]] = state_dict[k]else:model_state_dict[k] = state_dict[k][:,:model_state_dict[k].shape[1]]state_dict[k] = model_state_dict[k]except:print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]else:'''执行该 else 中的语句'''print('Skip loading parameter {}, required shape{}, '\'loaded shape{}.'.format(k, model_state_dict[k].shape, state_dict[k].shape))state_dict[k] = model_state_dict[k]'''将新模型的网络参数赋值给导入的模型中'''else:print('Drop parameter {}.'.format(k))for k in model_state_dict:if not (k in state_dict):print('No param {}.'.format(k))state_dict[k] = model_state_dict[k]'''给导入的模型添加没有的参数'''model.load_state_dict(state_dict, strict=False)'''使用 state_dict 反序列化模型参数字字典,用来加载模型参数将 state_dict 中的 parameters 和 buffers 复制到此 module 及其子节点中简述:给模型对象加载训练好的模型参数,即加载模型参数'''#冻结骨干网,没有执行if opt.freeze_backbone:for (name, module) in model.named_children():if name in opt.layers_to_freeze:for (name, layer) in module.named_children():for param in layer.parameters():param.requires_grad = False# 恢复优化器参数,没有执行if optimizer is not None and opt.resume:if 'optimizer' in checkpoint:start_epoch = checkpoint['epoch']start_lr = opt.lrfor step in opt.lr_step:if start_epoch >= step:start_lr *= 0.1for param_group in optimizer.param_groups:param_group['lr'] = start_lrprint('Resumed optimizer with start lr', start_lr)else:print('No optimizer parameters in checkpoint.')if optimizer is not None:'''执行该 if 语句'''return model, optimizer, start_epochelse:return modeldef save_model(path, epoch, model, optimizer=None):if isinstance(model, torch.nn.DataParallel):'''isinstance(object, classinfo) 判断一个函数 object 是否是一个已知的类型 classinfo是则返回 True,反之返回 False'''state_dict = model.module.state_dict()else:state_dict = model.state_dict()'''获取模型的参数矩阵'''data = {'epoch': epoch,'state_dict': state_dict}if not (optimizer is None):data['optimizer'] = optimizer.state_dict()'''获取模型的优化器'''torch.save(data, path)'''保存模型'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/752206.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【教学类-44-07】20240318 0-9数字描字帖 A4横版整页(宋体、黑体、文鼎虚线体、print dashed 德彪行书行楷)

背景需求: 前文制作了三种字体的A4横版数字描字帖 【教学类-44-06】20240318 0-9数字描字帖 A4横版整页&#xff08;宋体、黑体、文鼎虚线体&#xff09;-CSDN博客【教学类-44-06】20240318 0-9数字描字帖 A4横版整页&#xff08;宋体、黑体、文鼎虚线体&#xff09;https://…

将VSCode添加至右键的菜单栏

懒得bb&#xff0c;直接转发别人的博客&#xff0c;链接 但是我在win11上面弄了之后&#xff0c;除了文件夹其他格式都生效了&#xff0c;只需要在这个路径HKEY_CLASSES_ROOT\Directory\shell重复上面的操作&#xff0c;看Directory就知道是文件夹

担忧关于ChatGPT潜在风险的声音正在增强,但暂停人工智能是否明智?

深度学习算法的风险与挑战&#xff1a;ChatGPT的潜在风险引发关注 引言 随着人工智能技术的快速发展&#xff0c;特别是像ChatGPT这样的大型语言模型的广泛应用&#xff0c;人们对其潜在风险的关注也在不断升温。本文将探讨这些风险&#xff0c;并分析是否应该暂停AI的发展。…

事务、并发、锁机制的实现

配置全局事务 DATABASES {default: {ENGINE: django.db.backends.mysql,NAME: mydb,USER:root,PASSWORD:pass,HOST:127.0.0.1,PORT:3306,ATOMIC_REQUESTS: True, # 全局开启事务&#xff0c;绑定的是http请求响应整个过程# (non_atomic_requests可局部实现不让事务控制)} } …

stable diffusion webui 搭建和初步使用

官方repo: GitHub - AUTOMATIC1111/stable-diffusion-webui: Stable Diffusion web UI 关于stable-diffusion的介绍&#xff1a;Stable Diffusion&#xff5c;图解稳定扩散原理 - 知乎 一、环境搭建和启动 准备在容器里面搞一下 以 ubuntu22.04 为基础镜像&#xff0c;新建…

UnityShader(十六)凹凸映射

前言&#xff1a; 纹理的一种常见应用就是凹凸映射&#xff08;bump mapping&#xff09;。凹凸映射目的就是用一张纹理图来修改模型表面的法线&#xff0c;让模型看起来更加细节&#xff0c;这种方法不会改变模型原本的顶点位置&#xff08;也就是不会修改模型的形状&#xf…

数据结构之顺序存储-顺序表的基本操作c/c++(创建、初始化、赋值、插入、删除、查询、替换、输出)

学习参考博文&#xff1a;http://t.csdnimg.cn/Qi8DD 学习总结&#xff0c;同时更正原博主在顺序表中插入元素的错误。 数据结构顺序表——基本代码实现&#xff08;使用工具&#xff1a;VS2022&#xff09;&#xff1a; #define _CRT_SECURE_NO_WARNINGS #include <stdi…

uniapp canvas文字和元素居中

文字居中&#xff1a;ctx.textAlign "center"; 元素居中&#xff1a;ctx.arc(screenWidth / 2, 122, 40, 0, 2 * Math.PI); ctx.arc()的x轴为当前屏幕的宽度/2&#xff1b; let screenWidth 540; let screenHeight 960; // 头像 if (photoimg) {ctx.setFillSty…

gitlab cicd问题整理

1、docker设置数据目录&#xff1a; 原数据目录磁盘空间不足&#xff0c;需要更换目录&#xff1a; /etc/docker/daemon.json //写入/etc/docker/daemon.json {"data-root": "/data/docker" } 2、Dockerfile中ADD指令不生效 因为要ADD的文件被.docker…

指南:在各主流操作系统上安装与配置Apache Tomcat

指南&#xff1a;在各主流操作系统上安装与配置Apache Tomcat Apache Tomcat作为一款广受欢迎的开源Java Servlet容器&#xff0c;为用户提供了一个纯Java环境下的Web服务器和Servlet容器。本文将详细介绍如何在不同的操作系统上安装Apache Tomcat&#xff0c;并进行基本的配置…

【计算机网络】什么是http?

​ 目录 前言 1. 什么是HTTP协议&#xff1f; 2. 为什么使用HTTP协议&#xff1f; 3. HTTP协议通信过程 4. 什么是url&#xff1f; 5. HTTP报文 5.1 请求报文 5.2 响应报文 6. HTTP请求方式 7. HTTP头部字段 8. HTTP状态码 9. 连接管理 长连接与短连接 管线化连接…

smartmontools-5.43交叉编译Smartctl

嵌入式系统的sata盘经常故障&#xff0c;需要使用smatctl工具监控和诊断sata故障。 1. 从网上下载开源smartmontools-5.43包。 2. 修改makefile进行交叉编译。 由于软件包中已经包含Makefile.am&#xff0c;Makefile.in。直接运行 automake --add-missing 生成Makefile。 3.…

自动部署SSL证书到阿里云腾讯云CDN

项目地址&#xff1a;https://github.com/yxzlwz/ssl_update 项目简介 目前&#xff0c;自动申请和管理免费SSL证书的项目有很多&#xff0c;如个人正在使用的 acme.sh。然而在申请后&#xff0c;如果我们的需求不仅限于服务器本地的使用&#xff0c;证书的部署也是一件麻烦事…

Gin 框架中实现路由的几种方式介绍

本文将为您详细讲解 Gin 框架中实现路由的几种方式&#xff0c;并给出相应的简单例子。Gin 是一个高性能的 Web 框架&#xff0c;用于构建后端服务。在 Web 应用程序中&#xff0c;路由是一种将客户端请求映射到特定处理程序的方法。以下是几种常见的路由实现方式&#xff1a; …

JavaScript | 检测文档在垂直方向已滚动的像素值用pageYOffset在webstorm上显示弃用了,是否应该继续使用?还是用其他替代?

在学习JavaScript的时候&#xff0c;深入学习时会遇到一些实际案例需要检测文档在垂直方向已滚动的像素值。 例如&#xff0c;当前页面内容很多&#xff0c;我想要滚动鼠标滑轮或者拖拽滚动条来浏览网页下面的内容。这时候一动滚动条&#xff0c;一些绝对固定的盒子却想要随着…

python图形化编程turtle小乌龟

文章目录&#xff1a; 一&#xff1a;导入包&#xff08;常用的&#xff09; 二&#xff1a;布局 1.设置世界坐标系 2.窗体 3.画布屏幕screen 三&#xff1a;线条画笔海龟 1.运动 2.样式 3.外观 4.其他 四&#xff1a;颜色 五&#xff1a;文字 六&#xff1a;图…

【Kubernetes】k8s删除master节点后重新加入集群

目录 前言一、思路二、实战1.安装etcdctl指令2.重置旧节点的k8s3.旧节点的的 etcd 从 etcd 集群删除4.在 master03 上&#xff0c;创建存放证书目录5.把其他控制节点的证书拷贝到 master01 上6.把 master03 加入到集群7.验证 master03 是否加入到 k8s 集群&#xff0c;检查业务…

Unity触发器的使用

1.首先建立两个静态精灵&#xff08;并给其中一个物体添加"jj"标签&#xff09; 2.添加触发器 3.给其中一个物体添加刚体组件&#xff08;如果这里是静态的碰撞的时候将不会触发效果&#xff0c;如果另一个物体有刚体可以将它移除&#xff0c;或者将它的刚体属性设置…

c++pair的用法

pair简单来说就是可以存储两种类型数据的一个类&#xff0c;其内部是使用模板实现的&#xff0c;所以可以指定其内部的类型。 pair在#include <utility> pair的构造 pair<int, string> p1({ 1,"张三" });pair<int, string> p2;pair<int, str…

文件的基础

一、文件 什么是文件 文件流&#xff1a; 一、1、文件的相关操作 创建文件的三种方式&#xff1a; public class FileCreate {public static void main(String[] args) {}//方式1 new File(String pathname)Testpublic void create01() {String filePath "e:\\news1.…