softmax从零开始实现

softmax从零开始实现

  • 代码
  • 结果

代码

import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils import data# H,W,C -> C,H,W
mnist_train = torchvision.datasets.FashionMNIST(root="./data", train=True, download=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root="./data", train=False, download=True,transform=transforms.ToTensor())
batch_size = 256
# 随机读取⼩批量
train_loader = data.DataLoader(mnist_train, batch_size, shuffle=True)
test_loader = data.DataLoader(mnist_test, batch_size, shuffle=True)# feature, label = mnist_train[0]
# print(feature.shape, label) # torch.Size([1, 28, 28]) 9num_inputs = 784
num_outputs = 10def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True)  # 按行return X_exp / partition  # 这⾥应⽤了⼴播机制def net(X):return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)def cross_entropy(y_hat, y):return - torch.log(y_hat.gather(1, y.view(-1, 1)))def sgd(params, lr, batch_size):for param in params:param.data -= lr * param.grad / batch_size  # 注意这⾥更改param时⽤的param.datadef accuracy(y_hat, y):return (y_hat.argmax(dim=1) == y).float().mean().item()W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)
b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_()
b.requires_grad_()num_epochs, lr = 10, 0.1
loss = cross_entropy
optimizer = sgd
for epoch in range(1, 1 + num_epochs):total_loss = 0.0train_sample = 0.0train_acc_sum = 0for x, y in train_loader:y_hat = net(x)l = loss(y_hat, y) # 256,1# 梯度清零l.sum().backward()sgd([W, b], lr, batch_size)  # 使用参数的梯度更新参数W.grad.data.zero_()b.grad.data.zero_()total_loss += l.sum().item()train_sample += y.shape[0]train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()print('epoch %d, loss %.4f, train acc %.3f' % (epoch, total_loss / train_sample, train_acc_sum / train_sample,))with torch.no_grad():total_loss = 0.0test_sample = 0.0test_acc_sum = 0for x, y in test_loader:y_hat = net(x)l = loss(y_hat, y)  # 256,1total_loss += l.sum().item()test_sample += y.shape[0]test_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()print('loss %.4f, test acc %.3f' % (total_loss / test_sample, test_acc_sum / test_sample,))

结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/865385.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java静态代理-被代理对象,代理对象的概念(图+代码解释)

案例是老师类,这个老师生病请假了,需要请另外一个老师临时帮忙,这个过来帮忙的老师就是代理对象,生病的老师就是被代理对象,其中我们需要代理对象和被代理对象都implement这个ITeacherDao接口,实现里面的te…

8款你不一定知道的良心软件!

AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频https://aitools.jurilu.com/我们使用一些流行的软件的时候,往往会忽略一些功能非常强大的软件,因为这些软件的众 多,都因为看不见而丢失&a…

udp发送数据如果超过1个mtu时,抓包所遇到的问题记录说明

最近在测试Syslog udp发送相关功能,测试环境是centos udp头部的数据长度是2个字节,最大传输长度理论上是65535,除去头部这些字节,可以大概的说是64k。 写了一个超过64k的数据(随便用了一个7w字节的buffer)发送demo,打…

java顺序查找

其中有一个常用的编程思想: 由于是遍历查找,不能用if-else来输出没有找到,而应该设置一个索引index,如果找到就将index的值设置成下标的值,如果遍历结束后index仍为初始值,才是没有找到 //2024.07.03impor…

从百数教学看产品设计:掌握显隐规则,打造极致用户体验

字段显隐规则允许通过一个控件(如复选框、单选按钮或下拉菜单)来控制其他控件(如文本框、日期选择器等)和标签页(如表单的不同部分)的显示或隐藏。 这种规则通常基于用户的选择或满足特定条件来触发&#…

龙迅 国产原装 低成本高性能转换器 Type-C with 2lane@8.1Gbps/lane 4K60

2.一般说明 LT8711UXE1是一款高性能的Type-C/DP1.2至HDMI2.0转换器,设计用于将USBType-C源或DP1.2源连接至HDMI2.0收发器。该LT8711UXE1集成了一个DP1.2兼容接收器,和一个HDMI2.0兼容发射器。此外,还包括用于CC通信的两个CC控制器&#xff0c…

红酒与建筑:品味历史与艺术的交汇

在时间的长河中,红酒与建筑都是人类智慧的结晶,它们各自承载着历史的厚重与艺术的韵味。当这两者交汇时,仿佛是一场穿越时空的对话,将我们带入一个既古老又现代、既深沉又温柔的世界。今天,就让我们一起走进这个奇妙的…

PMP报考条件是什么?很多人都没读懂...

最近正值8月份考试报名期,想计划考8月份考试的宝子可以准备起来了,下面是报名时间和考试安排 8月考试时间安排: 👉报名时间在7.9日—12日 👉考试时间在8.31日(周六) 一、PMP报名条件是什么&am…

windows环境下socket使用

环境&#xff1a;vscodecmake 源码中引用头文件&#xff1a; #include <winsock2.h> #include <ws2tcpip.h> #pragma comment(lib, "ws2_32.lib") #pragma comment(lib, "wsock32.lib") 需要先调用&#xff1a;WSAStartup初始化&#xf…

炎黄数智人:万科集团——智能催收专员‘崔筱盼’,引领财务管理数字化转型

在数字化时代的浪潮中&#xff0c;人工智能&#xff08;AI&#xff09;技术的飞速发展正深刻改变着商业世界的面貌。万科集团&#xff0c;作为中国房地产行业的翘楚&#xff0c;一直致力于探索和实践最前沿的科技创新。此次&#xff0c;万科集团推出的数字员工“崔筱盼”&#…

十 .pfc,bus纹波分析与抑制方法

以apfc为例 在分析时用 uin 和 iin 表示输入电压和输入电流&#xff0c;uo 和 io&#xff0c;表示输出电压和输出电流&#xff0c;Uin 和 Iin 表示输入电压和输入电流的幅值&#xff0c;则输入电压和输入电流可以分别表示为&#xff1a; 从式&#xff08;3-3&#xff09;可以…

c->c++(二):class

本文主要探讨C类的相关知识。 构造和析构函数 构造函数(可多个)&#xff1a;对象产生时调用初始化class属性、分配class内部需要的动态内存 析构函数&#xff08;一个&#xff09;&#xff1a;对对象消亡时调用回收分配动态内存 C提供默认构造和析构,…

AI是在帮助开发者还是取代他们

目录 1.概述 1.1.AI助力开发者 1.2.AI对开发者的挑战 2.AI工具现状 2.1. GitHub Copilot 2.2. TabNine 2.3.小结 3.AI对开发者的影响 3.1.对开发者的影响 3.2.开发者需要掌握的新技能 3.3.在AI辅助的环境中保持竞争力的策略 4.AI开发的未来 5.总结 1.概述 生成式…

OA系统多少钱一套 用低代码开发OA系统需要多少钱

在数字化时代&#xff0c;企业对办公自动化(OA)系统的需求日益增长&#xff0c;以提高工作效率和优化管理流程。低代码开发平台以其快速开发和部署的能力&#xff0c;成为构建OA系统的热门选择。本文将介绍低代码开发OA系统的成本效益&#xff0c;并以白码低代码平台为例&#…

C# 类型转换之显式和隐式

文章目录 1、显式类型转换2. 隐式类型转换3. 示例4. 类型转换的注意事项5. 类型转换的应用示例总结 在C#编程中&#xff0c;类型转换是一个核心概念&#xff0c;它允许我们在程序中处理不同类型的数据。类型转换可以分为两大类&#xff1a;显式类型转换&#xff08;Explicit Ca…

MAVEN 重新配置参考

【笔记04】下载、配置 MAVEN&#xff08;配置 MAVEN 本地仓库&#xff09;&#xff08;MAVEN 的 setting.xml&#xff09;-阿里云开发者社区 windows 系统环境变量 MAVEN_HOME 也可以改一下

如何对GD32 MCU进行加密?

GD32 MCU有哪些加密方法呢&#xff1f;大家在平时项目开发的过程中&#xff0c;最后都可能会面临如何对出厂产品的MCU代码进行加密&#xff0c;避免产品流向市场被别人读取复制。 下面为大家介绍GD32 MCU所支持的几种常用的加密方法&#xff1a; 首先GD32 MCU本身支持防硬开盖…

Q-Vision新功能发布 | CANReplay-enable发送

Q-Vision是一款网络分析与ECU测试工具软件&#xff0c;支持CAN&#xff08;FD&#xff09;、LIN、以太网、LVDS等车载网络标准&#xff0c;以及CCP/XCP/UDS/OBD等协议&#xff0c;并能导入DBC/LDF/ARXML/A2L/ODX等格式的数据库。 使用Q-Vision可实现对多种总线网络的在线记录、…

基于Springboot的人格障碍诊断系统

结构图&#xff1a; 效果图&#xff1a; 后台&#xff1a; 前台:

基于STM32的智能仓储温湿度监控系统

目录 引言环境准备智能仓储温湿度监控系统基础代码实现&#xff1a;实现智能仓储温湿度监控系统 4.1 数据采集模块4.2 数据处理与分析4.3 控制系统实现4.4 用户界面与数据可视化应用场景&#xff1a;温湿度监控与管理问题解决方案与优化收尾与总结 1. 引言 智能仓储温湿度监…