PyTorch构建自定义模型

PyTorch 提供了灵活的方式来构建自定义神经网络模型。下面我将详细介绍从基础到高级的自定义模型构建方法,包含实际代码示例和最佳实践。

一、基础模型构建

1. 继承 nn.Module 基类

所有自定义模型都应该继承 torch.nn.Module 类,并实现两个基本方法:

import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self):super().__init__()  # 必须调用父类初始化# 在这里定义网络层self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 50, 5)self.fc1 = nn.Linear(4*4*50, 500)self.fc2 = nn.Linear(500, 10)def forward(self, x):# 定义前向传播逻辑x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2, 2)x = x.view(-1, 4*4*50)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

2. 模型使用方式 

model = MyModel()
output = model(input_tensor)  # 自动调用forward方法
loss = criterion(output, target)
loss.backward()

二、中级构建技巧

1. 使用 nn.Sequential

nn.Sequential 是一种用于快速构建顺序神经网络的容器类,适用于模块按线性顺序排列的模型。

class MySequentialModel(nn.Module):def __init__(self):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.classifier = nn.Sequential(nn.Linear(128 * 8 * 8, 512),nn.ReLU(inplace=True),nn.Linear(512, 10))def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x

 2. 参数初始化

def initialize_weights(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)model.apply(initialize_weights)  # 递归应用初始化函数

三、高级构建模式

1. 残差连接 (ResNet风格)

class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1,stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)return F.relu(out)

2. 自定义层 

class MyCustomLayer(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.weight = nn.Parameter(torch.randn(output_dim, input_dim))self.bias = nn.Parameter(torch.randn(output_dim))def forward(self, x):return F.linear(x, self.weight, self.bias)

 

四、模型保存与加载

1. 保存整个模型

torch.save(model, 'model.pth')  # 保存
model = torch.load('model.pth')  # 加载

2. 保存状态字典 (推荐)

torch.save(model.state_dict(), 'model_state.pth')  # 保存
model.load_state_dict(torch.load('model_state.pth'))  # 加载

五、模型部署准备

1. 模型导出为TorchScript

scripted_model = torch.jit.script(model)  # 或 torch.jit.trace
scripted_model.save('model_scripted.pt')

2. ONNX格式导出

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"])

六、完整示例:自定义CNN分类器

import torch
from torch import nn
from torch.nn import functional as Fclass CustomCNN(nn.Module):"""自定义CNN图像分类器Args:num_classes (int): 输出类别数dropout_prob (float): dropout概率,默认0.5"""def __init__(self, num_classes=10, dropout_prob=0.5):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2))self.avgpool = nn.AdaptiveAvgPool2d((6, 6))self.classifier = nn.Sequential(nn.Dropout(p=dropout_prob),nn.Linear(128 * 6 * 6, 512),nn.ReLU(inplace=True),nn.Dropout(p=dropout_prob),nn.Linear(512, num_classes))# 初始化权重for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播Args:x (torch.Tensor): 输入张量,形状为[B, C, H, W]Returns:torch.Tensor: 输出logits,形状为[B, num_classes]"""x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

七、注意事项

  1. 输入输出维度匹配

    • 需确保相邻模块的输入/输出维度兼容。例如,卷积层后接全连接层时需通过 Flatten 或自适应池化调整维度‌。
  2. 调试与验证

    • 可通过模拟输入数据验证模型结构,如:
      input = torch.ones(64, 3, 32, 32)  # 模拟 batch_size=64 的输入
      output = model(input)
      print(output.shape)  # 检查输出形状是否符合预期

       

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/76052.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI智算-K8s如何利用GPFS分布式并行文件存储加速训练or推理

文章目录 GPFS简介核心特性存储环境介绍存储软件版本客户端存储RoCEGPFS 管理(GUI)1. 创建 CSI 用户2. 检查GUI与k8s通信文件系统配置1. 开启配额2. 启用filesetdf文件系统3. 验证文件系统配置4. 启用自动inode扩展存储集群配置1. 启用对根文件集(root fileset)配额2. igno…

gbase8s之逻辑导出导入脚本(完美版本)

该脚本dbexport.sh用于快速导出库和导入库(使用多并发unload,和多并发dbload的方式) #!/bin/sh #脚本功能:将数据导出成文本,迁移至其他实例 #最后更新时间:2023-12-19 #使用方法: #1.执行该脚…

springMVC-拦截器详解

拦截器 概述 SpringMVC的处理器拦截器类似于Servlet开发中的过滤器Filter,用于对处理器进行预处理和后处理。开发者可以自己定义一些拦截器来实现特定的功能。 过滤器与拦截器的区别:拦截器是AOP思想的具体应用。 过滤器 servlet规范中的一部分,任何ja…

网络安全应急响应-系统排查

在网络安全应急响应中,系统排查是快速识别潜在威胁的关键步骤。以下是针对Windows和Linux系统的系统基本信息排查指南,涵盖常用命令及注意事项: 一、Windows系统排查 1. 系统信息工具(msinfo32.exe) 命令执行&#x…

基于YOLO的半自动化标注方法:提升铁路视频缺陷检测效率

论文地址:https://arxiv.org/pdf/2504.01010 1. 论文结构概述 本文提出了一种半自动化标注方法,旨在解决铁路缺陷检测中大规模图像/视频数据集标注成本高、耗时长的问题。论文结构清晰,分为以下核心部分: ​引言(Introduction)​ 强调传统手动标注的痛点(耗时、易错、…

Linux驱动开发:SPI驱动开发原理

前言 本文章是根据韦东山老师的教学视频整理的学习笔记https://video.100ask.net/page/1712503 SPI 通信协议采用同步全双工传输机制,拓扑架构支持一主多从连接模式,这种模式在实际应用场景中颇为高效。其有效传输距离大致为 10m ,传输速率…

Android Hilt 教程

Android Hilt 教程 —— 一看就懂,一学就会 1. 什么是 Hilt?为什么要用 Hilt? Hilt 是 Android 官方推荐的 依赖注入(DI)框架,基于 Dagger 开发,能够大大简化依赖注入的使用。 为什么要用 Hi…

【算法手记11】NC41 最长无重复子数组 NC379 重排字符串

🦄个人主页:修修修也 🎏所属专栏:刷题 ⚙️操作环境:牛客网 目录 一.NC41 最长无重复子数组 题目详情: 题目思路: 解题代码: 二.NC379 重排字符串 题目详情: 题目思路: 解题代码: 结语 一.NC41 最长无重复子数组 牛客网题目链接(点击即可跳转):NC41 最长…

C语言:字符串处理函数strstr分析

在 C 语言中,strstr 函数用于查找一个字符串中是否存在另一个字符串。它的主要功能是搜索指定的子字符串,并返回该子字符串在目标字符串中第一次出现的位置的指针。如果没有找到子字符串,则返回 NULL。 详细说明: 头文件&#xf…

在windows下安装spark

在windows下安装spark完成 安装过程:

MongoDB常见面试题总结(上)

MongoDB 基础 MongoDB 是什么? MongoDB 是一个基于 分布式文件存储 的开源 NoSQL 数据库系统,由 C 编写的。MongoDB 提供了 面向文档 的存储方式,操作起来比较简单和容易,支持“无模式”的数据建模,可以存储比较复杂…

【Java设计模式】第2章 UML急速入门

2-1 本章导航 UML类图与时序图入门 UML定义 统一建模语言(Unified Modeling Language):第三代非专利建模语言。特点:开放方法,支持可视化构建面向对象系统,涵盖模型、流程、代码等。UML分类(2.2版本) 结构式图形:系统静态建模(类图、对象图、包图)。行为式图形:事…

【4】搭建k8s集群系列(二进制部署)之安装master节点组件(kube-apiserver)

一、下载k8s二进制文件 下载地址: https://github.com/kubernetes/kubernetes/blob/master/CHANGELOG/CHANGELOG -1.20.md 注:打开链接你会发现里面有很多包,下载一个 server 包就够了,包含了 Master 和 Worker Node 二进制文件。…

电子电气架构 --- AUTOSAR 的信息安全架构

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 周末洗了一个澡,换了一身衣服,出了门却不知道去哪儿,不知道去找谁,漫无目的走着,大概这就是成年人最深的孤独吧! 旧人不知我近况,新人不知我过…

ROS2与OpenAI Gym集成指南:从安装到自定义环境与强化学习训练

1.理解 ROS2 和 OpenAI Gym 的基本概念 ROS2(Robot Operating System 2):是一个用于机器人软件开发的框架。它提供了一系列的工具、库和通信机制,方便开发者构建复杂的机器人应用程序。例如,ROS2 可以处理机器人不同组…

【设计模式】创建型 -- 单例模式 (c++实现)

文章目录 单例模式使用场景c实现静态局部变量饿汉式(线程安全)懒汉式(线程安全)懒汉式(线程安全) 智能指针懒汉式(线程安全)智能指针call_once懒汉式(线程安全)智能指针call_onceCRTP 单例模式 单例模式是…

C语言之九九乘法表

一、代码展示 二、运行结果 三、代码分析 首先->是外层循环是小于等于9的 然后->是内层循环是小于等于外层循环的 最后->就是\n让九九乘法表的格式更加美观(当然 电脑不同 有可能%2d 也有可能%3d) 四、与以下素数题目逻辑相似 五、运行结果

自动化备份全网服务器数据平台

自动化备份全网服务器数据平台 项目背景知识 总体需求 某企业里有一台Web服务器,里面的数据很重要,但是如果硬盘坏了数据就会丢失,现在领导要求把数据做备份,这样Web服务器数据丢失在可以进行恢复。要求如下:1.每天0…

stm32+esp8266+机智云手机app

现在很多大学嵌入式毕设都要求云端控制,本文章就教一下大家如何使用esp8266去连接机智云的app去进行显示stm32的外设传感器数据啊,控制一些外设啊等。 因为本文章主要教大家如何移植机智云的代码到自己的工程,所以前面的一些准备工作&#x…

时序数据库 TDengine Cloud 私有连接实战指南:4步实现数据安全传输与成本优化

小T导读:在物联网和工业互联网场景下,企业对高并发、低延迟的数据处理需求愈发迫切。本文将带你深入了解 TDengineCloud 如何通过全托管服务与私有连接,帮助企业实现更安全、更高效、更低成本的数据采集与传输,从架构解析到实际配…