PyTorch预训练和微调:以VGG16为例

文章目录

    • 预训练和微调代码
    • 测试结果
    • 参考来源

预训练和微调代码

数据集:CIFAR10
CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。数据集介绍来自:CIFAR10

在这里插入图片描述
图片来源:https://paperswithcode.com/dataset/cifar-10

预训练模型: vgg16

代码

# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
from tqdm import tqdmdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")# Hyperparameters
num_classes = 10
learning_rate = 1e-3
batch_size = 1024
num_epochs = 2class Identity(nn.Module):def __init__(self):super(Identity, self).__init__()def forward(self, x):return x# Load pretrain model & modify itmodel = torchvision.models.vgg16(weights='DEFAULT')
# # If you want to do finetuning then set requires_grad = False
# # Remove these two lines if you want to train entire model,
# # and only want to load the pretrain weights.
# for param in model.parameters():
#     param.requires_grad = False
for param in model.parameters():param.requires_grad = Falsemodel.avgpool = Identity() # 站位层,使得该层啥事不做
model.classifier = nn.Sequential(nn.Linear(512, 100),nn.ReLU(),nn.Linear(100, 10)) # 修改原模型的后几层
model.to(device)# Load Data
train_dataset = datasets.CIFAR10(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# Train Network
for epoch in range(num_epochs):losses = []for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):# Get data to cuda if possibledata = data.to(device=device)targets = targets.to(device=device)# forwardscores = model(data)loss = criterion(scores, targets)losses.append(loss.item())# backwardoptimizer.zero_grad()loss.backward()# gradient descent or adam stepoptimizer.step()print(f"Cost at epoch {epoch} is {sum(losses)/len(losses):.5f}")# Check accuracy on training & test to see how good our modeldef check_accuracy(loader, model):if loader.dataset.train:print("Checking accuracy on training data")else:print("Checking accuracy on test data")num_correct = 0num_samples = 0model.eval()with torch.no_grad():for x, y in loader:x = x.to(device=device)y = y.to(device=device)scores = model(x)_, predictions = scores.max(1)num_correct += (predictions == y).sum()num_samples += predictions.size(0)print(f"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%")model.train()check_accuracy(train_loader, model)

测试结果

Checking accuracy on training data
Got 29449 / 50000 with accuracy 58.90%

可以看到本次预训练模型的导入,测试结果并不理想。但并不妨碍我们对Pytorch预训练和微调的学习。

参考来源

【1】 https://www.youtube.com/watch?v=qaDe0qQZ5AQ&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=8
【2】https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_pretrain_finetune.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/666.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SolidUI AI生成可视化,0.1.0版本模块划分以及源码讲解

1.背景 随着文本生成图像的语言模型兴起,SolidUI想帮人们快速构建可视化工具,可视化内容包括2D,3D,3D场景,从而快速构三维数据演示场景。SolidUI 是一个创新的项目,旨在将自然语言处理(NLP)与计算机图形学相…

【微信小程序-uniapp】CustomDialog 居中弹窗组件

1. 效果图 2. 组件完整代码 <template><uni-popup :ref="ref" type="center" @change

Ubuntu下配置Redis哨兵集群

目录 准备实例和配置 启动哨兵集群 测试配置 搭建一个三节点形成的Sentinel集群&#xff0c;来监管Redis主从集群。 三个sentinel哨兵实例信息如下&#xff1a; 节点IPPORTs1192.168.22.13527001s2192.168.22.13527002s3192.168.22.13527003 准备实例和配置 要在同一台虚…

Google Bard 拓展与归纳

导言&#xff1a; Bard&#xff08;谷歌人工智能语言模型“https://bard.google.com”&#xff09;在不断演进和改进中&#xff0c;为用户提供了更丰富、便捷和个性化的服务体验。本文集将深入探索 Bard 在不同方面的关键更新&#xff0c;包括语言支持扩大、图像呈现、交互方式…

组合式API

文章目录 前言了解组合式API简单类型 ref封装对象类型 user.name子组件数组类型 reactive封装 组合式 API 基础练习基础练习优化 前言 Vue 3 的组合式 API&#xff08;Composition API&#xff09;是一组函数和语法糖&#xff0c;用于更灵活和可组合地组织 Vue 组件的代码逻辑…

leetcode 538. 把二叉搜索树转换为累加树

2023.7.16 这道题利用中序遍历&#xff08;右中左&#xff09;的操作不断修改节点的值即刻&#xff0c;直接看代码&#xff1a; class Solution { public:TreeNode* convertBST(TreeNode* root) {stack<TreeNode*> stk;//前面的累加值int pre_value 0;TreeNode* cur r…

review回文子串

给你一个字符串 s&#xff0c;请你将 s 分割成一些子串&#xff0c;使每个子串都是 回文串 。返回 s 所有可能的分割方案。 回文串 是正着读和反着读都一样的字符串。 class Solution {List<List<String>> lists new ArrayList<>(); // 用于存储所有可能…

【数据挖掘】时间序列教程【二】

2.4 示例:颗粒物浓度 在本章中,我们将使用美国环境保护署的一些空气污染数据作为运行样本。该数据集由 2 年和 5 年空气动力学直径小于或等于 3.2017 \(mu\)g/m\(^2018\) 的颗粒物组成。 我们将特别关注来自两个特定监视器的数据,一个在加利福尼亚州弗雷斯诺,另一个在密…

图片文字对齐 图片文字居中对齐

方法一: 用 vertical-align: middle; <view class="container"><view class="search"><image src="../../images/icon/search.png" alt="" /><text class="tex">搜索</text></view>&…

webSocket前端+webSocket封装

一、websocket基础 /*** 初始化websocket连接*/ function initWebSocket() {let uId 1;var websocket null;if(WebSocket in window) {websocket new WebSocket("ws://localhost:8009/webSocket"uId );//请求的地址} else {alert("该浏览器不支持websocket&…

透彻!127.0.0.1和0.0.0.0之间的区别总算听明白了!

参考视频&#xff1a;透彻&#xff01;127.0.0.1和0.0.0.0之间的区别总算听明白了&#xff01;_哔哩哔哩_bilibili 0.0.0.0不是一个ip地址&#xff0c;而是一个通配符&#xff0c;通配当前主机上面所有的网卡&#xff08;包括虚拟网卡&#xff09;。

深度学习环境安装|PyCharm,Anaconda,PyTorch,CUDA,cuDNN等

本文参考了许多优秀博主的博客&#xff0c;大部分安装步骤可在其他博客中找到&#xff0c;鉴于我本人第一次安装后&#xff0c;时隔半年&#xff0c;我忘记了当时安装的许多细节和版本信息&#xff0c;所以再一次报错时&#xff0c;重装花费了大量时间。因此&#xff0c;我觉得…

全志F1C200S嵌入式驱动开发(解决reboot失败的问题)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing @163.com】 上一次做了rootfs之后,就马不停蹄地测试了几个常用的命令。比如cd、ls、date、time、reboot这样的命令。其他命令测试结果都还好,就是这个reboot命令当死就没有生效,现场的打印结…

初级 - 若依框架 - Java Spring/Spring Boot 项目理解记录

1、Autowired 自动装配的理解 一般情况下&#xff0c;我们创建对象都是 类名 类引用名 new 类名() 但是如果是不想要 等于号后面的对象实例化操作&#xff0c;那么可以使用 Autowired 注解&#xff0c;当然这是在使用 Spring 时&#xff0c;才能这样&#xff0c;不然一般情…

Profibus DP主站转Modbus TCP网关profibus从站地址范围

远创智控YC-DPM-TCP网关。这款产品在Profibus总线侧实现了主站功能&#xff0c;在以太网侧实现了ModbusTcp服务器功能&#xff0c;为我们的工业自动化网络带来了全新的可能。 远创智控YC-DPM-TCP网关是如何实现这些功能的呢&#xff1f;首先&#xff0c;让我们来看看它的Profib…

【新版系统架构】第十七章-通信系统架构设计理论与实践

软考-系统架构设计师知识点提炼-系统架构设计师教程&#xff08;第2版&#xff09; 第一章-绪论第二章-计算机系统基础知识&#xff08;一&#xff09;第二章-计算机系统基础知识&#xff08;二&#xff09;第三章-信息系统基础知识第四章-信息安全技术基础知识第五章-软件工程…

Redis---缓存双写一致性

目录 一、什么是缓存双写一致性呢&#xff1f; 1.1 双检加锁机制 二、数据库和缓存一致性的更新策略 2.1、先更新数据库&#xff0c;后更新缓存 2.2 、先更新缓存&#xff0c;后更新数据库 2.3、先删除缓存&#xff0c;在更新数据库 延时双删的策略&#xff1a; 2.4.先更新数…

Matplotlib---3D图

1. 3D图 # 3D引擎 from mpl_toolkits.mplot3d.axes3d import Axes3D fig plt.figure(figsize(8, 5)) x np.linspace(0, 100, 400) y np.sin(x) z np.cos(x)# 三维折线图 axes Axes3D(fig, auto_add_to_figureFalse) fig.add_axes(axes) axes.plot(x,y,z) plt.savefi…

Runner 介绍

Runner 介绍 概述 Runner是用来批量调用collection里某个文件夹里的全部接口的。 (注意&#xff0c;我说的是文件夹内所有接口,可以是一级文件夹&#xff0c;也可是二级文件夹) 示意图 打开runner&#xff0c;如图所示 说明 历史记录 历史执行记录 2.导入 导入别人或之…

h5页面如何与原生交互

本文讲述h5页面跟原生通信&#xff0c;比如在app内&#xff0c;调用相机&#xff0c;获取相册内的图片&#xff0c;在app内拉起微信小程序等等&#xff0c;h5页面没有这么多权限能够直接调用移动端的原生能力&#xff0c;这个时候就需要与原生进行通讯&#xff0c;传递一个信号…