最快速度与最简代码搭建卷积神经网络,并快速训练模型,每日坚持手撕默写代码

大家好,我是微学AI,今天给大家介绍一下最快速度与最简代码搭建卷积神经网络,并快速训练模型,每日坚持手撕默写代码。随着人工智能的快速发展,去年有强大的大模型ChatGPT横空出世,国内的大模型也紧追其后的发布,主要包括:文心一言、ChatGLM、通义千问、百川大模型等,他们可以帮助我们编写代码,但是在实际中,高度依赖于大模型则会缺乏思考的能力,缺乏编写代码的感觉,在别人问的时候,缺乏熟练度。坚持多写代码反复进行,可以提高熟练程度,提高开发效率,锻炼记忆力。本文尝试利用最短的代码实现数据集、卷积神经网络的搭建、模型的训练,模型的评估的整个流程代码,快速熟练手打出来。

在这里插入图片描述

一、坚持手撕默写代码的意义:

关于坚持手撕默写代码的意义,我总结一下几点:

1.提高熟练程度:

通过手撕默写代码,我能够更加深入地理解代码的逻辑和工作原理,加深对代码的理解,并提高对编程语言和算法的熟练程度。

2.培养思维逻辑与开发效率:
手撕默写代码需要你对算法和语法有较为全面的理解,同时需要你将思路转化为具体的代码实现。这种过程能够培养我的思维逻辑能力,提高问题解决能力,提高模型库包的快速调用与开发效率。

3.探索学习新知识:
通过手撕默写代码,你会遇到各种问题和挑战,需要不断查阅资料、学习和探索,从中获得新的知识和技能。

4.锻炼记忆力:
反复手写代码可以加强对语法和细节的记忆,提高记忆力和代码的熟悉程度。

二、卷积神经网络的快速搭建

关于pytorch框架,我们经常用到的第三方库有torch,torch.nn,torchvision,这些我们要烂熟于心。

torch:torch是PyTorch的核心库,提供了张量操作、数学函数、自动求导等功能。它是一个多维数组的库,类似于NumPy,但具有GPU加速和用于深度学习的其他扩展功能。

torch.nn:torch.nn模块是PyTorch中用于构建神经网络模型的模块。它提供了各种层(如全连接层、卷积层、循环层等)和损失函数(如交叉熵损失、均方误差损失等),以及优化算法(如随机梯度下降等)的实现。

torchvision.transforms:torchvision.transforms模块提供了一系列用于图像预处理和数据增强的函数。通过该模块,可以对输入图像进行常见的操作,如裁剪、缩放、旋转、归一化等,以便更好地适应模型的输入要求。

torch.utils.data.DataLoader:torch.utils.data.DataLoader是PyTorch中用于加载和迭代数据集的工具。它可以将数据集封装成可迭代的数据加载器,支持批量加载、多线程加载和数据打乱等功能。

torchvision.datasets.FakeData:torchvision.datasets.FakeData是用于生成虚拟数据集的类。它可以根据指定的数据样式和大小生成虚拟的图像数据集,用于模型调试和测试。本文利用FakeData进行快速训练

第三方库的导入与卷积神经网络搭建:

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import FakeDataclass CNNnet(nn.Module):def __init__(self):super(CNNnet, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3,32,3,1,1),nn.ReLU(),nn.MaxPool2d(2),nn.Conv2d(32,64,3,1,1),nn.ReLU(),nn.MaxPool2d(2))self.linear = nn.Linear(int((32/4)*(32/4)*64),2)def forward(self, x):x = self.conv1(x)x =x.view(x.size(0),-1)x = self.linear(x)return x

在上述的CNNnet网络模型中,nn.Linear(int((32/4)*(32/4)64),2)中的int((32/4)(32/4)*64)是指线性层的输入特征数。在该模型中,线性层的输入来自于卷积层输出的特征图,经过reshape处理后得到的一维向量。具体地,假设输入图像的大小为 W x H,卷积核大小为 k x k,卷积层的输出通道数为 n,则经过两次最大池化后,卷积层的输出特征图的大小为 (W/4) x (H/4) x n。因此,线性层的输入特征数 num = (W/4) x (H/4) x n。
我们这里设置输入图像的大小为 32x32,卷积核大小为 3x3,卷积层的输出通道数为 64,则经过两次最大池化后,卷积层的输出特征图的大小为 (32/4)x(32/4)x64=8x8x64=4096。因此,线性层的输入特征数 num=4096。

三、模型训练代码快速编写

model = CNNnet()  # 实例化模型criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 建立Adam优化器dataset = FakeData(size=1000,image_size=(3,32,32),num_classes=2,transform=transforms.ToTensor())
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)for epoch in range(25):running_loss = 0.0correct = 0total = 0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i + 1) % 10 == 0:print('[Epoch %d, Batch %5d] Loss: %.3f | Accuracy: %.3f%%' %(epoch + 1, i + 1, running_loss / 5, 100 * correct / total))running_loss = 0.0correct = 0total = 0

四、模型评估代码快速编写

# 模型评估
model.eval()
total = 0
correct = 0with torch.no_grad():for data in train_loader:inputs, labels = dataoutputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on the training dataset: %.3f%%' % (100 * correct / total))

上面的代码将模型设置为评估模式(model.eval()),然后使用torch.no_grad()上下文管理器来禁用梯度计算,以提高运行效率。在遍历训练集数据进行预测时,统计正确预测的样本数,并计算准确率。
该评估代码是在训练集上进行评估,如果需要在测试集上评估模型,需要使用测试集的数据进行评估。这里没有做扩展。

运行结果:

...
[Epoch 18, Batch    10] Loss: 0.292 | Accuracy: 99.062%
[Epoch 18, Batch    20] Loss: 0.264 | Accuracy: 100.000%
[Epoch 18, Batch    30] Loss: 0.245 | Accuracy: 100.000%
[Epoch 19, Batch    10] Loss: 0.208 | Accuracy: 100.000%
[Epoch 19, Batch    20] Loss: 0.218 | Accuracy: 100.000%
[Epoch 19, Batch    30] Loss: 0.215 | Accuracy: 99.688%
[Epoch 20, Batch    10] Loss: 0.201 | Accuracy: 100.000%
[Epoch 20, Batch    20] Loss: 0.183 | Accuracy: 100.000%
[Epoch 20, Batch    30] Loss: 0.165 | Accuracy: 100.000%
[Epoch 21, Batch    10] Loss: 0.136 | Accuracy: 100.000%
[Epoch 21, Batch    20] Loss: 0.137 | Accuracy: 100.000%
[Epoch 21, Batch    30] Loss: 0.119 | Accuracy: 100.000%
[Epoch 22, Batch    10] Loss: 0.108 | Accuracy: 100.000%
[Epoch 22, Batch    20] Loss: 0.102 | Accuracy: 100.000%
[Epoch 22, Batch    30] Loss: 0.098 | Accuracy: 100.000%
[Epoch 23, Batch    10] Loss: 0.087 | Accuracy: 100.000%
[Epoch 23, Batch    20] Loss: 0.083 | Accuracy: 100.000%
[Epoch 23, Batch    30] Loss: 0.086 | Accuracy: 100.000%
[Epoch 24, Batch    10] Loss: 0.072 | Accuracy: 100.000%
[Epoch 24, Batch    20] Loss: 0.075 | Accuracy: 100.000%
[Epoch 24, Batch    30] Loss: 0.075 | Accuracy: 100.000%
[Epoch 25, Batch    10] Loss: 0.068 | Accuracy: 100.000%
[Epoch 25, Batch    20] Loss: 0.060 | Accuracy: 100.000%
[Epoch 25, Batch    30] Loss: 0.065 | Accuracy: 100.000%
Accuracy on the training dataset: 100.000%

本文只是将模型训练的过程跑通,手打快速训练卷积神经网络网络的过程。实际应用场景中还需要将数据集分为训练集、验证集、测试集,详细的过程可以看我的往期文章。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/583651.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于ssm西安旅游管理系统论文

摘 要 在如今社会上,关于信息上面的处理,没有任何一个企业或者个人会忽视,如何让信息急速传递,并且归档储存查询,采用之前的纸张记录模式已经不符合当前使用要求了。所以,对西安旅游信息管理的提升&#x…

Mediapipe绘制实时3d铰接骨架图——Mediapipe实时姿态估计

一、前言 大约两年前,基于自己的理解我曾写了几篇关于Mediapipe的文章,似乎帮助到了一些人。这两年,忙于比赛、实习、毕业、工作和考研。上篇文章已经是一年多前发的了。这段时间收到很多私信和评论,请原谅无法一一回复了。我将尝…

Redis缓存与数据库如何保证一致性

数据库和缓存如何保证一致性? 目录 数据库和缓存如何保证一致性?背景方案先更新数据库,还是先更新缓存?先更新数据库,再更新缓存先更新缓存,再更新数据库 先更新数据库,还是先删除缓存&#xff…

安装 PyQt5 保姆级教程

作者:billy 版权声明:著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处 前言 博主之前做应用层开发用的一直是 Qt,这次尝试一下在 python 中使用 Pyqt5 模块来开发 UI 界面,这里做一些…

certum的ip证书购买流程

Certum是成立于欧洲的CA认证机构,经过二十几年的发展Certum已经成为欧洲知名的CA认证机构之一,拥有广泛的客户群体和合作伙伴。IP证书是Certum为只有公网IP地址的网站准备的数字加密服务。今天就随SSL盾小编了解购买Certum旗下的IP证书流程。 第一步&am…

数据库进阶教学——读写分离(Mycat1.6+Ubuntu22.04主+Win10从)

目录 1、概述 2、环境准备 3、读写分离实验 3.1、安装jdk 3.2、安装Mycat 3.3、配置Mycat 3.3.1、配置schema.xml ​​​​3.3.2、配置server.xml 3.4、修改主从机远程登陆权限 3.4.1、主机 3.4.2、从机 3.5、启动Mycat 3.6、登录Mycat 3.7、验证 1、概述 读写分…

如何合理配置云服务器的CPU和内存?

​  提到云服务器性能,大抵有两个主要影响因素,CPU 核心数量和内存容量 ,它们决定了云服务器的速度和可靠性。日常运用中,我们如何判断网站需要需要更多或更少?如何扩大或缩小它们以优化网站的性能? 一般来说,您拥…

视频遥测终端机的设计需求

目录 1.目的 2.参考文件 3.总体描述 4.硬件资源描述 4.1微控制单元 4.2视频处理单元 4.3性能指标 5.功能要求 5.1系统参数要求 5.1.1系统管理 5.1.2系统配置 5.1.2.1一般参数 5.1.2.2编码参数 5.1.2.3网络参数 5.1.2.4网络服务 5.1.2.5OSD参数 5.1.2.6抓拍 5.…

java设计模式学习之【访问者模式】

文章目录 引言访问者模式简介定义与用途实现方式 使用场景优势与劣势在Spring框架中的应用电脑示例代码地址 引言 设想你是一个艺术馆的管理员,艺术馆里有各种各样的艺术品。每当有游客来访时,根据他们的兴趣,他们可能只想看画、雕塑或特定的…

【开源】基于Vue+SpringBoot的房屋出售出租系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 房屋销售模块2.2 房屋出租模块2.3 预定意向模块2.4 交易订单模块 三、系统展示四、核心代码4.1 查询房屋求租单4.2 查询卖家的房屋求购单4.3 出租意向预定4.4 出租单支付4.5 查询买家房屋销售交易单 五、免责说明 一、摘…

spring security oauth2搭建认证服务器

如图(上面图片的代码在业务项目中),第一步在独立的业务项目中,先获取授权码(也叫jsessionId)、获取授权码的路径就是 /oauth2/authorize,这个路径是oauth2的框架中被OAuth2AuthorizationEndpoin…

如何从RTP包的AP类型包,获取h265的PPS、SPS、VPS信息

ffmpeg播放rtp流,为了降低首开延迟,需要在SDP文件中指定PPS、SPS、VPS信息。抓包后发现wireshark无法解析AP包。需要自己进行AP包解析。RTP协议AP包格式如下: 根据如上信息,我们可以解析AP包,效果如下 40 01&#xff…

图像质量评估:使用 SSIM 计算图像相似性

在图像处理领域,衡量两幅图像之间相似性的一种常见方法是使用结构相似性指数(SSIM)。SSIM 是一种全参考的图像质量评估指标,它不仅考虑了图像的亮度、对比度,还考虑了结构信息。在本文中,我们将介绍一个使用…

【Vue2+3入门到实战】(13)插槽<slot>详细示例及自定义组件的创建与使用代码示例 详解

目录 一、学习目标1.插槽2.综合案例:商品列表 一、插槽-默认插槽1.作用2.需求3.问题4.插槽的基本语法5.代码示例6.总结 二、插槽-后备内容(默认值)1.问题2.插槽的后备内容3.语法4.效果5.代码示例 三、插槽-具名插槽1.需求2.具名插槽语法3.v-s…

【JAVA】使用OPENGL

从这个网址下载对应的库: LWJGL - Lightweight Java Game Libraryhttps://www.lwjgl.org/browse/release/3.3.3下载这个压缩包(实际上有很多版本3.3.3是比较新的版本:LWJGL - Lightweight Java Game Library): https…

关于log4j的那些坑

背景:工程中同时存在log4j.xml&log4j2.xml maven依赖如下: 此时工程实际使用的日志文件为log4j.xml 1、当同时设置log4j和log4j2的桥接依赖时 maven依赖如下: 此时启动会有警告日志: 点击告警日志链接:https://…

【Vue2 + ElementUI】el-table中校验表单

一. 案例 校验金额 阐述&#xff1a;校验输入的金额是否正确。如下所示&#xff0c;点击【编辑图标】会变为input输入框当&#xff0c;输入金额。当输入框失去焦点时&#xff0c;若正确则调用接口更新金额且变为不可输入状态&#xff0c;否则返回不合法金额提示 <templat…

计算机网络复习4

网络层——点到点 文章目录 网络层——点到点功能路由算法IPV4NAT 网络地址转换子网划分与子网掩码、CIDR地址解析协议ARP&#xff1a;根据IP地址找到MAC地址动态主机配置协议DHCP网际控制报文协议ICMPIPV6内部网关协议&#xff08;IGP&#xff09;外部网关协议(EGP) 功能 异构…

数字人直播系统——打破时空限制的新媒体时代

随着科技的不断进步&#xff0c;新媒体开始逐渐成为人们获取信息和娱乐的首选方式。其中&#xff0c;数字人直播系统作为一种创新的传媒形式&#xff0c;正以其独特的优势受到越来越多人的关注和喜爱。数字人直播系统通过将虚拟人物与现实世界紧密结合&#xff0c;打破了时间和…

SuperMap YashanDB联合解决方案发布,赋能更强大的地理智慧

近期&#xff0c;深圳计算科学研究院&#xff08;简称“深算院”&#xff09;携手超图软件集团&#xff08;简称“超图 ”&#xff09;重磅推出基于崖山数据库的空间数据管理解决方案&#xff0c;基于YashanDB空间数据库能力&#xff0c;与超图SuperMap GIS平台深度适配&#x…