交叉熵Loss多分类问题实战(手写数字)

1、import所需要的torch库和包
在这里插入图片描述
2、加载mnist手写数字数据集,划分训练集和测试集,转化数据格式,batch_size设置为200在这里插入图片描述
3、定义三层线性网络参数w,b,设置求导信息
在这里插入图片描述
4、初始化参数,这一步比较关键,是否初始化影响到数据质量以及后续网络学习效果
在这里插入图片描述
5、自定义三层线性网络
在这里插入图片描述
6、选定优化器激活函数和loss函数
在这里插入图片描述
7、训练及测试,并记录每轮训练的loss变化和在测试集上的效果。第一轮就达到了98的准确度,判断是初始化效果较好,在前几次测试中根据初始化的情况不同,初始准确率为50%-85%不等
在这里插入图片描述
完整代码:

import torch
import torchvision
import torch.nn.functional as Ftrain_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)w1 = torch.randn(200, 784, requires_grad=True)
b1 = torch.randn(200, requires_grad=True)
w2 = torch.randn(200, 200, requires_grad=True)
b2 = torch.randn(200, requires_grad=True)
w3 = torch.randn(10, 200, requires_grad=True)
b3 = torch.randn(10, requires_grad=True)torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)def forward(x):x = x@w1.t() +b1x = F.relu(x)x = x@w2.t() +b2x = F.relu(x)x = x@w3.t() +b3x = F.relu(x)return xoptimizer = torch.optim.Adam([w1, b1, w2, b2, w3, b3], lr=0.001)
criterion = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = forward(data)loss = criterion(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if (batch_idx+1) % 150 == 0:print('Train Epoch:{} [{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(epoch, (batch_idx+1) * len(data), len(train_loader.dataset),100. * (batch_idx+1) / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28*28)logits = forward(data)test_loss += criterion(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader)print('\nTest Set:Average Loss:{:.4f}, Accuracy:{}/{}({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107519.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用内网穿透实现U8用友ERP本地部署并远程访问办公?

文章目录 前言1. 服务器本机安装U8并调试设置2. 用友U8借助cpolar实现企业远程办公2.1 在被控端电脑上,点击开始菜单栏,打开设置——系统2.2 找到远程桌面2.3 启用远程桌面 3. 安装cpolar内网穿透3.1 注册cpolar账号3.2 下载cpolar客户端 4. 获取远程桌面…

VMware使用ubuntu安装增强功能实现自动缩放

VMware使用ubuntu安装增强功能实现自动缩放 1.下载 VMware Tools2.安装tool 1.下载 VMware Tools 1.需要先弹出DVD 2.虚拟机-安装VMware Tools 进入终端 3.把media下的VMware压缩包拷贝到home/下 4.去home下解压 2.安装tool 进入vmware-tools-distrib sudo ./vmware-ins…

G.711语音编解码器详解

语音编解码利用人听觉上的冗余对语音信息进行压缩从而达到节省带宽的目的。值得注意的是,本文说的是语音编解码器,也就Speech codec,而常用的还有另一种编解码器称作音频编解码器,英文是Audio codec,它们的区别如下。 以前在学校的时候研究了很多VoIP的编解码器从G.723到A…

Linux:将mysql数据导入mongodb

mysql和mongodb都要同时开启 进入mysql创建一个数据库为aaa create database aaa; 创建一个tarro表结构为 (id int,name varchar(20)) create table tarro(id int,name varchar(20)); 插入几个数据,等会把这里的数据导过去 insert in…

香港学界呼吁RWA“在港先发”,构建基于港元稳定币的Web3生态!

2023年以来,市场对于RWA(Real World Assets)即真实世界资产“代币化”的讨论愈发频繁,一些观点认为 RWA将在下一轮加密资产牛市中成为焦点,部分Web3创业者和传统金融企业也快速将业务方向瞄准相关赛道,而被…

Sanic​——Python函数变成API的神器

今天给大家介绍一个超好用的框架,迅速将Python函数变成API,它就是最近越来越火的异步Web框架Sanic。 1. Sanic简介 Sanic 是 Python3.7 Web 服务器和 Web 框架,旨在提高性能。它允许使用 Python3.5 中添加的async/await语法,这使…

[牛客习题]“幸运的袋子”

习题链接:幸运的袋子_牛客题霸_牛客网 题目分析 由题意可知:“幸运的袋子”的概念是——小球的数值之和大于小球的数值之积。 假如现在有5个小球:1,1,3,5,7,并将他们编号a0~a4.我们…

【数据结构】链表

⭐ 作者:小胡_不糊涂 🌱 作者主页:小胡_不糊涂的个人主页 📀 收录专栏:浅谈数据结构 💖 持续更文,关注博主少走弯路,谢谢大家支持 💖 链表 1. ArrayList的缺陷2. 链表2.1…

异步使用langchain

文章目录 一.先利用langchain官方文档的AI功能问问二.langchain async api三.串行,异步速度比较 一.先利用langchain官方文档的AI功能问问 然后看他给的 Verified Sources 这个页面里面虽然有些函数是异步函数,但是并非专门讲解异步的 二.langchain asy…

大模型引发“暴力计算”,巨头加速推进液冷“降温”

点击关注 文|姚悦 编|王一粟 一进入部署了液冷服务器的数据中心,不仅没有嘈杂的风扇声,甚至在不开空调的夏日也完全没有闷热感。 在大模型引发“暴力计算”的热潮下,数据中心的上下游,正在加紧推进液冷“…

二十六、【颜色调整】

文章目录 1、色相/饱和度2、色彩平衡3、曲线4、可选颜色 1、色相/饱和度 色相其实就是颜色的亮度,就是我们往颜色里边加白色,白色越多颜色越淡。饱和度就是我们往颜色里边加黑色,黑色越多颜色越浓。如下图,我们调整拾色器里边的颜…

2.1 初探大数据

文章目录 零、学习目标一、导入新课二、新课讲解(一)什么是大数据(二)大数据的特征1、Volume - 数据量大2、Variety - 数据多样3、Velocity - 数据增速快4、Value - 数据价值低5、Veracity - 数据真实性 (三&#xff0…

RabbitMQ消息中间件概述

1.什么是RabbitMQ RabbitMQ是一个由erlang开发的AMQP(Advanced Message Queue )的开源实现。AMQP 的出现其实也是应了广大人民群众的需求,虽然在同步消息通讯的世界里有很多公开标准(如 COBAR的 IIOP ,或者是 SOAP 等&…

当出现“无法成功完成操作,因为文件包含病毒或潜在的垃圾软件“时的解决办法

安装补丁或其他安装包时,被系统识别为病毒垃圾 具体解决步骤是: 1.在开始菜单,打开Windows 安全中心 找到主页的病毒和威胁防护 找到管理设置 最后将确认安全的文件或安装包添加到排除项即可

华为eNSP配置专题-VLAN和DHCP的配置

文章目录 华为eNSP配置专题-VLAN和DHCP的配置1、前置环境1.1、宿主机1.2、eNSP模拟器 2、基本环境搭建2.1、基本终端构成和连接 3、VLAN的配置3.1、两台PC先配置静态IP3.2、交换机上配置VLAN 4、接口方式的DHCP的配置4.1、在交换机上开启DHCP4.2、在PC上开启DHCP 5、全局方式的…

零售数据分析模板鉴赏-品类销售结构报表

不管是服装零售,还是连锁超市或者其他,只要是零售行业就绕不过商品数据分析,那么商品数据分析该怎么做?奥威BI的零售数据分析方案早早就预设好相关报表模板,点击应用后,一键替换数据源,立得新报…

新版Android Studio搜索不到Lombok以及无法安装Lombok插件的问题

前言 在最近新版本的Android Studio中,使用插件时,在插件市场无法找到Lombox Plugin,具体表现如下图所示: 1、操作步骤: (1)打开Android Studio->Settings->Plugins,搜索Lom…

【JVM】JVM的内存区域划分

JVM的内存区域划分 堆Java虚拟机栈程序计数器方法区运行时常量池 堆 程序中创建的所有对象都保存在堆中 Java虚拟机栈 Java虚拟机栈的生命周期和线程相同,描述的是Java方法执行的内存模型,每个方法在执行的时候都会同时创建一个栈帧用于存储局部变量表,操作栈,动态链接,方法…

git log 美化配置

编辑 vim ~/.gitconfig 添加配置 [alias]lg log --graph --abbrev-commit --decorate --dateformat:%m-%d %H:%M:%S --formatformat:%C(bold blue)%h%C(reset) - %s %C(bold yellow)% d%C(reset) %n %C(dim white) (%ad) - %an%C(reset) --allgit lg 效果

微信小程序------框架

目录 视图层 WXML 数据绑定 列表渲染 条件渲染 模板 wsx事件 逻辑层 生命周期 跳转 视图层 WXML WXML(WeiXin Markup Language)是框架设计的一套标签语言,结合基础组件、事件系统,可以构建出页面的结构。 先在我们的项目中…