pytorch-手写数字识别之全连接层实现

目录

  • 1. 背景
  • 2. nn.Linear线性层
  • 2. 实现MLP网络
  • 3. train
  • 4. 完整代码

1. 背景

上一篇https://blog.csdn.net/wyw0000/article/details/137622977?spm=1001.2014.3001.5502中实现手撸代码的方式实现了手写数字识别,本文将使用pytorch的API实现。

2. nn.Linear线性层

相当于x = x@w1.t() + b1
因此可以使用nn.Linear代替,上一篇文中中的

w1, b1 = torch.randn(200, 784, requires_grad=True),\torch.zeros(200, requires_grad=True)
x = x@w1.t() + b1

使用nn.Linear定义的三层网络,如下图所示:
在这里插入图片描述
增加激活函数relu
在这里插入图片描述

2. 实现MLP网络

  • 实现__init__函数
  • 将网络各层放到序列化容器Sequential中
    见下图使用nn.Sequential创建一个小的model,运行时输入首先传给nn.Linear(784, 200),nn.Linear(784, 200)的输出再传给nn.ReLU(inplace=True),这样依次传递下去,直至结束。
  • 实现forward
    调用将输入x作为model的参数调用model,并将结果返回。
    在这里插入图片描述

3. train

在这里插入图片描述

4. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.ReLU(inplace=True),nn.Linear(200, 200),nn.ReLU(inplace=True),nn.Linear(200, 10),nn.ReLU(inplace=True),)def forward(self, x):x = self.model(x)return xnet = MLP()
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = net(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/825664.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【感受C++的魅力】:用C++演奏歌曲《起风了》——含完整源码

文章目录 一、运行效果二、代码实现1. 引入部分2. 枚举3. 音色定义4. 演奏速度定义5. 特殊定义6. 模拟风声 三、完整代码 一、运行效果 【C的魅力】&#xff1a;用C演奏歌曲《起风了》 二、代码实现 1. 引入部分 #include <iostream> #include <Windows.h> #prag…

开发一个农场小游戏需要多少钱

开发一个农场小游戏的费用因多个因素而异&#xff0c;包括但不限于游戏的规模、复杂性、功能需求、设计复杂度、开发团队的规模和经验&#xff0c;以及项目的时间周期等。因此&#xff0c;无法给出确切的费用数字。 具体来说&#xff0c;游戏的复杂程度和包含的功能特性数量会直…

企业文档知识库建设,数据安全如何保障?

随着现代市场经济的高速发展&#xff0c;企业的竞争优势越来越多体现在人才和科技的优势。而随着员工流动率的提升&#xff0c;随之流失的则是员工积累多年的宝贵工作经验&#xff0c;如果缺乏有效的内部知识库的建设和管理&#xff0c;企业的竞争优势将难以维系。「企业网盘」…

jQuery 性能优化 —— 学习笔记 详细版

1.总是从 ID 选择器开始继承 在 jQuery 中最快的选择器是 ID 选择器,因为它直接来自于 JavaScript 的 getElementById() 方法。例如有一段 HTML 代码: <div id="content"> <form method="post" action="#"> <h2>交通信号…

每天学习一个Linux命令之ufw

每天学习一个Linux命令之ufw 在Linux系统中&#xff0c;操作防火墙是一项重要工作。ufw&#xff08;Uncomplicated Firewall&#xff09;是一个简单易用的防火墙管理工具&#xff0c;它使得配置和管理防火墙规则变得非常简单。本文将介绍ufw命令的使用方法以及可用的选项。 1…

Claude和chatgpt的区别

ChatGPT是OpenAI开发的人工智能的聊天机器人&#xff0c;它可以生成文章、代码并执行各种任务。是Open AI发布的第一款大语言模型&#xff0c;GPT4效果相比chatgpt大幅提升。尤其是最新版的模型&#xff0c;OpenAI几天前刚刚发布的GPT-4-Turbo-2024-04-09版本&#xff0c;大幅超…

架构设计-流程引擎的架构设计

1、什么是流程引擎 流程引擎是一个底层支撑平台&#xff0c;是为提供流程处理而开发设计的。流程引擎和流程应用&#xff0c;以及应用程序的关系如下图所示。 常见的支撑场景有&#xff1a;Workflow、BPM、流程编排等。本次分享&#xff0c;主要从 BPM 流程引擎切入&#xff0…

【前端】3. CSS【万字长文】

CSS 是什么 层叠样式表 (Cascading Style Sheets). CSS 能够对网页中元素位置的排版进行像素级精确控制, 实现美化页面的效果. 能够做到页面的样式和结构分离. CSS 就是 “东方四大邪术” 之化妆术. 基本语法规范 选择器 {一条/N条声明} 选择器决定针对谁修改 (找谁)声明决…

钉钉直播回放怎么下载到本地

钉钉直播回放如何下载到本地,本文就给大家解密如何下载到本地 工具我已经给大家打包好了 钉钉直播回放下载软件链接&#xff1a;https://pan.baidu.com/s/1_4NZLfENDxswI2ANsQVvpw?pwd1234 提取码&#xff1a;1234 --来自百度网盘超级会员V10的分享 1.首先解压好我给大家…

【Qt】Qt Hello World 程序

文章目录 1、Qt Hello World 程序1.1 使用按钮实现1.1.1 使用可视化方式实现 1.1.2 纯代码方式实现 label创建堆&#xff08;内存泄漏&#xff09;或者栈问题Qt基础类&#xff08;Qstring、Qvector、Qlist&#xff09;乱码问题零散知识 1、Qt Hello World 程序 1.1 使用按钮实…

Swin Transformer 浅析

Swin Transformer 浅析 文章目录 Swin Transformer 浅析引言Swin Transformer 的网络结构W-MSA 窗口多头注意力机制SW-MSA 滑动窗口多头注意力机制Patch Merging 图块合并 引言 因为ViT无法实现CNN中的层次化构建以及局部信息&#xff0c;由此微软团队提出了Swin Transformer来…

C语言(二维数组)

Hi~&#xff01;这里是奋斗的小羊&#xff0c;很荣幸各位能阅读我的文章&#xff0c;诚请评论指点&#xff0c;关注收藏&#xff0c;欢迎欢迎~~ &#x1f4a5;个人主页&#xff1a;小羊在奋斗 &#x1f4a5;所属专栏&#xff1a;C语言 本系列文章为个人学习笔记&#x…

15.7 2011年42题真题讲解

2&#xff0c;4&#xff0c;6&#xff0c;8&#xff0c;11&#xff0c;13&#xff0c;15&#xff0c;17&#xff0c;19&#xff0c;20 可以推出题目的一个隐含条件&#xff1a;偶数个元素的中位数是靠前的那一个 应试技巧&#xff1a;如果实在想不出高效的算法&#xff0c;那…

Linux下跟踪某个进程的内核处理时延消耗情况

1.利用系统自动的trace功能&#xff0c;编辑如下脚本&#xff0c;vim trace_process.sh #!/bin/sh cd /sys/kernel/debug/tracing/ #清空原有跟踪信息 echo > trace echo nop > current_tracer #设置要跟踪的进程 echo "pid281255" echo 281255 > set_ftra…

基于springboot+vue+Mysql的房产销售平台

开发语言&#xff1a;Java框架&#xff1a;springcloudJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a…

【详细讲解CentOS常用的命令】

&#x1f308;个人主页: 程序员不想敲代码啊 &#x1f3c6;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f44d;点赞⭐评论⭐收藏 &#x1f91d;希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出指正&#xff0c;让我们共…

SQLite FTS5 扩展(三十)

返回&#xff1a;SQLite—系列文章目录 上一篇:SQLite的知名用户(二十九) 下一篇&#xff1a;SQLite—系列文章目录 1. FTS5概述 FTS5 是一个 SQLite 虚拟表模块&#xff0c;它为数据库应用程序提供全文搜索功能。在最基本的形式中&#xff0c; 全文搜索引擎允许用户有…

Dinov2 + Faiss 图片检索

MetaAI 通过开源 DINOv2&#xff0c;在计算机视觉领域取得了一个显着的里程碑&#xff0c;这是一个在包含1.42 亿张图像的令人印象深刻的数据集上训练的模型。产生适用于图像级视觉任务&#xff08;图像分类、实例检索、视频理解&#xff09;以及像素级视觉任务&#xff08;深度…

【leetcode面试经典150题】57. 环形链表(C++)

【leetcode面试经典150题】专栏系列将为准备暑期实习生以及秋招的同学们提高在面试时的经典面试算法题的思路和想法。本专栏将以一题多解和精简算法思路为主&#xff0c;题解使用C语言。&#xff08;若有使用其他语言的同学也可了解题解思路&#xff0c;本质上语法内容一致&…

服务的状态

1、服务的状态 有状态的服务&#xff1a; 即服务器保存服务有关的个性化参数。 比如客户登陆后&#xff0c;将 客户的权限信息 保存在服务器&#xff0c;每次拿到客户请求后&#xff0c;服务器从自身的数据存储中取出客户角色信息。判断是否响应客户请求。 无状态的服务&…