PyTorch预训练和微调:以VGG16为例

文章目录

    • 预训练和微调代码
    • 测试结果
    • 参考来源

预训练和微调代码

数据集:CIFAR10
CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。数据集介绍来自:CIFAR10

在这里插入图片描述
图片来源:https://paperswithcode.com/dataset/cifar-10

预训练模型: vgg16

代码

# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
from tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Hyperparameters
num_classes = 10
learning_rate = 1e-3
batch_size = 1024
num_epochs = 2class Identity(nn.Module):def __init__(self):super(Identity, self).__init__()def forward(self, x):return x# Load pretrain model & modify itmodel = torchvision.models.vgg16(weights='DEFAULT')
# # If you want to do finetuning then set requires_grad = False
# # Remove these two lines if you want to train entire model,
# # and only want to load the pretrain weights.
# for param in model.parameters():
#     param.requires_grad = False
for param in model.parameters():param.requires_grad = Falsemodel.avgpool = Identity() # 站位层,使得该层啥事不做
model.classifier = nn.Sequential(nn.Linear(512, 100),nn.ReLU(),nn.Linear(100, 10)) # 修改原模型的后几层
model.to(device)# Load Data
train_dataset = datasets.CIFAR10(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Train Network
for epoch in range(num_epochs):losses = []for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):# Get data to cuda if possibledata = data.to(device=device)targets = targets.to(device=device)# forwardscores = model(data)loss = criterion(scores, targets)losses.append(loss.item())# backwardoptimizer.zero_grad()loss.backward()# gradient descent or adam stepoptimizer.step()print(f"Cost at epoch {epoch} is {sum(losses)/len(losses):.5f}")# Check accuracy on training & test to see how good our modeldef check_accuracy(loader, model):if loader.dataset.train:print("Checking accuracy on training data")else:print("Checking accuracy on test data")num_correct = 0num_samples = 0model.eval()with torch.no_grad():for x, y in loader:x = x.to(device=device)y = y.to(device=device)scores = model(x)_, predictions = scores.max(1)num_correct += (predictions == y).sum()num_samples += predictions.size(0)print(f"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%")model.train()check_accuracy(train_loader, model)

测试结果

Checking accuracy on training data
Got 29449 / 50000 with accuracy 58.90%

可以看到本次预训练模型的导入,测试结果并不理想。但并不妨碍我们对Pytorch预训练和微调的学习。

参考来源

【1】 https://www.youtube.com/watch?v=qaDe0qQZ5AQ&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=8
【2】https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_pretrain_finetune.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/666.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SolidUI AI生成可视化,0.1.0版本模块划分以及源码讲解

1.背景 随着文本生成图像的语言模型兴起,SolidUI想帮人们快速构建可视化工具,可视化内容包括2D,3D,3D场景,从而快速构三维数据演示场景。SolidUI 是一个创新的项目,旨在将自然语言处理(NLP)与计算机图形学相…

【微信小程序-uniapp】CustomDialog 居中弹窗组件

1. 效果图 2. 组件完整代码 <template><uni-popup :ref="ref" type="center" @change

Ubuntu下配置Redis哨兵集群

目录 准备实例和配置 启动哨兵集群 测试配置 搭建一个三节点形成的Sentinel集群&#xff0c;来监管Redis主从集群。 三个sentinel哨兵实例信息如下&#xff1a; 节点IPPORTs1192.168.22.13527001s2192.168.22.13527002s3192.168.22.13527003 准备实例和配置 要在同一台虚…

组合式API

文章目录 前言了解组合式API简单类型 ref封装对象类型 user.name子组件数组类型 reactive封装 组合式 API 基础练习基础练习优化 前言 Vue 3 的组合式 API&#xff08;Composition API&#xff09;是一组函数和语法糖&#xff0c;用于更灵活和可组合地组织 Vue 组件的代码逻辑…

leetcode 538. 把二叉搜索树转换为累加树

2023.7.16 这道题利用中序遍历&#xff08;右中左&#xff09;的操作不断修改节点的值即刻&#xff0c;直接看代码&#xff1a; class Solution { public:TreeNode* convertBST(TreeNode* root) {stack<TreeNode*> stk;//前面的累加值int pre_value 0;TreeNode* cur r…

review回文子串

给你一个字符串 s&#xff0c;请你将 s 分割成一些子串&#xff0c;使每个子串都是 回文串 。返回 s 所有可能的分割方案。 回文串 是正着读和反着读都一样的字符串。 class Solution {List<List<String>> lists new ArrayList<>(); // 用于存储所有可能…

【数据挖掘】时间序列教程【二】

2.4 示例:颗粒物浓度 在本章中,我们将使用美国环境保护署的一些空气污染数据作为运行样本。该数据集由 2 年和 5 年空气动力学直径小于或等于 3.2017 \(mu\)g/m\(^2018\) 的颗粒物组成。 我们将特别关注来自两个特定监视器的数据,一个在加利福尼亚州弗雷斯诺,另一个在密…

图片文字对齐 图片文字居中对齐

方法一: 用 vertical-align: middle; <view class="container"><view class="search"><image src="../../images/icon/search.png" alt="" /><text class="tex">搜索</text></view>&…

透彻!127.0.0.1和0.0.0.0之间的区别总算听明白了!

参考视频&#xff1a;透彻&#xff01;127.0.0.1和0.0.0.0之间的区别总算听明白了&#xff01;_哔哩哔哩_bilibili 0.0.0.0不是一个ip地址&#xff0c;而是一个通配符&#xff0c;通配当前主机上面所有的网卡&#xff08;包括虚拟网卡&#xff09;。

深度学习环境安装|PyCharm,Anaconda,PyTorch,CUDA,cuDNN等

本文参考了许多优秀博主的博客&#xff0c;大部分安装步骤可在其他博客中找到&#xff0c;鉴于我本人第一次安装后&#xff0c;时隔半年&#xff0c;我忘记了当时安装的许多细节和版本信息&#xff0c;所以再一次报错时&#xff0c;重装花费了大量时间。因此&#xff0c;我觉得…

Profibus DP主站转Modbus TCP网关profibus从站地址范围

远创智控YC-DPM-TCP网关。这款产品在Profibus总线侧实现了主站功能&#xff0c;在以太网侧实现了ModbusTcp服务器功能&#xff0c;为我们的工业自动化网络带来了全新的可能。 远创智控YC-DPM-TCP网关是如何实现这些功能的呢&#xff1f;首先&#xff0c;让我们来看看它的Profib…

【新版系统架构】第十七章-通信系统架构设计理论与实践

软考-系统架构设计师知识点提炼-系统架构设计师教程&#xff08;第2版&#xff09; 第一章-绪论第二章-计算机系统基础知识&#xff08;一&#xff09;第二章-计算机系统基础知识&#xff08;二&#xff09;第三章-信息系统基础知识第四章-信息安全技术基础知识第五章-软件工程…

Redis---缓存双写一致性

目录 一、什么是缓存双写一致性呢&#xff1f; 1.1 双检加锁机制 二、数据库和缓存一致性的更新策略 2.1、先更新数据库&#xff0c;后更新缓存 2.2 、先更新缓存&#xff0c;后更新数据库 2.3、先删除缓存&#xff0c;在更新数据库 延时双删的策略&#xff1a; 2.4.先更新数…

Matplotlib---3D图

1. 3D图 # 3D引擎 from mpl_toolkits.mplot3d.axes3d import Axes3D fig plt.figure(figsize(8, 5)) x np.linspace(0, 100, 400) y np.sin(x) z np.cos(x)# 三维折线图 axes Axes3D(fig, auto_add_to_figureFalse) fig.add_axes(axes) axes.plot(x,y,z) plt.savefi…

Runner 介绍

Runner 介绍 概述 Runner是用来批量调用collection里某个文件夹里的全部接口的。 (注意&#xff0c;我说的是文件夹内所有接口,可以是一级文件夹&#xff0c;也可是二级文件夹) 示意图 打开runner&#xff0c;如图所示 说明 历史记录 历史执行记录 2.导入 导入别人或之…

h5页面如何与原生交互

本文讲述h5页面跟原生通信&#xff0c;比如在app内&#xff0c;调用相机&#xff0c;获取相册内的图片&#xff0c;在app内拉起微信小程序等等&#xff0c;h5页面没有这么多权限能够直接调用移动端的原生能力&#xff0c;这个时候就需要与原生进行通讯&#xff0c;传递一个信号…

Go实现在线词典翻译(三种翻译接口,结合sync)

火山翻译 首先介绍用火山翻译英译汉。 package mainimport ("bufio""bytes""encoding/json""fmt""io""log""net/http""os""strings""unicode" )type DictRequestHS st…

单片机第一季:零基础6——按键

目录 1&#xff0c;独立按键 2&#xff0c;矩阵按键 &#xff08;注意&#xff1a;文章中的代码仅供参考学习&#xff0c;实际使用时要根据需要修改&#xff09; 1&#xff0c;独立按键 按键管脚两端距离长的表示默认是导通状态&#xff0c;距离短的默认是断开状态&#xf…

Web APIs

文章目录 1.Web APIs 和 JS 基础关联性1.1 JS 的组成 2. API 和 Web API2.1 API2.2 Web API 1.Web APIs 和 JS 基础关联性 1.1 JS 的组成 2. API 和 Web API 2.1 API **API&#xff08;Application Programming Interface,应用程序编程接口&#xff09;**是一些预先定义的函…

观察者模式(下):如何实现一个异步非阻塞的EventBus框架?

上一节课中&#xff0c;我们学习了观察者模式的原理、实现、应用场景&#xff0c;重点介绍了不同应用场景下&#xff0c;几种不同的实现方式&#xff0c;包括&#xff1a;同步阻塞、异步非阻塞、进程内、进程间的实现方式。 同步阻塞是最经典的实现方式&#xff0c;主要是为了…