pytorch-多分类实战之手写数字识别

目录

  • 1. 网络设计
  • 2. 代码实现
    • 2.1 网络代码
    • 2.2 train
  • 3. 完整代码

1. 网络设计

输入是手写数字图片28x28,输出是10个分类0~9,有两个隐藏层,如下图所示:
在这里插入图片描述

2. 代码实现

2.1 网络代码

第一层将784降维到200,第二次使用200不降维,输出层200降维到10,每一层之后加一个激活函数relu,每一层都需要梯度信息所以requires_grad=True;
forward函数最后不要加softmax,因为后面CrossEntropyLoss中包含了softmax操作。
在这里插入图片描述

2.2 train

优化目标是w1、b1、w2、b2、w3、b3,使用SGD优化器,使用CrossEntropyLoss计算loss
在这里插入图片描述

3. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)w1, b1 = torch.randn(200, 784, requires_grad=True),\torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True),\torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\torch.zeros(10, requires_grad=True)# torch.nn.init.kaiming_normal_(w1)
# torch.nn.init.kaiming_normal_(w2)
# torch.nn.init.kaiming_normal_(w3)def forward(x):x = x@w1.t() + b1x = F.relu(x)x = x@w2.t() + b2x = F.relu(x)x = x@w3.t() + b3x = F.relu(x)return xoptimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criteon = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = forward(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = forward(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

如下图:
未使用torch.nn.init.kaiming_normal_(w1)初始化参数的情况,可以看出Loss在2.302585后就不下降了。
在这里插入图片描述
如下图:使用了torch.nn.init.kaiming_normal_(w1)初始化参数的情况下,Loss下降还是比较快的。
在这里插入图片描述
因此使用好的初始化参数对网络的训练起到至关重要的作用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/811943.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ChatGPT基础(二) ChatGPT的使用和调优

文章目录 ChatGPT的特性采用关键词进行提问给ChatGPT指定身份提升问答质量的策略1.表述方式上的优化2.用"继续"输出长内容3.营造场景4.由浅入深,提升问题质量5.预设回答框架和风格 ChatGPT的特性 1.能够联系上下文进行回答 ChatGPT回答问题是有上下文的&…

uni-app web端使用getUserMedia,摄像头拍照

<template><view><video id"video"></video></view> </template> 摄像头显示在video标签上 var opts {audio: false,video: true }navigator.mediaDevices.getUserMedia(opts).then((stream)> {video document.querySelec…

【python】在pycharm创建一个新的项目

双击打开pycharm,选择create new project 选择create,后进入项目 右键项目根目录,选择new一个新的python file 随意命名一下 输入p 然后后面就会出现智能补全提示,此时轻敲一下tab,代码就写好了,非常的方便 右键执行一下代码,下面两个直接运行和debug运行都是可以的 小结 …

CentOS 8服务器搭建L2TP服务器(over IPsec)操作指南

正文共&#xff1a;1234 字 14 图&#xff0c;预估阅读时间&#xff1a;2 分钟 之前发过把我自己的服务器搬上公网的文章&#xff08;我用100块钱把物理服务器放到了公网&#xff0c;省了几万块&#xff01;&#xff09;&#xff0c;当时L2TP拨号用的是网络上的解决方案&#x…

微服务整合Spring Cloud Gateway动态路由

前置 创建 Spring Cloud项目 参考&#xff1a;创建Spring Cloud Maven工程-CSDN博客 1. 创建一个maven jar类型项目 在idea中右键父工程-》New-》Module 创建一个maven工程 2. 引入相关依赖 在POM文件中引入下面的依赖 <project xmlns"http://maven.apache.org/P…

C++设计模式|创建型 1.单例模式

1.什么是单例模式&#xff1f; 单例模式的的核⼼思想在创建类对象的时候保证⼀个类只有⼀个实例&#xff0c;并提供⼀个全局访问点来访问这个实例。 只有⼀个实例的意思是&#xff0c;在整个应⽤程序中&#xff0c;只存在该类的⼀个实例对象&#xff0c;⽽不是创建多个相同类…

【JAVA基础篇教学】第八篇:Java中List详解说明

博主打算从0-1讲解下java基础教学&#xff0c;今天教学第八篇&#xff1a;Java中List详解说明。 在 Java 编程中&#xff0c;List 接口是一个非常常用的集合接口&#xff0c;它代表了一个有序的集合&#xff0c;可以包含重复的元素。List 接口提供了一系列操作方法&#xff0c;…

72V电瓶电压降5V1.5A恒压WT7039

72V电瓶电压降5V1.5A恒压WT7039 WT6039是一款12V至72V宽电压降压DC-DC转换器芯片&#xff0c;集成了开关控制、参考电源、误差放大器等多重功能&#xff0c;并具备过热、限流和短路保护。它能在广泛输入功率下实现2A连续输出电流&#xff0c;并通过编程调整输出电压。适用于高…

文心一言

文章目录 前言一、首页二、使用总结 前言 今天给大家带来百度的文心一言,它基于百度的文心大模型,是一种全新的生成式人工智能工具。 一、首页 首先要登录才能使用,左侧可以看到以前的聊天历史 3.5的目前免费用,但是4.0的就需要vip了 二、使用 首先在最下方文本框输入你想要搜…

npm问题合集以及npm常规操作整理

问题一&#xff1a;npm install 卡在“sill idealTree buildDeps“ 可以尝试将网络切换为其他网络&#xff0c;我的是这么解决的。可以尝试重新设置镜像。 问题二&#xff1a;npm install 项目依赖报WARN deprecated 原因 依赖版本问题&#xff0c;下载最新版本。 解决方案 …

ElasticSearch中使用bge-large-zh-v1.5进行向量检索(一)

一、准备 系统&#xff1a;MacOS 14.3.1 ElasticSearch&#xff1a;8.13.2 Kibana&#xff1a;8.13.2 BGE是一个常见的文本转向量的模型&#xff0c;在很多大模型RAG应用中常常能见到&#xff0c;但是ElasticSearch中默认没有。BGE模型有很多版本&#xff0c;本次采用的是bg…

vue和react通用后台管理系统权限控制方案

1. 介绍 在任何企业级应用中&#xff0c;尤其是后台管理系统&#xff0c;权限控制是一个至关重要的环节。它确保了系统资源的安全性&#xff0c;防止非法访问和操作&#xff0c;保障业务流程的正常进行。本文件将详细解析后台管理系统中的权限控制机制及其实施策略。 那么权限…

算法思想总结:分治思想

一、颜色划分 . - 力扣&#xff08;LeetCode&#xff09; class Solution { public:void sortColors(vector<int>& nums) {//三路划分的思想int nnums.size();int left-1, rightn,cur0;while(cur<right){if(nums[cur]0) swap(nums[left],nums[cur]);else if(nums…

ChatGPT加持,需求分析再无难题

简介 在实际工作过程中&#xff0c;常常需要拿到产品的PRD文档或者原型图进行需求分析&#xff0c;为产品的功能设计和优化提供建议。 而使用ChatGPT可以很好的帮助分析和整理用户需求。 实践演练 接下来&#xff0c;需要使用ChatGPT 辅助我们完成需求分析的任务 注意&…

GMSSL-通信

死磕GMSSL通信-C/C++系列(一) 最近再做国密通信的项目开发,以为国密也就简单的集成一个库就可以完事了,没想到能有这么多坑。遂写下文章,避免重复踩坑。以下国密通信的坑有以下场景 1、使用GMSSL guanzhi/GmSSL进行通信 2、使用加密套件SM2-WITH-SMS4-SM3 使用心得 ​…

电动汽车原理视频笔记

看到了一个讲的不错的系列视频 新能源维修猿老罗的个人空间-新能源维修猿老罗个人主页-哔哩哔哩视频 新能源汽车上的安全防护系统就是这么多&#xff01;-绝缘检测等_哔哩哔哩_bilibili 新能源汽车居然是这样上电的-高低压上下电详细流程_哔哩哔哩_bilibili

机器学习和深度学习-- 李宏毅(笔记与个人理解)Day 14

Day 14 Classfication (short version) 二分类的时候 用sigmoid 那不就是 logistic 回归嘛&#xff08;softmax 的二分类等价&#xff09; Loss 哦 今天刚学的 &#xff0c;KL散度 &#xff0c;看来cross-entropy 和KL散度是等价的咯~ 我感觉我的直觉没错 这里MSE离得很远的时候…

php未能在vscode识别?

在设置里搜php&#xff0c;找到settings.json&#xff0c;设置你的安装路径即可。 成功

HubSpot如何通过自动化和优化客户服务流程?

在当今竞争激烈的市场环境中&#xff0c;提供卓越的客户服务体验已经成为企业赢得客户忠诚、推动业务增长的关键所在。HubSpot&#xff0c;作为一款领先的客户关系管理软件&#xff0c;通过自动化和优化客户服务流程&#xff0c;为企业带来了革命性的服务体验提升。 HubSpot通…

【opencv】示例-grabcut.cpp 使用OpenCV库的GrabCut算法进行图像分割

left mouse button - set rectangle SHIFTleft mouse button - set GC_FGD pixels CTRLleft mouse button - set GC_BGD pixels 这段代码是一个使用OpenCV库的GrabCut算法进行图像分割的C程序。它允许用户通过交互式方式选择图像中的一个区域&#xff0c;并利用GrabCut算法尝试…