根据Pytorch源码实现的 ResNet18

 一,类模块定义: 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensorclass ResBlock(nn.Module):def __init__(self, inchannel, outchannel, stride=1) -> None:super(ResBlock, self).__init__()# 这里定义了残差块内连续的2个卷积层self.conv1 = nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(outchannel)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(outchannel)self.downsample = nn.Sequential()if stride != 1 or inchannel != outchannel:# shortcut,这里为了跟2个卷积层的结果结构一致,要做处理self.downsample = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(outchannel))def forward(self, x):out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = out + self.downsample(x)out = self.relu(out)return outclass ResNet18(nn.Module):def __init__(self, ResBlock, num_classes=1000) -> None:super(ResNet18, self).__init__()self.inchannel = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)self.layer1 = self.make_layer(ResBlock, 64, 2, stride=1)self.layer2 = self.make_layer(ResBlock, 128, 2, stride=2)self.layer3 = self.make_layer(ResBlock, 256, 2, stride=2)self.layer4 = self.make_layer(ResBlock, 512, 2, stride=2)self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))self.fc = nn.Linear(512, num_classes)def forward(self, x: Tensor) -> Tensor:out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.maxpool(out)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avgpool(out)out = out.view(out.size(0), -1)out = self.fc(out)return outdef make_layer(self, block, channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.inchannel, channels, stride))self.inchannel = channelsreturn nn.Sequential(*layers)if __name__ == '__main__':model = ResNet18(ResBlock)print(model)

二,对比Pytorch官方提供的预训练模型 加载xxx.pht文件 

# 方案一: 使用官方自带的resnet18加载预训练模型
from torchvision import models# 当 xxx.pth预训练模型不存在时,可以联网直接下载
# model = models.resnet18(weights=ResNet18_Weights.DEFAULT)   # 载入预训练模型
model = models.resnet18()# 加载与训练模型
weights_dict = torch.load('C:\\Users\\torch\\hub\\checkpoints\\resnet18-f37072fd.pth')model.load_state_dict(weights_dict, strict=True)
print(model)# 方案二: 使用自定义的ResNet18加载预训练模型
model = ResNet18(ResBlock)
weights_dict = torch.load('C:\\Users\\torch\\hub\\checkpoints\\resnet18-f37072fd.pth')model.load_state_dict(weights_dict, strict=True)
print(model)

三,用自定义的ResNet18记载Pytorch官网提供的预训练模型,训练自己的图像分类数据,完整代码 

import matplotlib.pyplot as plt
from torchvision.models import ResNet18_Weightsimport warnings
warnings.filterwarnings("ignore")   # 忽略烦人的红色提示import time
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F# 导入训练需使用的工具包
from torchvision import models
import torch.optim as optim
from torch.optim import lr_schedulerfrom sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score''' 运行一个 batch 的训练,返回当前 batch 的训练日志 '''
log_train = {}
def train_one_batch(images, labels, epoch, batch_idx):# 获得一个 batch 的数据和标注images = images.to(device)labels = labels.to(device)# images = [32, 3, 224, 224]outputs = model(images)  # 输入模型,执行前向预测(mat1 and mat2 shapes cannot be multiplied (32x25088 and 512x30))loss = criterion(outputs, labels)  # 计算当前 batch 中,每个样本的平均交叉熵损失函数值# 优化更新权重optimizer.zero_grad()loss.backward()optimizer.step()# 获取当前 batch 的标签类别和预测类别_, preds = torch.max(outputs, 1)  # 获得当前 batch 所有图像的预测类别preds = preds.cpu().numpy()loss = loss.detach().cpu().numpy()outputs = outputs.detach().cpu().numpy()labels = labels.detach().cpu().numpy()log_train['epoch'] = epochlog_train['batch'] = batch_idx# 计算分类评估指标log_train['train_loss'] = losslog_train['train_accuracy'] = accuracy_score(labels, preds)log_train['train_precision'] = precision_score(labels, preds, average='macro')log_train['train_recall'] = recall_score(labels, preds, average='macro')log_train['train_f1-score'] = f1_score(labels, preds, average='macro')return log_train''' 在整个测试集上评估,返回分类评估指标日志 '''
def evaluate_testset(epoch):loss_list = []labels_list = []preds_list = []with torch.no_grad():for images, labels in test_loader:  # 生成一个 batch 的数据和标注images = images.to(device)labels = labels.to(device)outputs = model(images)  # 输入模型,执行前向预测loss = criterion(outputs, labels)  # 由 logit,计算当前 batch 中,每个样本的平均交叉熵损失函数值# 获取整个测试集的标签类别和预测类别_, preds = torch.max(outputs, 1)  # 获得当前 batch 所有图像的预测类别preds = preds.cpu().numpy()loss = loss.detach().cpu().numpy()outputs = outputs.detach().cpu().numpy()labels = labels.detach().cpu().numpy()loss_list.append(loss)labels_list.extend(labels)preds_list.extend(preds)log_test = {}log_test['epoch'] = epoch# 计算分类评估指标log_test['test_loss'] = np.mean(loss_list)log_test['test_accuracy'] = accuracy_score(labels_list, preds_list)log_test['test_precision'] = precision_score(labels_list, preds_list, average='macro')log_test['test_recall'] = recall_score(labels_list, preds_list, average='macro')log_test['test_f1-score'] = f1_score(labels_list, preds_list, average='macro')return log_testdef saveLog():# 训练日志-训练集df_train_log = pd.DataFrame()log_train = {}log_train['epoch'] = 0log_train['batch'] = 0images, labels = next(iter(train_loader))log_train.update(train_one_batch(images, labels, 0, 0))df_train_log = df_train_log.append(log_train, ignore_index=True)# 训练日志-测试集df_test_log = pd.DataFrame()log_test = {}log_test['epoch'] = 0log_test.update(evaluate_testset(0))df_test_log = df_test_log.append(log_test, ignore_index=True)return df_train_log, df_test_logif __name__ == '__main__':ntime = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))ntime = str(ntime)device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print('device', device)from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])# 数据集文件夹路径dataset_dir = 'D:\\dl_workspace\\datasets\\fruit30_split'train_path = os.path.join(dataset_dir, 'train')test_path = os.path.join(dataset_dir, 'val')print('训练集路径', train_path)print('测试集路径', test_path)from torchvision import datasetstrain_dataset = datasets.ImageFolder(train_path, train_transform)   # 载入训练集test_dataset = datasets.ImageFolder(test_path, test_transform)  # 载入测试集# 各类别名称class_names = train_dataset.classesn_class = len(class_names)train_dataset.class_to_idx  # 映射关系:类别 到 索引号idx_to_labels = {y: x for x, y in train_dataset.class_to_idx.items()}  # 映射关系:索引号 到 类别# 保存为本地的 npy 文件# np.save('idx_to_labels.npy', idx_to_labels)# np.save('labels_to_idx.npy', train_dataset.class_to_idx)from torch.utils.data import DataLoaderBATCH_SIZE = 256# 训练集的数据加载器train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)from Utils import pyutils# 只微调训练模型最后一层(全连接分类层)# model = models.resnet18(weights=ResNet18_Weights.DEFAULT)   # 载入预训练模型# model = models.resnet18()# print(model)# print('pymodel:', pyutils.getOrderedDictKeys(model.state_dict()))from ResNet18_Model import ResNet18, ResBlock, ResNetmodel = ResNet18(ResBlock)# 给自定义模型,加载预训练模型权重,(strict=False 可以看到具有相同网络层名称的网络被初始化,不具有的网络层的参数不会被初始化)weights_dict = torch.load('C:\\Users\\Administrator/.cache\\torch\\hub\\checkpoints\\resnet18-f37072fd.pth')model.load_state_dict(weights_dict, strict=True)# 修改全连接层,使得全连接层的输出与当前数据集类别数对应(新建的层默认 requires_grad=True)# 只微调训练最后一层全连接层的参数,其它层冻结(1000分类改成30分类)model.fc = nn.Linear(model.fc.in_features, n_class)optimizer = optim.Adam(model.fc.parameters())print(model)# 训练配置model = model.to(device)criterion = nn.CrossEntropyLoss()   # 交叉熵损失函数EPOCHS = 30     # 训练轮次 Epoch(训练集当中所有的训练数据扫一遍算作一个epoch)'''学习率的降低优化策略,每经过5个epoch,学习率降低为原来的一半(lr = lr*gamma)'''lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)   # 学习率降低策略# df_train_log, df_test_log = saveLog()df_train_log = pd.DataFrame()df_test_log = pd.DataFrame()epoch = 0batch_idx = 0best_test_accuracy = 0# 运行训练for epoch in range(1, EPOCHS + 1):print(f'Epoch {epoch}/{EPOCHS}')## 训练阶段model.train()for images, labels in tqdm(train_loader):  # 获得一个 batch 的数据和标注batch_idx += 1log_train = train_one_batch(images, labels, epoch, batch_idx)df_train_log = df_train_log.append(log_train, ignore_index=True)# wandb.log(log_train)lr_scheduler.step()  # 学习率优化策略,跟新学习率## 测试阶段model.eval()    # 将模型的模式从训练模式改成评估模式log_test = evaluate_testset(epoch)  # 在整个测试集上评估,并且返回测试结果df_test_log = df_test_log.append(log_test, ignore_index=True)# wandb.log(log_test)# 保存最新的最佳模型文件if log_test['test_accuracy'] > best_test_accuracy:# 删除旧的最佳模型文件(如有)old_best_checkpoint_path = 'checkpoint/best-{:.3f}.pth'.format(best_test_accuracy)if os.path.exists(old_best_checkpoint_path):os.remove(old_best_checkpoint_path)# 保存新的最佳模型文件best_test_accuracy = log_test['test_accuracy']new_best_checkpoint_path = './checkpoint/{0}_best-{1:.3f}.pth'.format(ntime, log_test['test_accuracy'])torch.save(model, new_best_checkpoint_path)print('保存新的最佳模型', './checkpoint/{0}_best-{1:.3f}.pth'.format(ntime, best_test_accuracy))best_test_accuracy = log_test['test_accuracy']print(f'测试准确率:  {best_test_accuracy} / {epoch}')df_train_log.to_csv('训练日志-训练集-{0}.csv'.format(ntime), index=False)df_test_log.to_csv('训练日志-测试集-{0}.csv'.format(ntime), index=False)#  测试集上的准确率为 87.662 %

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/22051.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CountDownLatch的使用(判断多个线程是否都执行完毕)

使用CountDownLatch的情景 有一些”多线程下载器",可以把一个大的文件给拆分成多个小的部分,使用多个线程分别下载,每个线程负责下载一部分,此时每个线程都是一个网络连接,这样就可以大幅度提高下载速度。 假设&a…

Spring Boot使用的关键点

目录 1. 构建系统 1.1. 依赖管理 1.2. Maven Maven项目结构 1.3. Starter 2. 代码结构 2.1. “default” 包 2.2. 启动类的位置 3. Configuration 类 3.1. 导入额外的 Configuration 类 3.2. 导入 XML Configuration 4. 自动装配(配置) 4.1…

【云原生】深入掌握k8s中Pod和生命周期

个人主页:征服bug-CSDN博客 kubernetes专栏:kubernetes_征服bug的博客-CSDN博客 目录 1 什么是 Pod 2 Pod 基本操作 3 Pod 运行多个容器 4 Pod 的 Labels(标签) 5 Pod 的生命周期 1 什么是 Pod 摘取官网: Pod | Kubernetes 1.1 简介 Pod 是可以在 …

idea打开传统eclipse项目

打开传统web项目 1.打开后选择项目文件 2.选择项目结构 3.设置jdk版本 4.导入当前项目模块 5.选择eclipse 6. 设置保存目录 7.右键模块,添加spring和web文件 8. 设置web目录之类的,并且创建打包工具 9.如果有本地lib,添加为库 最后点击应用&…

Databend 开源周报第 104 期

Databend 是一款现代云数仓。专为弹性和高效设计,为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务:https://app.databend.cn 。 Whats On In Databend 探索 Databend 本周新进展,遇到更贴近你心意的 Databend 。 从 Kafka 载入数…

安装win版本的neo4j(2023最新版本)

安装win版本的neo4j 写在最前面安装 win版本的neo4j1. 安装JDK2.下载配置环境变量(也可选择直接点击快捷方式,就可以不用配环境了)3. 启动neo4j 测试代码遇到的问题及解决(每次环境都太离谱了,各种问题)连接…

Java系列之数据库geometry的相关函数

文章目录 前言一、geometry是什么?二、geometry常用函数1.OGC标准函数①管理函数②几何对象关系函数③几何对象处理函数④几何对象存取函数⑤几何对象构造函数 2.PostGIS扩展函数①管理函数②几何操作符③几何量测函数④几何对象输出⑤几何对象创建⑥几何对象编辑⑦…

小研究 - 微服务系统服务依赖发现技术综述(二)

微服务架构得到了广泛的部署与应用, 提升了软件系统开发的效率, 降低了系统更新与维护的成本, 提高了系统的可扩展性. 但微服务变更频繁、异构融合等特点使得微服务故障频发、其故障传播快且影响大, 同时微服务间复杂的调用依赖关系或逻辑依赖关系又使得其故障难以被及时、准确…

AI算法图形化编程加持|OPT(奥普特)智能相机轻松适应各类检测任务

OPT(奥普特)基于SciVision视觉开发包,全新推出多功能一体化智能相机,采用图形化编程设计,操作简单、易用;不仅有上百种视觉检测算法加持,还支持深度学习功能,能轻松应对计数、定位、…

Spring mvc:SpringServletContainerInitializer

SpringServletContainerInitializer实现了Servlet3.0规范中定义的ServletContainerInitializer&#xff1a; public interface ServletContainerInitializer {void onStartup(Set<Class<?>> c, ServletContext ctx) throws ServletException; }SpringServletCont…

Android 获取网关 ip 和 DNS ip

参考下方 PingUtil.java 代码 import android.content.Context; import android.net.DhcpInfo; import android.net.wifi.WifiManager; import android.text.format.Formatter;import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; impor…

记一次 .NET 某物流API系统 CPU爆高分析

一&#xff1a;背景 1. 讲故事 前段时间有位朋友找到我&#xff0c;说他程序CPU直接被打满了&#xff0c;让我帮忙看下怎么回事&#xff0c;截图如下&#xff1a; 看了下是两个相同的程序&#xff0c;既然被打满了那就抓一个 dump 看看到底咋回事。 二&#xff1a;为什么会打…

新人如何高效写 API 文档

什么是 API 文档&#xff1f; 在深入研究 API 文档之前&#xff0c;让我简要解释一下 API 是什么以及它的基本功能。 API 是应用程序编程接口的首字母缩写。 ​ 编辑 切换为居中 通过 API 将设备连接到数据库 无论你是初学者还是高级开发人员&#xff0c;你都会在软件开发…

Mr. Cappuccino的第53杯咖啡——Mybatis源码分析

Mybatis源码分析 Mybatis源码分析入口1. 读取配置文件总结 2. 解析配置文件核心代码&#xff08;一&#xff09;核心代码&#xff08;二&#xff09;分析parse()方法分析build()方法 总结 3. 获取SqlSession总结 4. 获取mapper代理对象总结 5. 使用mapper代理对象执行Sql语句二…

MySQL操作命令详解:增删改查

文章目录 一、CRUD1.1 数据库操作1.2 表操作1.2.1 五大约束1.2.2 创建表1.2.3 修改表1.2.3 删除表1.2.4 表数据的增删改查1.2.5 去重方式 二、高级查询2.1 基础查询2.2 条件查询2.3 范围查询2.4 判空查询2.5 模糊查询2.6 分页查询2.7 查询后排序2.8 聚合查询2.9 分组查询2.10 联…

【移动机器人运动规划】02 —— 基于采样的规划算法

文章目录 前言相关代码整理:相关文章&#xff1a; 基本概念概率路线图&#xff08;Probabilistic Road Map&#xff09;基本流程预处理阶段查询阶段 优缺点&#xff08;pros&cons&#xff09;一些改进算法Lazy collision-checking Rapidly-exploring Random Tree算法伪代码…

数据结构 10-排序4 统计工龄 桶排序/计数排序(C语言)

给定公司名员工的工龄&#xff0c;要求按工龄增序输出每个工龄段有多少员工。 输入格式: 输入首先给出正整数&#xff08;≤&#xff09;&#xff0c;即员工总人数&#xff1b;随后给出个整数&#xff0c;即每个员工的工龄&#xff0c;范围在[0, 50]。 输出格式: 按工龄的递…

【Jmeter】配置不同业务请求比例,应对综合场景压测

目录 前言 Jmeter5.0新特性 核心改进 其他变化 资料获取方法 前言 Jmeter 5.0这次的核心改进是在许多地方改进了对 Rest 的支持&#xff0c;此外还有调试功能、录制功能的增强、报告的改进等。 我也是因为迁移到了Mac&#xff0c;准备在Mac上安装Jmeter的时候发现它已经…

Java 中的 7 种重试机制

随着互联网的发展项目中的业务功能越来越复杂&#xff0c;有一些基础服务我们不可避免的会去调用一些第三方的接口或者公司内其他项目中提供的服务&#xff0c;但是远程服务的健壮性和网络稳定性都是不可控因素。 在测试阶段可能没有什么异常情况&#xff0c;但上线后可能会出…

html学习6(xhtml)

1、xhtml是以xml格式编写的html。 2、xhtml与html的文档结构区别&#xff1a; DOCTYPE是强制性的<html>、<head>、<title>、<body>也是强制性的<html>中xmlns属性是强制性的 3、 元素语法区别&#xff1a; xhtml元素必须正确嵌套xhtml元素必…