交叉熵Loss多分类问题实战(手写数字)

1、import所需要的torch库和包
在这里插入图片描述
2、加载mnist手写数字数据集,划分训练集和测试集,转化数据格式,batch_size设置为200在这里插入图片描述
3、定义三层线性网络参数w,b,设置求导信息
在这里插入图片描述
4、初始化参数,这一步比较关键,是否初始化影响到数据质量以及后续网络学习效果
在这里插入图片描述
5、自定义三层线性网络
在这里插入图片描述
6、选定优化器激活函数和loss函数
在这里插入图片描述
7、训练及测试,并记录每轮训练的loss变化和在测试集上的效果。第一轮就达到了98的准确度,判断是初始化效果较好,在前几次测试中根据初始化的情况不同,初始准确率为50%-85%不等
在这里插入图片描述
完整代码:

import torch
import torchvision
import torch.nn.functional as Ftrain_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('mnist_data', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307, ), (0.3081, ))])),batch_size=200, shuffle=True)w1 = torch.randn(200, 784, requires_grad=True)
b1 = torch.randn(200, requires_grad=True)
w2 = torch.randn(200, 200, requires_grad=True)
b2 = torch.randn(200, requires_grad=True)
w3 = torch.randn(10, 200, requires_grad=True)
b3 = torch.randn(10, requires_grad=True)torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)def forward(x):x = x@w1.t() +b1x = F.relu(x)x = x@w2.t() +b2x = F.relu(x)x = x@w3.t() +b3x = F.relu(x)return xoptimizer = torch.optim.Adam([w1, b1, w2, b2, w3, b3], lr=0.001)
criterion = torch.nn.CrossEntropyLoss()for epoch in range(10):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)logits = forward(data)loss = criterion(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if (batch_idx+1) % 150 == 0:print('Train Epoch:{} [{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(epoch, (batch_idx+1) * len(data), len(train_loader.dataset),100. * (batch_idx+1) / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28*28)logits = forward(data)test_loss += criterion(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader)print('\nTest Set:Average Loss:{:.4f}, Accuracy:{}/{}({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107519.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用内网穿透实现U8用友ERP本地部署并远程访问办公?

文章目录 前言1. 服务器本机安装U8并调试设置2. 用友U8借助cpolar实现企业远程办公2.1 在被控端电脑上,点击开始菜单栏,打开设置——系统2.2 找到远程桌面2.3 启用远程桌面 3. 安装cpolar内网穿透3.1 注册cpolar账号3.2 下载cpolar客户端 4. 获取远程桌面…

VMware使用ubuntu安装增强功能实现自动缩放

VMware使用ubuntu安装增强功能实现自动缩放 1.下载 VMware Tools2.安装tool 1.下载 VMware Tools 1.需要先弹出DVD 2.虚拟机-安装VMware Tools 进入终端 3.把media下的VMware压缩包拷贝到home/下 4.去home下解压 2.安装tool 进入vmware-tools-distrib sudo ./vmware-ins…

G.711语音编解码器详解

语音编解码利用人听觉上的冗余对语音信息进行压缩从而达到节省带宽的目的。值得注意的是,本文说的是语音编解码器,也就Speech codec,而常用的还有另一种编解码器称作音频编解码器,英文是Audio codec,它们的区别如下。 以前在学校的时候研究了很多VoIP的编解码器从G.723到A…

Linux:将mysql数据导入mongodb

mysql和mongodb都要同时开启 进入mysql创建一个数据库为aaa create database aaa; 创建一个tarro表结构为 (id int,name varchar(20)) create table tarro(id int,name varchar(20)); 插入几个数据,等会把这里的数据导过去 insert in…

卡尔曼家族从零解剖-(00)目录最新无死角讲解

讲解关于slam一系列文章汇总链接:史上最全slam从零开始,针对于本栏目讲解的 卡尔曼家族从零解剖 链接 :卡尔曼家族从零解剖-(00)目录最新无死角讲解:https://blog.csdn.net/weixin_43013761/article/details/133846882 文末正下方中心提供了本人 联系…

【7-1 CEmployee类的友元函数改名】 武汉理工大学

7-1 CEmployee类的友元函数改名 分数 15 作者 谢颂华 单位 武汉理工大学 定义一个CEmployee类,其中包括姓名、街道地址、城市和邮编等属性,以及带参的构造函数实现初始化、友元函数change_name()和成员函数display()。要求: 1.函数display()显…

香港学界呼吁RWA“在港先发”,构建基于港元稳定币的Web3生态!

2023年以来,市场对于RWA(Real World Assets)即真实世界资产“代币化”的讨论愈发频繁,一些观点认为 RWA将在下一轮加密资产牛市中成为焦点,部分Web3创业者和传统金融企业也快速将业务方向瞄准相关赛道,而被…

Sanic​——Python函数变成API的神器

今天给大家介绍一个超好用的框架,迅速将Python函数变成API,它就是最近越来越火的异步Web框架Sanic。 1. Sanic简介 Sanic 是 Python3.7 Web 服务器和 Web 框架,旨在提高性能。它允许使用 Python3.5 中添加的async/await语法,这使…

深度学习小工具:Linux 环境下的用户命令和脚本顺序执行器

前言 深度学习跑代码的时候,需要跑很多个对比实验,要么开多个窗口并行执行代码,要么就写在一个 .sh 文件里面顺序执行,前面一种并行执行多个任务出结果很慢,而后一种如果想添加任务或者删除某个任务就得全部停止&…

[牛客习题]“幸运的袋子”

习题链接:幸运的袋子_牛客题霸_牛客网 题目分析 由题意可知:“幸运的袋子”的概念是——小球的数值之和大于小球的数值之积。 假如现在有5个小球:1,1,3,5,7,并将他们编号a0~a4.我们…

【05】基础知识:React组件实例三大核心属性 - props

一、props 了解 理解 1、每个组件对象都会有 props(properties的简写)属性 2、组件标签的所有属性都保存在 props 中 作用 通过标签属性从组件外向组件内传递变化的数据 注意 组件内部不要修改 props 数据 二、案例 需求:自定义用来…

算法通关村第一关|青铜|链表笔记

1.理解 Java 如何构造出链表 在 Java 中,我们创建一个链表类,类中应当有两个属性,一个是结点的值 val ,一个是该结点指向的下一个结点 next 。 next 通俗讲是一个链表中的指针,但是在链表类中是一个链表类型的引用变量…

【数据结构】链表

⭐ 作者:小胡_不糊涂 🌱 作者主页:小胡_不糊涂的个人主页 📀 收录专栏:浅谈数据结构 💖 持续更文,关注博主少走弯路,谢谢大家支持 💖 链表 1. ArrayList的缺陷2. 链表2.1…

异步使用langchain

文章目录 一.先利用langchain官方文档的AI功能问问二.langchain async api三.串行,异步速度比较 一.先利用langchain官方文档的AI功能问问 然后看他给的 Verified Sources 这个页面里面虽然有些函数是异步函数,但是并非专门讲解异步的 二.langchain asy…

大模型引发“暴力计算”,巨头加速推进液冷“降温”

点击关注 文|姚悦 编|王一粟 一进入部署了液冷服务器的数据中心,不仅没有嘈杂的风扇声,甚至在不开空调的夏日也完全没有闷热感。 在大模型引发“暴力计算”的热潮下,数据中心的上下游,正在加紧推进液冷“…

【跳槽必备】2023常用手写面试题知识点总结

前言 想想几年前一个月随便投出去一天至少3面试一个月排满面试的场景对于现在已经灭绝了,基本只有外包和驻场有回应,今年很多人都只能猥琐发育,市场上不仅岗位变少,money也少了很多。目前环境的不景气,面试难度也增加…

特种设备怎么运输到国外

特种设备的运输需要考虑多个因素,包括设备的尺寸、重量、敏感度等。以下是一些常用的运输方式: 海运:海运是运输特种设备的主要方式之一,通常采用货运集装箱进行装载。在运输前需要进行妥善包装和固定,以保证设备的安全…

二十六、【颜色调整】

文章目录 1、色相/饱和度2、色彩平衡3、曲线4、可选颜色 1、色相/饱和度 色相其实就是颜色的亮度,就是我们往颜色里边加白色,白色越多颜色越淡。饱和度就是我们往颜色里边加黑色,黑色越多颜色越浓。如下图,我们调整拾色器里边的颜…

2.1 初探大数据

文章目录 零、学习目标一、导入新课二、新课讲解(一)什么是大数据(二)大数据的特征1、Volume - 数据量大2、Variety - 数据多样3、Velocity - 数据增速快4、Value - 数据价值低5、Veracity - 数据真实性 (三&#xff0…

互联网摸鱼日报(2023-10-11)

互联网摸鱼日报(2023-10-11) 36氪新闻 走向平衡:生成式AI的开源与专有模型之争 麦当劳和可乐们最大的威胁,居然是“减肥药” 束从轩5000万“宴请全国”,老乡鸡会去港股吗? 威马汽车回应破产重整 特斯拉电动皮卡,还…