Pytroch搭建全连接神经网络识别MNIST手写数字数据集

编写步骤

之前已经记录国多次的编写步骤了,无需多言。
(1)准备数据集
这里我们使用MNIST数据集,有官方下载渠道。我们直接使用torchvison里面提供的数据读取功能包就行。如果不使用这个,自己像这样子构建也一样。

# 自己构建数据读取模块
#(1) 数据读取模块
class Mydataset(Dataset):def __init__(self,filepath):xy=np.loadtxt(filepath,delimiter=',',dtype=np.float32)self.len=xy.shape[0]self.x_data=torch.from_numpy(xy[:,:-1])self.y_data=torch.from_numpy(xy[:,[-1]])#魔法方法,容许用户通过索引index得到值def __getitem__(self,index):return self.x_data[index],self.y_data[index]def __len__(self):return self.len

这里直接使用torchvison里面的工具

#准备数据集
batch_size = 64
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])trainset = torchvision.datasets.MNIST(root=r'../data/mnist',train=True,download=True,transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)testset = torchvision.datasets.MNIST(root=r'../data/mnist',train=False,download=True,transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

(2) 构建模型
这次我们使用不带dropout的全连接模型

# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(784, 100)self.linear2 = nn.Linear(100, 20)self.linear3 = nn.Linear(20, 10)def forward(self, x):x=x.view(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))x = self.linear3(x)return x

(3) 选择损失和优化器

# 构建模型和损失
model=Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

(4)训练模型

def train(epoch):running_loss = 0.0for batch_idx, (inputs, targets) in enumerate(trainloader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()#需要将张量转换为浮点数运算running_loss += loss.item()if batch_idx % 100 == 0:print('Train Epoch: {}, Loss: {:.6f}'.format(epoch, loss.item()))running_loss = 0

(5)测试模型

def test(epoch):correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(testloader):outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct=correct+(predicted.eq(targets).sum()*1.0)print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total))

全部代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
#准备数据集
batch_size = 64
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])trainset = torchvision.datasets.MNIST(root=r'../data/mnist',train=True,download=True,transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)testset = torchvision.datasets.MNIST(root=r'../data/mnist',train=False,download=True,transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)# 定义模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.linear1 = nn.Linear(784, 100)self.linear2 = nn.Linear(100, 20)self.linear3 = nn.Linear(20, 10)def forward(self, x):x=x.view(x.size(0), -1)x = F.relu(self.linear1(x))x = F.relu(self.linear2(x))x = self.linear3(x)return x
# 构建模型和损失
model=Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)def train(epoch):running_loss = 0.0for batch_idx, (inputs, targets) in enumerate(trainloader):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()#需要将张量转换为浮点数运算running_loss += loss.item()if batch_idx % 100 == 0:print('Train Epoch: {}, Loss: {:.6f}'.format(epoch, loss.item()))running_loss = 0
def test(epoch):correct = 0total = 0with torch.no_grad():for batch_idx, (inputs, targets) in enumerate(testloader):outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += targets.size(0)correct=correct+(predicted.eq(targets).sum()*1.0)print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct/total))
if __name__ == '__main__':for epoch in range(10):train(epoch)test(epoch)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/74805.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 基本数据类型 vs 包装类(引用数据类型)

一、核心概念对比(以 int vs Integer 为例) 特性基本数据类型(int)包装类(Integer)数据类型原始值(Primitive Value)对象(Object)默认值0null内存位置栈&…

什么是 强化学习(RL):以DQN、PPO等经典模型

什么是 强化学习(RL):以DQN、PPO等经典模型 DQN(深度 Q 网络)和 PPO(近端策略优化)共同属于强化学习(Reinforcement Learning,RL)这一领域。强化学习是机器学习中的一个重要分支,其核心在于智能体(Agent)通过与环境进行交互,根据环境反馈的奖励信号来学习最优的…

【Sql Server】在SQL Server中生成雪花ID(Snowflake ID)

大家好,我是全栈小5,欢迎来到《小5讲堂》。 这是《Sql Server》系列文章,每篇文章将以博主理解的角度展开讲解。 温馨提示:博主能力有限,理解水平有限,若有不对之处望指正! 目录 前言认识雪花ID…

HTML 表单处理进阶:验证与提交机制的学习心得与进度(一)

引言 在前端开发的广袤领域中,HTML 表单处理堪称基石般的存在,是构建交互性 Web 应用不可或缺的关键环节。从日常频繁使用的登录注册表单,到功能多样的搜索栏、反馈表单,HTML 表单如同桥梁,紧密连接着用户与 Web 应用…

C# CancellationTokenSource CancellationToken Task.Run传入token 取消令牌

基本使用方法创建 CancellationTokenSource获取 CancellationToken将 CancellationToken 传递给任务***注意*** 在任务中检查取消状态请求取消处理取消异常 高级用法设置超时自动取消或者使用 CancelAfter 方法关联多个取消令牌注册回调 注意事项 CancellationTokenSource 是 …

Git 之配置ssh

1、打开 Git Bash 终端 2、设置用户名 git config --global user.name tom3、生成公钥 ssh-keygen -t rsa4、查看公钥 cat ~/.ssh/id_rsa.pub5、将查看到的公钥添加到不同Git平台 6、验证ssh远程连接git仓库 ssh -T gitgitee.com ssh -T gitcodeup.aliyun.com

cli命令编写

新建文件夹 template-cli template-cli下运行 npm init生成package.json 新建bin文件夹和index.js文件 编写index.js #! /usr/bin/env node console.log(hello cli)package.json增加 bin 字段注册命令template-cli template-cli命令对应执行的内容文件 bin/index.js 运行 n…

vue3自定义动态锚点列表,实现本页面锚点跳转效果

需求&#xff1a;当前页面存在多个模块且内容很长时&#xff0c;需要提供一个锚点列表&#xff0c;可以快速查看对应模块内容 实现步骤&#xff1a; 1.每个模块添加唯一id&#xff0c;添加锚点列表div <template><!-- 模块A --><div id"modalA">…

L2TP实验

一、实验拓扑 二、实验内容 手工部署IPec VPN 三、实验步骤 1、配置接口IP和安全区域 [PPPoE Client]firewall zone trust [PPPoE Client-zone-trust]add int g 1/0/0[NAS]firewall zone untrust [NAS-zone-untrust]add int g 1/0/1 [NAS]firewall zone trust [NAS-zon…

青少年编程与数学 02-012 SQLite 数据库简介 01课题、数据库概要

青少年编程与数学 02-012 SQLite 数据库简介 01课题、数据库概要&#xff09; 一、特点二、功能 课题摘要:SQLite 是一种轻量级的嵌入式关系型数据库管理系统。 一、特点 轻量级 它不需要单独的服务器进程来运行。不像 MySQL 或 PostgreSQL 这样的数据库系统需要一个专门的服务…

分布式系统面试总结:3、分布式锁(和本地锁的区别、特点、常见实现方案)

仅供自学回顾使用&#xff0c;请支持javaGuide原版书籍。 本篇文章涉及到的分布式锁&#xff0c;在本人其他文章中也有涉及。 《JUC&#xff1a;三、两阶段终止模式、死锁的jconsole检测、乐观锁&#xff08;版本号机制CAS实现&#xff09;悲观锁》&#xff1a;https://blog.…

Ubuntu 系统上完全卸载 Docker

以下是在 Ubuntu 系统上完全卸载 Docker 的分步指南 一.卸载验证 二.卸载步骤 1.停止 Docker 服务 sudo systemctl stop docker.socket sudo systemctl stop docker.service2.卸载 Docker 软件包 # 移除 Docker 核心组件 sudo apt-get purge -y \docker-ce \docker-ce-cli …

Postman 版本信息速查:快速定位版本号

保持 Postman 更新至最新版本是非常重要的&#xff0c;因为这能让我们享受到最新的功能&#xff0c;同时也保证了软件的安全性。所以&#xff0c;如何快速查看你的 Postman 版本信息呢&#xff1f; 如何查看 Postman 的版本信息教程

EF Core 异步方法

文章目录 前言一、为什么使用异步方法二、核心异步方法1&#xff09;查询数据2&#xff09;保存数据3&#xff09;事务处理 三、异步查询最佳实践1&#xff09;始终使用 await2&#xff09;组合异步操作3&#xff09;并行查询&#xff08;谨慎使用&#xff09; 四、异常处理五、…

装饰器模式介绍和典型实现

装饰器模式&#xff08;Decorator Pattern&#xff09;是一种结构型设计模式&#xff0c;它允许你通过将对象放入包含行为的特殊封装对象中来为原对象添加新的功能。装饰器模式的主要优点是可以在运行时动态地添加功能&#xff0c;而不需要修改原对象的代码。这使得代码更加灵活…

【 <二> 丹方改良:Spring 时代的 JavaWeb】之 Spring Boot 中的日志管理:Logback 的集成

<前文回顾> 点击此处查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、开篇整…

神经网络知识点整理

目录 ​一、深度学习基础与流程 二、神经网络基础组件 三、卷积神经网络&#xff08;CNN&#xff09;​编辑 四、循环神经网络&#xff08;RNN&#xff09;与LSTM 五、优化技巧与调参 六、应用场景与前沿​编辑 七、总结与展望​编辑 一、深度学习基础与流程 机器学习流…

【sql优化】where 1=1

文章目录 where 11问题描述错误实现正确实现性能对比测试 where 11 问题描述 在动态 SQL 拼接场景中&#xff0c;开发者常使用 WHERE 11 简化条件拼接逻辑&#xff08;避免处理首个条件的 AND&#xff09;。理论上&#xff0c;数据库优化器会忽略 11&#xff0c;但字符串拼接…

车载以太网网络测试 -24【SOME/IP概述】

目录 1 摘要2 车载SOME/IP 概述2.1发展背景以及应用2.1.1车载 SOME/IP 背景2.1.2 车载 SOME/IP 应用场景 2.3 什么是SOME/IP2.3.1 SOME/IP定义2.3.2 SOME/IP在协议栈中的位置 3 SOA是什么4 SOME/IP主要功能5 SOME/IP标准 1 摘要 本文主要介绍SOME/IP的背景以及在车载行业的发展…

vue3中,route4,获取当前页面路由的问题

首先应用场景如下&#xff1a; 在main.js里面&#xff0c;引入的是路由的配置文件&#xff0c;如下&#xff1a; import {router} from /router; app.use(router); 路由配置文件router.js如下&#xff1a; import { createRouter, createWebHistory } from vue-router; imp…