卷积神经网络|迁移学习-猫狗分类完整代码实现

还记得这篇文章吗?迁移学习|代码实现

在这篇文章中,我们知道了在构建模型时,可以借助一些非常有名的模型,这些模型在ImageNet数据集上早已经得到了检验。

同时torchvision模块也提供了预训练好的模型。我们只需稍作修改,便可运用到自己的实际任务中!

我们仍然按照这个步骤开始我们的模型的训练

  • 准备一个可迭代的数据集

  • 定义一个神经网络

  • 将数据集输入到神经网络进行处理

  • 计算损失

  • 通过梯度下降算法更新参数

import torch import torchvisionimport torchvision.transforms as transformsimport torch.nn as nnimport torch.optim as optimimport matplotlib.pyplot as pltfrom torchvision import models

数据集准备

cifar10_train = torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = True,    download = True)cifar10_test=torchvision.datasets.CIFAR10(    root = 'cifar10/',    train = False,    download = True)
transform = transforms.Compose([        transforms.ToTensor(),        transforms.Resize((224,224))    ])cifar2_train=[(transform(img),[3,5].index(label)) for img,label in cifar10_train if label in [3,5]]
cifar2_test=[(transform(img),[3,5].index(label)) for img,label in cifar10_test if label in [3,5]]
train_loader = torch.utils.data.DataLoader(cifar2_train, batch_size=64,shuffle=True)test_loader = torch.utils.data.DataLoader(cifar2_test, batch_size=64,shuffle=True)

数据集使用CIFAR-10数据集中的猫和狗

CIFAR-10数据集类别

种类       标签

  • plane       0

  • car           1

  • bird         2

  • cat           3

  • deer         4

  • dog          5

  • frog         6

  • horse       7

  • ship         8

  • truck        9

可以看到其中cat和dog的标签分别为3和5

借助:

[3,5].index(label)

我们可以将cat标签变为0dog标签变为1,从而回到二分类问题。

举个例子:

>>> [3,5].index(3)0>>> [3,5].index(5)1

定义模型

参考这篇文章:迁移学习|代码实现

#网络搭建network=models.resnet18(pretrained=True)
for param in network.parameters():    param.requires_grad=False
network.fc=nn.Linear(512,2)#损失函数criterion=nn.CrossEntropyLoss()#优化器optimizer=optim.SGD(network.fc.parameters(),lr=0.01,momentum=0.9)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")network=network.to(device)

训练模型:

for epoch in range(10):    total_loss = 0    total_correct = 0    for batch in train_loader:   # Get batch        images, labels =batch        images=images.to(device)        labels=labels.to(device)                    optimizer.zero_grad()  #告诉优化器把梯度属性中权重的梯度归零,否则pytorch会累积梯度        preds = network(images)        loss = criterion(preds, labels)        loss.backward()        optimizer.step()                total_loss += loss.item()        _,prelabels=torch.max(preds,dim=1)        total_correct += int((prelabels==labels).sum())    accuracy = total_correct/len(cifar2_train)    print("Epoch:%d  ,  Loss:%f  , Accuracy:%f "%(epoch,total_loss,accuracy))
  • Epoch:0  ,  Loss:78.549439  , Accuracy:0.788900

  • Epoch:1  ,  Loss:77.828066  , Accuracy:0.801500

  • Epoch:2  ,  Loss:66.151785  , Accuracy:0.828100

  • Epoch:3  ,  Loss:76.204446  , Accuracy:0.816800

  • Epoch:4  ,  Loss:68.886606  , Accuracy:0.828100

  • Epoch:5  ,  Loss:71.129405  , Accuracy:0.821200

  • Epoch:6  ,  Loss:66.096364  , Accuracy:0.829900

  • Epoch:7  ,  Loss:65.504227  , Accuracy:0.827700

  • Epoch:8  ,  Loss:76.303878  , Accuracy:0.817100

  • Epoch:9  ,  Loss:70.546953  , Accuracy:0.820700

测试模型:

correct=0total=0network.eval()with torch.no_grad():    for batch in test_loader:        imgs,labels=batch        imgs=imgs.cuda()        labels=labels.cuda()                preds=network(imgs)        _,prelabels=torch.max(preds,dim=1)        #print(prelabels.size())        total=total+labels.size(0)        correct=correct+int((prelabels==labels).sum())    #print(total)    accuracy=correct/total    print("Accuracy: ",accuracy)

Accuracy:  0.8025

这里使用的预训练模型是resnet18,我们也可以使用VGG16模型,同时记得改变最后一个全连接层的输出参数,使得其满足我们自己的任务。

除了预训练模型之外,我们还可以对一些超参数进行调整,使最后的效果变得更好!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/606089.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【airsim】computer_vision 源码阅读

api文档https://github.com/Microsoft/AirSim/blob/main/docs/image_apis.md#computer-vision-mode capture_ir_segmentation AirSim\PythonClient\computer_vision\capture_ir_segmentation.py 在重新分配分段 ID 后运行。它跟踪感兴趣的物体并记录多旋翼飞行器的红外图像和…

qtday1(2024/1/8)

#include "mywidget.h"MyWidget::MyWidget(QWidget *parent): QMainWindow(parent) {//设置界面固定大小this->resize(1728,972);this->setFixedSize(1728,972);this->setWindowIcon(QIcon("C:\\Users\\78507\\Desktop\\pic\\qq1.png"));this->…

高级RAG(五):TruLens 评估-扩大和加速LLM应用程序评估

之前我们介绍了,RAGAs评估,今天我们再来介绍另外一款RAG的评估工具:TruLens , trulens是TruEra公司的一款开源软件工具,它可帮助您使用反馈功函数客观地评估基于 LLM 的应用程序的质量和有效性。反馈函数有助于以编程方式评估输入、输出和中间…

vue3 内置组件

文章目录 前言一、过渡效果相关的组件1、Transition2、TransitionGroup 二、状态缓存组件(KeepAlive)三、传送组件(Teleport )四、异步依赖处理组件(Suspense) 前言 在vue3中 其提供了5个内置组件 Transiti…

antv/x6_2.0学习使用(四、边)

一、添加边 节点和边都有共同的基类 Cell,除了从 Cell 继承属性外,还支持以下选项。 属性名类型默认值描述sourceTerminalData-源节点或起始点targetTerminalData-目标节点或目标点verticesPoint.PointLike[]-路径点routerRouterData-路由connectorCon…

猫咪吃哪种猫粮好?主食冻干猫粮哪种性价比高

由于猫咪是肉食动物,对蛋白质的需求很高,如果摄入的蛋白质不足,就会影响猫咪的成长。而冻干猫粮本身因为制作工艺的原因,能保留原有的营养成分和营养元素,所以冻干猫粮蛋白含量比较高,营养又高,…

Python3 集合

集合(set)是一个无序的不重复元素序列。 集合中的元素不会重复,并且可以进行交集、并集、差集等常见的集合操作。 可以使用大括号 { } 创建集合,元素之间用逗号 , 分隔, 或者也可以使用 set() 函数创建集合。 创建格…

第二十七周:文献阅读笔记

第二十七周:文献阅读笔记 摘要AbstractDenseNet 网络1. 文献摘要2. 引言3. ResNets4. Dense Block5. Pooling layers6. Implementation Details7. Experiments8. Feature Reuse9. 代码实现 总结 摘要 DenseNet(密集连接网络)是一种深度学习神…

技术总监写的十个方法,让我精通了lambda表达式

技术总监写的十个方法,让我精通了lambda表达式 Collection 转化为 Map使用样例代码展示 Map格式转换转换 Map 的 Value测试样例代码展示 集合类型转化Collection 和 List、Set 的转化测试样例 List、Set 类型之间的转换测试样例 前公司的技术总监写了工具类&#xf…

工智能基础知识总结--词嵌入之FastText

什么是FastText FastText是Facebook于2016年开源的一个词向量计算和文本分类工具,它提出了子词嵌入的方法,试图在词嵌入向量中引入构词信息。一般情况下,使用fastText进行文本分类的同时也会产生词的embedding,即embedding是fastText分类的产物。 FastText流程 FastText的架…

计算机组成原理简答题

目录 1、指令和数据在计算机内部以几进制存储,又是如何区分的呢? 2、计算机内部为什么要使用二进制? 3、简单描述计算机系统的层次结构 4、DRAM为什么要进行刷新,如何刷新的? 5、简述不同操作码的指令格式&#xf…

常用Java代码-Java中的多线程编程(Multi-threading)

多线程编程是Java中的一个重要概念,它允许程序在同一时刻执行多个任务,提高程序的执行效率和响应性。在Java中,多线程编程通过创建多个线程并利用线程来执行任务实现。 Java提供了Thread类和Runnable接口来实现多线程编程。Thread类是Java中…

FileStream文件管理

文件管理 FileStream:是一个用于读写文件的一个类。它提供了基于流的方式操作文件,可以进行读取、写入、查找和关闭等操作。 第一个参数:path(路径) 相对路径:相对于当前项目的bin目录下的Debug和Realse来…

EMD+包络谱故障诊断

EMD是一种信号处理方法,用于将信号分解成多个本征模态函数(Intrinsic Mode Functions,IMF),每个IMF代表信号中的一个固有振动模式。VMD在处理非平稳信号和非线性信号方面具有较好的性能。 包络谱峭度是一种用于描述信号包络频谱形状的特征。它通过对信号包络谱的谱线斜率…

[嵌入式AI从0开始到入土]10_yolov5在昇腾上应用

[嵌入式AI从0开始到入土]嵌入式AI系列教程 注:等我摸完鱼再把链接补上 可以关注我的B站号工具人呵呵的个人空间,后期会考虑出视频教程,务必催更,以防我变身鸽王。 第一章 昇腾Altas 200 DK上手 第二章 下载昇腾案例并运行 第三章…

Android 车联网——多屏多用户(十五)

前面几篇文章介绍了多用户和多屏相关的 Manager 和 Service。上一篇文章最后虽然车内乘员都根据配置有自己的对应屏幕,但默认情况下,所有车内乘员依然使用的是当前主用户(司机用户),这一篇我们继续放下看一下用户的创建与分配。 一、用户创建分配 1、创建用户 对于创建用…

【AI视野·今日NLP 自然语言处理论文速览 第七十一期】Fri, 5 Jan 2024

AI视野今日CS.NLP 自然语言处理论文速览 Fri, 5 Jan 2024 Totally 28 papers 👉上期速览✈更多精彩请移步主页 Daily Computation and Language Papers LLaMA Pro: Progressive LLaMA with Block Expansion Authors Chengyue Wu, Yukang Gan, Yixiao Ge, Zeyu Lu, …

JavaScript实现的复杂功能:自动生成带水印的图片

#程序员的崩溃瞬间 在本文中,我们将讨论一个JavaScript实现的复杂功能,该功能可以自动为图片添加水印。这个功能在许多场景中都非常有用,例如,如果你想保护你的图片版权,或者你想在你的网站上显示自定义的水印。 一、…

java导出word套打

这篇文档手把手教你完成导出word套打&#xff0c;有这个demo&#xff0c;其他word套打导出都通用。 1、主要依赖 <!--hutool--><dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.3.0</ve…

IPv6路由协议---IPv6动态路由(RIPng)

IPv6动态路由协议 动态路由协议有自己的路由算法,能够自动适应网络拓扑的变化,适用于具有一定数量三层设备的网络。缺点是配置对用户要求比较高,对系统的要求高于静态路由,并将占用一定的网络资源和系统资源。 路由表和FIB表 路由器转发数据包的关键是路由表和FIB表,每…