卷积神经网络|迁移学习-猫狗分类完整代码实现

还记得这篇文章吗?迁移学习|代码实现

在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。

同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!

我们仍然按照这个步骤开始我们的模型的训练

  • 准备一个可迭代的数据集

  • 定义一个神经网络

  • 将数据集输入到神经网络进行处理

  • 计算损失

  • 通过梯度下降算法更新参数

import torch import torchvisionimport torchvision.transforms as transformsimport torch.nn as nnimport torch.optim as optimimport matplotlib.pyplot as pltfrom torchvision import models

数据集准备

cifar10_train = torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = True,    download = True)cifar10_test=torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = False,    download = True)
transform = transforms.Compose([        transforms.ToTensor(),        transforms.Resize((224,224))    ])cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)

数据集使用CIFAR-10数据集中的猫和狗

CIFAR-10数据集类别

种类       标签

  • plane       0

  • car           1

  • bird         2

  • cat           3

  • deer         4

  • dog          5

  • frog         6

  • horse       7

  • ship         8

  • truck        9

可以看到其中cat和dog的标签分别为3和5

借助:

[3,5].index(label)

我们可以将cat标签变为0dog标签变为1,从而回到二分类问题。

举个例子:

>>> [3,5].index(3)0>>> [3,5].index(5)1

定义模型

参考这篇文章:迁移学习|代码实现

#网络搭建network=models.resnet18(pretrained=True)
for param in network.parameters():    param.requires_grad=False
network.fc=nn.Linear(512,2)#损失函数criterion=nn.CrossEntropyLoss()#优化器optimizer=optim.SGD(network.fc.parameters(),lr=0.01,momentum=0.9)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")network=network.to(device)

训练模型:

for epoch in range(10):    total_loss = 0    total_correct = 0    for batch in train_loader:   # Get batch        images, labels =batch        images=images.to(device)        labels=labels.to(device)                    optimizer.zero_grad()  #告诉优化器把梯度属性中权重的梯度归零,否则pytorch会累积梯度        preds = network(images)        loss = criterion(preds, labels)        loss.backward()        optimizer.step()                total_loss += loss.item()        _,prelabels=torch.max(preds,dim=1)        total_correct += int((prelabels==labels).sum())    accuracy = total_correct/len(cifar2_train)    print("Epoch:%d  ,  Loss:%f  , Accuracy:%f "%(epoch,total_loss,accuracy))
  • Epoch:0  ,  Loss:78.549439  , Accuracy:0.788900

  • Epoch:1  ,  Loss:77.828066  , Accuracy:0.801500

  • Epoch:2  ,  Loss:66.151785  , Accuracy:0.828100

  • Epoch:3  ,  Loss:76.204446  , Accuracy:0.816800

  • Epoch:4  ,  Loss:68.886606  , Accuracy:0.828100

  • Epoch:5  ,  Loss:71.129405  , Accuracy:0.821200

  • Epoch:6  ,  Loss:66.096364  , Accuracy:0.829900

  • Epoch:7  ,  Loss:65.504227  , Accuracy:0.827700

  • Epoch:8  ,  Loss:76.303878  , Accuracy:0.817100

  • Epoch:9  ,  Loss:70.546953  , Accuracy:0.820700

测试模型:

correct=0total=0network.eval()with torch.no_grad():    for batch in test_loader:        imgs,labels=batch        imgs=imgs.cuda()        labels=labels.cuda()                preds=network(imgs)        _,prelabels=torch.max(preds,dim=1)        #print(prelabels.size())        total=total+labels.size(0)        correct=correct+int((prelabels==labels).sum())    #print(total)    accuracy=correct/total    print("Accuracy: ",accuracy)

Accuracy:  0.8025

这里使用的预训练模型是resnet18,我们也可以使用VGG16模型,同时记得改变最后一个全连接层的输出参数,使得其满足我们自己的任务。

除了预训练模型之外,我们还可以对一些超参数进行调整,使最后的效果变得更好!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/606089.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

qtday1(2024/1/8)

#include "mywidget.h"MyWidget::MyWidget(QWidget *parent): QMainWindow(parent) {//设置界面固定大小this->resize(1728,972);this->setFixedSize(1728,972);this->setWindowIcon(QIcon("C:\\Users\\78507\\Desktop\\pic\\qq1.png"));this->…

高级RAG(五):TruLens 评估-扩大和加速LLM应用程序评估

之前我们介绍了,RAGAs评估,今天我们再来介绍另外一款RAG的评估工具:TruLens , trulens是TruEra公司的一款开源软件工具,它可帮助您使用反馈功函数客观地评估基于 LLM 的应用程序的质量和有效性。反馈函数有助于以编程方式评估输入、输出和中间…

vue3 内置组件

文章目录 前言一、过渡效果相关的组件1、Transition2、TransitionGroup 二、状态缓存组件(KeepAlive)三、传送组件(Teleport )四、异步依赖处理组件(Suspense) 前言 在vue3中 其提供了5个内置组件 Transiti…

antv/x6_2.0学习使用(四、边)

一、添加边 节点和边都有共同的基类 Cell,除了从 Cell 继承属性外,还支持以下选项。 属性名类型默认值描述sourceTerminalData-源节点或起始点targetTerminalData-目标节点或目标点verticesPoint.PointLike[]-路径点routerRouterData-路由connectorCon…

猫咪吃哪种猫粮好?主食冻干猫粮哪种性价比高

由于猫咪是肉食动物,对蛋白质的需求很高,如果摄入的蛋白质不足,就会影响猫咪的成长。而冻干猫粮本身因为制作工艺的原因,能保留原有的营养成分和营养元素,所以冻干猫粮蛋白含量比较高,营养又高,…

第二十七周:文献阅读笔记

第二十七周:文献阅读笔记 摘要AbstractDenseNet 网络1. 文献摘要2. 引言3. ResNets4. Dense Block5. Pooling layers6. Implementation Details7. Experiments8. Feature Reuse9. 代码实现 总结 摘要 DenseNet(密集连接网络)是一种深度学习神…

工智能基础知识总结--词嵌入之FastText

什么是FastText FastText是Facebook于2016年开源的一个词向量计算和文本分类工具,它提出了子词嵌入的方法,试图在词嵌入向量中引入构词信息。一般情况下,使用fastText进行文本分类的同时也会产生词的embedding,即embedding是fastText分类的产物。 FastText流程 FastText的架…

计算机组成原理简答题

目录 1、指令和数据在计算机内部以几进制存储,又是如何区分的呢? 2、计算机内部为什么要使用二进制? 3、简单描述计算机系统的层次结构 4、DRAM为什么要进行刷新,如何刷新的? 5、简述不同操作码的指令格式&#xf…

FileStream文件管理

文件管理 FileStream:是一个用于读写文件的一个类。它提供了基于流的方式操作文件,可以进行读取、写入、查找和关闭等操作。 第一个参数:path(路径) 相对路径:相对于当前项目的bin目录下的Debug和Realse来…

[嵌入式AI从0开始到入土]10_yolov5在昇腾上应用

[嵌入式AI从0开始到入土]嵌入式AI系列教程 注:等我摸完鱼再把链接补上 可以关注我的B站号工具人呵呵的个人空间,后期会考虑出视频教程,务必催更,以防我变身鸽王。 第一章 昇腾Altas 200 DK上手 第二章 下载昇腾案例并运行 第三章…

【AI视野·今日NLP 自然语言处理论文速览 第七十一期】Fri, 5 Jan 2024

AI视野今日CS.NLP 自然语言处理论文速览 Fri, 5 Jan 2024 Totally 28 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers LLaMA Pro: Progressive LLaMA with Block Expansion Authors Chengyue Wu, Yukang Gan, Yixiao Ge, Zeyu Lu, …

java导出word套打

这篇文档手把手教你完成导出word套打&#xff0c;有这个demo&#xff0c;其他word套打导出都通用。 1、主要依赖 <!--hutool--><dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.3.0</ve…

IPv6路由协议---IPv6动态路由(RIPng)

IPv6动态路由协议 动态路由协议有自己的路由算法,能够自动适应网络拓扑的变化,适用于具有一定数量三层设备的网络。缺点是配置对用户要求比较高,对系统的要求高于静态路由,并将占用一定的网络资源和系统资源。 路由表和FIB表 路由器转发数据包的关键是路由表和FIB表,每…

CreateDIBSection失败的问题记录

错误记录 [ERROR] (:0, ): QPixmap::fromWinHICON(), failed to GetIconInfo() (操作成功完成。) [ERROR] (:0, ): QPixmap::fromWinHICON(), failed to GetIconInfo() (参数错误。) [ERROR] (:0, ): QPixmap::fromWinHICON(), failed to GetIconInfo() (参数错误。) [ERROR] …

升级 Vite 5 出现警告 The CJS build of Vite‘s Node API is deprecated.

&#x1f680; 作者主页&#xff1a; 有来技术 &#x1f525; 开源项目&#xff1a; youlai-mall &#x1f343; vue3-element-admin &#x1f343; youlai-boot &#x1f33a; 仓库主页&#xff1a; Gitee &#x1f4ab; Github &#x1f4ab; GitCode &#x1f496; 欢迎点赞…

数仓建设学习路线(一)

前言 数仓建设实践路线是语兴发布在B站的系列课程&#xff0c;搜索语兴呀即可学习完整的数仓建设理论。 大数据相关岗位 大数据常见的岗位主要包括实时开发、数据治理、数据安全、数据资产等。 其中&#xff1a; 实时开发组的主要任务是实时可视化制作(大屏/彩蛋/战报&…

前端结合MQTT实现连接 订阅发送信息等操作 VUE3

MQTT客户端下载 使用测试 在我之前文章中 MQTT下载基础使用 下面记录一下前端使用的话的操作 1.安装 npm i mqtt引入 import * as mqtt from "mqtt/dist/mqtt.min"; //VUE3 import mqtt from mqtt //VUE2 一、MQTT协议中的方法 Connect。等待与服务器建立连接…

[VUE]2-vue的基本使用

目录 vue基本使用方式 1、vue 组件 2、文本插值 3、属性绑定 4、事件绑定 5、双向绑定 6、条件渲染 7、axios 8、⭐跨域问题 &#x1f343;作者介绍&#xff1a;双非本科大三网络工程专业在读&#xff0c;阿里云专家博主&#xff0c;专注于Java领域学习&#xff0c;擅…

气膜建筑:舒适、智能、可持续

气膜建筑之所以能够拥有广阔的发展空间&#xff0c;源于其融合了诸多优势特点&#xff0c;使其成为未来建筑领域的前沿趋势。 气膜建筑注重环境可持续性和能源效率。在材料和设计上&#xff0c;它采用可回收材料、提高热保温效果&#xff0c;并积极利用太阳能等可再生能源&…

【洛谷学习自留】p9226 糖果

解题思路&#xff1a; 简单的计算题&#xff0c;用n对k取余&#xff0c;如果余数为0&#xff0c;则输出k的值&#xff0c;否则输出&#xff08;k-余数&#xff09;的值。 代码实现&#xff1a; import java.util.Scanner;public class p9226 {public static void main(Strin…