最快速度与最简代码搭建卷积神经网络,并快速训练模型,每日坚持手撕默写代码

大家好,我是微学AI,今天给大家介绍一下最快速度与最简代码搭建卷积神经网络,并快速训练模型,每日坚持手撕默写代码。随着人工智能的快速发展,去年有强大的大模型ChatGPT横空出世,国内的大模型也紧追其后的发布,主要包括:文心一言、ChatGLM、通义千问、百川大模型等,他们可以帮助我们编写代码,但是在实际中,高度依赖于大模型则会缺乏思考的能力,缺乏编写代码的感觉,在别人问的时候,缺乏熟练度。坚持多写代码反复进行,可以提高熟练程度,提高开发效率,锻炼记忆力。本文尝试利用最短的代码实现数据集、卷积神经网络的搭建、模型的训练,模型的评估的整个流程代码,快速熟练手打出来。

在这里插入图片描述

一、坚持手撕默写代码的意义:

关于坚持手撕默写代码的意义,我总结一下几点:

1.提高熟练程度:

通过手撕默写代码,我能够更加深入地理解代码的逻辑和工作原理,加深对代码的理解,并提高对编程语言和算法的熟练程度。

2.培养思维逻辑与开发效率:
手撕默写代码需要你对算法和语法有较为全面的理解,同时需要你将思路转化为具体的代码实现。这种过程能够培养我的思维逻辑能力,提高问题解决能力,提高模型库包的快速调用与开发效率。

3.探索学习新知识:
通过手撕默写代码,你会遇到各种问题和挑战,需要不断查阅资料、学习和探索,从中获得新的知识和技能。

4.锻炼记忆力:
反复手写代码可以加强对语法和细节的记忆,提高记忆力和代码的熟悉程度。

二、卷积神经网络的快速搭建

关于pytorch框架,我们经常用到的第三方库有torch,torch.nn,torchvision,这些我们要烂熟于心。

torch:torch是PyTorch的核心库,提供了张量操作、数学函数、自动求导等功能。它是一个多维数组的库,类似于NumPy,但具有GPU加速和用于深度学习的其他扩展功能。

torch.nn:torch.nn模块是PyTorch中用于构建神经网络模型的模块。它提供了各种层(如全连接层、卷积层、循环层等)和损失函数(如交叉熵损失、均方误差损失等),以及优化算法(如随机梯度下降等)的实现。

torchvision.transforms:torchvision.transforms模块提供了一系列用于图像预处理和数据增强的函数。通过该模块,可以对输入图像进行常见的操作,如裁剪、缩放、旋转、归一化等,以便更好地适应模型的输入要求。

torch.utils.data.DataLoader:torch.utils.data.DataLoader是PyTorch中用于加载和迭代数据集的工具。它可以将数据集封装成可迭代的数据加载器,支持批量加载、多线程加载和数据打乱等功能。

torchvision.datasets.FakeData:torchvision.datasets.FakeData是用于生成虚拟数据集的类。它可以根据指定的数据样式和大小生成虚拟的图像数据集,用于模型调试和测试。本文利用FakeData进行快速训练

第三方库的导入与卷积神经网络搭建:

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import FakeDataclass CNNnet(nn.Module):def __init__(self):super(CNNnet, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3,32,3,1,1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32,64,3,1,1),nn.ReLU(),nn.MaxPool2d(2))self.linear = nn.Linear(int((32/4)*(32/4)*64),2)def forward(self, x):x = self.conv1(x)x =x.view(x.size(0),-1)x = self.linear(x)return x

在上述的CNNnet网络模型中,nn.Linear(int((32/4)*(32/4)64),2)中的int((32/4)(32/4)*64)是指线性层的输入特征数。在该模型中,线性层的输入来自于卷积层输出的特征图,经过reshape处理后得到的一维向量。具体地,假设输入图像的大小为 W x H,卷积核大小为 k x k,卷积层的输出通道数为 n,则经过两次最大池化后,卷积层的输出特征图的大小为 (W/4) x (H/4) x n。因此,线性层的输入特征数 num = (W/4) x (H/4) x n。
我们这里设置输入图像的大小为 32x32,卷积核大小为 3x3,卷积层的输出通道数为 64,则经过两次最大池化后,卷积层的输出特征图的大小为 (32/4)x(32/4)x64=8x8x64=4096。因此,线性层的输入特征数 num=4096。

三、模型训练代码快速编写

model = CNNnet()  # 实例化模型criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 建立Adam优化器dataset = FakeData(size=1000,image_size=(3,32,32),num_classes=2,transform=transforms.ToTensor())
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)for epoch in range(25):running_loss = 0.0correct = 0total = 0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i + 1) % 10 == 0:print('[Epoch %d, Batch %5d] Loss: %.3f | Accuracy: %.3f%%' %(epoch + 1, i + 1, running_loss / 5, 100 * correct / total))running_loss = 0.0correct = 0total = 0

四、模型评估代码快速编写

# 模型评估
model.eval()
total = 0
correct = 0with torch.no_grad():for data in train_loader:inputs, labels = dataoutputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on the training dataset: %.3f%%' % (100 * correct / total))

上面的代码将模型设置为评估模式(model.eval()),然后使用torch.no_grad()上下文管理器来禁用梯度计算,以提高运行效率。在遍历训练集数据进行预测时,统计正确预测的样本数,并计算准确率。
该评估代码是在训练集上进行评估,如果需要在测试集上评估模型,需要使用测试集的数据进行评估。这里没有做扩展。

运行结果:

...
[Epoch 18, Batch    10] Loss: 0.292 | Accuracy: 99.062%
[Epoch 18, Batch    20] Loss: 0.264 | Accuracy: 100.000%
[Epoch 18, Batch    30] Loss: 0.245 | Accuracy: 100.000%
[Epoch 19, Batch    10] Loss: 0.208 | Accuracy: 100.000%
[Epoch 19, Batch    20] Loss: 0.218 | Accuracy: 100.000%
[Epoch 19, Batch    30] Loss: 0.215 | Accuracy: 99.688%
[Epoch 20, Batch    10] Loss: 0.201 | Accuracy: 100.000%
[Epoch 20, Batch    20] Loss: 0.183 | Accuracy: 100.000%
[Epoch 20, Batch    30] Loss: 0.165 | Accuracy: 100.000%
[Epoch 21, Batch    10] Loss: 0.136 | Accuracy: 100.000%
[Epoch 21, Batch    20] Loss: 0.137 | Accuracy: 100.000%
[Epoch 21, Batch    30] Loss: 0.119 | Accuracy: 100.000%
[Epoch 22, Batch    10] Loss: 0.108 | Accuracy: 100.000%
[Epoch 22, Batch    20] Loss: 0.102 | Accuracy: 100.000%
[Epoch 22, Batch    30] Loss: 0.098 | Accuracy: 100.000%
[Epoch 23, Batch    10] Loss: 0.087 | Accuracy: 100.000%
[Epoch 23, Batch    20] Loss: 0.083 | Accuracy: 100.000%
[Epoch 23, Batch    30] Loss: 0.086 | Accuracy: 100.000%
[Epoch 24, Batch    10] Loss: 0.072 | Accuracy: 100.000%
[Epoch 24, Batch    20] Loss: 0.075 | Accuracy: 100.000%
[Epoch 24, Batch    30] Loss: 0.075 | Accuracy: 100.000%
[Epoch 25, Batch    10] Loss: 0.068 | Accuracy: 100.000%
[Epoch 25, Batch    20] Loss: 0.060 | Accuracy: 100.000%
[Epoch 25, Batch    30] Loss: 0.065 | Accuracy: 100.000%
Accuracy on the training dataset: 100.000%

本文只是将模型训练的过程跑通,手打快速训练卷积神经网络网络的过程。实际应用场景中还需要将数据集分为训练集、验证集、测试集,详细的过程可以看我的往期文章。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/583651.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

写一个java状态模式的详细实例

以下是一个示例的 Java 状态模式实现: java Copy code // 定义状态接口 interface State { void handleState(Context context); } // 具体状态类 1 class ConcreteState1 implements State { public void handleState(Context context) { System…

基于ssm西安旅游管理系统论文

摘 要 在如今社会上,关于信息上面的处理,没有任何一个企业或者个人会忽视,如何让信息急速传递,并且归档储存查询,采用之前的纸张记录模式已经不符合当前使用要求了。所以,对西安旅游信息管理的提升&#x…

JavaScript----字符串拼接

1、字符串拼接 字符串拼接使用: "" 运算符 var iNum1 10; var fNum2 11.1; var sStr abc;result iNum1 fNum2; alert(result); // 弹出21.1result fNum2 sStr; alert(result); // 弹出11.1abc说明 数字和字符串拼接会自动进行类型转换(隐士类型转换)&#…

使用pandas绘图,并保存,支持中文

使用pandas绘图,并保存,支持中文 支持中文标题绘图创建DataFrame绘制图形添加其他绘图细节保存图形显示图形 支持中文标题 import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties import matplotlib.font_manager as fm…

深入理解Java集合框架

导语: Java集合框架是Java提供的一组用于管理对象的类和接口,它是Java编程中非常重要的一部分。Java集合框架通过提供诸如List、Set、Map等数据结构,为程序员提供了一种方便、高效的管理对象的方式。本文将深入理解Java集合框架,包…

RuntimeError: “slow_conv2d_cpu“ not implemented for ‘Half‘

目录 临时解决方法: RuntimeError: "slow_conv2d_cpu" not implemented for Half train_lora.py中: 原因:cpu不支持fp16类型, 临时解决方法: 注释掉fp16模式, weight_dtype torch.float32i…

Mediapipe绘制实时3d铰接骨架图——Mediapipe实时姿态估计

一、前言 大约两年前,基于自己的理解我曾写了几篇关于Mediapipe的文章,似乎帮助到了一些人。这两年,忙于比赛、实习、毕业、工作和考研。上篇文章已经是一年多前发的了。这段时间收到很多私信和评论,请原谅无法一一回复了。我将尝…

Redis缓存与数据库如何保证一致性

数据库和缓存如何保证一致性? 目录 数据库和缓存如何保证一致性?背景方案先更新数据库,还是先更新缓存?先更新数据库,再更新缓存先更新缓存,再更新数据库 先更新数据库,还是先删除缓存&#xff…

安装 PyQt5 保姆级教程

作者:billy 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 前言 博主之前做应用层开发用的一直是 Qt,这次尝试一下在 python 中使用 Pyqt5 模块来开发 UI 界面,这里做一些…

certum的ip证书购买流程

Certum是成立于欧洲的CA认证机构,经过二十几年的发展Certum已经成为欧洲知名的CA认证机构之一,拥有广泛的客户群体和合作伙伴。IP证书是Certum为只有公网IP地址的网站准备的数字加密服务。今天就随SSL盾小编了解购买Certum旗下的IP证书流程。 第一步&am…

WPF Grid

Resource 在 WPF 中,“Grid” 是一种用于布局的面板控件,而 “Resource” 是一种用于定义可重用对象的机制。您可以将资源定义为 Grid 控件的一部分,以便在整个应用程序中共享和重用。 使用资源可以帮助您简化界面的创建和维护。在 Grid 控件…

总结一些好用的函数

1. <string.h>/<cstring>头文件 中的 memset函数 作用&#xff1a;用于将一段内存区域设置为特定的值(它作用的基本单位是字节) 可以对变量&#xff0c;数组&#xff08;一维数组和二维数组&#xff09;&#xff0c;结构体进行初始化&#xff0c;但是不能对vecto…

数据库进阶教学——读写分离(Mycat1.6+Ubuntu22.04主+Win10从)

目录 1、概述 2、环境准备 3、读写分离实验 3.1、安装jdk 3.2、安装Mycat 3.3、配置Mycat 3.3.1、配置schema.xml ​​​​3.3.2、配置server.xml 3.4、修改主从机远程登陆权限 3.4.1、主机 3.4.2、从机 3.5、启动Mycat 3.6、登录Mycat 3.7、验证 1、概述 读写分…

如何合理配置云服务器的CPU和内存?

​  提到云服务器性能&#xff0c;大抵有两个主要影响因素&#xff0c;CPU 核心数量和内存容量 &#xff0c;它们决定了云服务器的速度和可靠性。日常运用中&#xff0c;我们如何判断网站需要需要更多或更少?如何扩大或缩小它们以优化网站的性能? 一般来说&#xff0c;您拥…

视频遥测终端机的设计需求

目录 1.目的 2.参考文件 3.总体描述 4.硬件资源描述 4.1微控制单元 4.2视频处理单元 4.3性能指标 5.功能要求 5.1系统参数要求 5.1.1系统管理 5.1.2系统配置 5.1.2.1一般参数 5.1.2.2编码参数 5.1.2.3网络参数 5.1.2.4网络服务 5.1.2.5OSD参数 5.1.2.6抓拍 5.…

java设计模式学习之【访问者模式】

文章目录 引言访问者模式简介定义与用途实现方式 使用场景优势与劣势在Spring框架中的应用电脑示例代码地址 引言 设想你是一个艺术馆的管理员&#xff0c;艺术馆里有各种各样的艺术品。每当有游客来访时&#xff0c;根据他们的兴趣&#xff0c;他们可能只想看画、雕塑或特定的…

【开源】基于Vue+SpringBoot的房屋出售出租系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 房屋销售模块2.2 房屋出租模块2.3 预定意向模块2.4 交易订单模块 三、系统展示四、核心代码4.1 查询房屋求租单4.2 查询卖家的房屋求购单4.3 出租意向预定4.4 出租单支付4.5 查询买家房屋销售交易单 五、免责说明 一、摘…

华为云服务器重启后无法连接故障解决

华为云服务器系统为麒麟银河V10sp1&#xff0c;升级sp1.1后继续升级sp2. 升级后重启发现无法连接服务器了。 登录华为云控制台&#xff0c;使用用vnc方式连接成功。 推测应该是网络问题。使用ip addr命令检查&#xff0c;发现网卡eth0处于down状态。 使用命令启动网卡&…

spring security oauth2搭建认证服务器

如图&#xff08;上面图片的代码在业务项目中&#xff09;&#xff0c;第一步在独立的业务项目中&#xff0c;先获取授权码&#xff08;也叫jsessionId&#xff09;、获取授权码的路径就是 /oauth2/authorize&#xff0c;这个路径是oauth2的框架中被OAuth2AuthorizationEndpoin…

如何从RTP包的AP类型包,获取h265的PPS、SPS、VPS信息

ffmpeg播放rtp流&#xff0c;为了降低首开延迟&#xff0c;需要在SDP文件中指定PPS、SPS、VPS信息。抓包后发现wireshark无法解析AP包。需要自己进行AP包解析。RTP协议AP包格式如下&#xff1a; 根据如上信息&#xff0c;我们可以解析AP包&#xff0c;效果如下 40 01&#xff…