神经网络模型底层原理与实现10-softmax的实现

import torch
from IPython import display
from d2l import torch as d2l

batch_size=256

#定义训练和验证数据集
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)

#参数初始化,把输入图片看成长度784的向量,这个数据集有十个类别,输出为10
num_inputs=784
num_outputs=10
w=torch.normal(0,0.01,size=(num_inputs,num_outputs),requires_grad=True)
b=torch.zeros(num_outputs,requires_grad=True)

#实现softmax函数
def softmax(X):
X_exp=torch.exp(X)
partition=X_exp.sum(1,keepdim=True)#保持输出维度,使它还是一个矩阵,0是按列求和,1是按行求和

#实现softmax回归模型
def net(X):
return softmax(torch.matmul(X.reshape(-1,w.shape[0]),w)+b)#matmul是矩阵乘法

#实现交叉熵损失函数
def cross_entropy(y_hat,y):#公式是-y*log(y_hat)
return -torch.log(y_hat[range(len(y_hat)),y])#log是以e为底的对数,根据前面推的公式,【】内是取出对应元素值

#将预测类别与真实类别比较,这里开始进入测试部分
def accuracy(y_hat,y):
if len(y_hat.shape)>1 and y_hat.shape[1]>1:
y_hat=y_hat.argmax(axis=1)#选出每行中最大的,也就是分类的类别
cmp=y_hat.type(y.dtype)==y
return float(cmp.type(y.dtype).sum())

#按照accuracy的思路,可以写出模型结果准确率计算函数,分子分母不断累加正确的个数和总的个数
def evaluate_accuracy(data_iter, net):
acc_sum, n = 0.0, 0
for X, y in data_iter:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()#item将tensor类型转为数据类型
n += y.shape[0]
return acc_sum / n

#softmax训练过程
def train_ch3(net, train_iter, test_iter, loss, num_epochs,batch_size,params=None, lr=None, optimizer=None):#num_epochs训练次数,lr学习率
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for X, y in train_iter:
y_hat = net(X)#进入网络
l = loss(y_hat, y).sum()#求损失

# 梯度清零
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_()

l.backward()#反向传播
if optimizer is None:
d2l.sgd(params, lr, batch_size)
else:
optimizer.step()
train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) ==y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n,test_acc))

总结一下:写一个深度学习算法的底层就是写它的模型、损失函数和评价函数

最终输出的结果:

epoch 1, loss 0.7878, train acc 0.749, test acc 0.794
epoch 2, loss 0.5702, train acc 0.814, test acc 0.813
epoch 3, loss 0.5252, train acc 0.827, test acc 0.819
epoch 4, loss 0.5010, train acc 0.833, test acc 0.824
epoch 5, loss 0.4858, train acc 0.836, test acc 0.815

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/816656.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小蚕爬树问题

小蚕爬树问题 问题描述: 编写一个函数 int day(int k,int m,int n),其功能是:返回小蚕需要多少天才能爬到树顶(树高 k 厘米,小蚕每天白天向上爬 m 厘米,每天晚上下滑 n 厘米,爬到树顶后不再下滑&#xff0…

(六)C++自制植物大战僵尸游戏关卡数据讲解

植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/xjvbb 游戏关卡数据文件定义了游戏中每一个关卡的数据,包括游戏类型、关卡通关奖励的金币数量、僵尸出现的波数、每一波出现僵尸数量、每一波僵尸出现的类型等。根据不同的游戏类型,定义了不同的通…

kafka ----修改log4j、jmx、jvm参数等

1、修改log4j 日志路径 在kafka-run-class.sh文件中修改如下配置,将 LOG_DIR变量指定为自己想要存储的路径 # Log directory to use if [ "x$LOG_DIR" "x" ]; thenLOG_DIR"$base_dir/logs" fi2、修改jmx参数 在kafka-run-class.s…

C++11 数据结构3 线性表的循环链式存储,实现,测试

上一节课,我们学了线性表 单向存储结构(也就是单链表),这个是企业常用的技术,且是后面各种的基本,一定要牢牢掌握,如果没有掌握,下面的课程会云里雾里。 一 ,循环链表 1…

stm32报错问题集锦

PS:本文负责记录本人日常遇到的报错问题,以及问题描述、原因以及解决办法等,解决办法百分百亲测有效。本篇会不定期更新,更新频率就看遇到的问题多不多了 更换工程芯片型号 问题描述 例程最开始用的芯片型号是STM32F103VE&#…

c++11 标准模板(STL)本地化库 - 平面类别(std::codecvt) - 在字符编码间转换,包括 UTF-8、UTF-16、UTF-32 (四)

本地化库 本地环境设施包含字符分类和字符串校对、数值、货币及日期/时间格式化和分析,以及消息取得的国际化支持。本地环境设置控制流 I/O 、正则表达式库和 C 标准库的其他组件的行为。 平面类别 在字符编码间转换,包括 UTF-8、UTF-16、UTF-32 std::…

C++设计模式探讨(1)-工厂模式

设计模式是架构设计之术,但不是道。是技巧和方法论,但不是核心思想。设计模式是世界规律的提取,但无法体现具体的形式。所以,设计模式是一个尴尬的存在,它有一定的价值,但却又十分有限。架构设计者需要做的…

IOS 短信拦截插件

在使⽤iOS设备的时候, 我们经常会收到1069、1065开头的垃圾短信, 如果开了iMessage会更严重, 各种乱七⼋糟的垃圾信息会时不时地收到。 从iOS11开始, ⼿机可以⽀持恶短信拦截插件了. 我们可以通过该插件添加⼀些规则通过滤这些不需要的信息. ⼀. 使⽤xcode新建⼀个项⽬ 【1】…

浦大喜奔APP8.0智能升级,发力数字金融深化五大金融篇章服务

1. 浦大喜奔立足科技赋能持续迭代升级,筑牢用户体验护城河 浦发信用卡中心坚持数字科技与客户体验双轮驱动,以科技赋能发展,优化整体系统性能,全方位支撑浦大喜奔 APP提高线上客户服务能力与体验,积极服务民生消费&a…

pyqt和opencv结合01:读取图像、显示

在这里插入图片描述 1 、opencv读取图像用于pyqt显示 # image cv2.imread(file_path)image cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# 将图像转换为 Qt 可接受的格式height, width, channel image.shapebytes_per_line 3 * widthq_image QImage(image.data, width, hei…

Tomcat源码解析——Tomcat的启动流程

一、启动脚本 当我们在服务启动Tomcat时,都是通过执行startup.sh脚本启动。 在Tomcat的启动脚本startup.sh中,最终会去执行catalina.sh脚本,传递的参数是start。 在catalina.sh脚本中,前面是环境判断和初始化参数,最终…

MES生产管理系统:私有云、公有云与本地化部署的比较分析

随着信息技术的迅猛发展,云计算作为一种新兴的技术服务模式,已经深入渗透到企业的日常运营中。在众多部署方式中,私有云、公有云和本地化部署是三种最为常见的选择。它们各自具有独特的特点和适用场景,并在不同程度上影响着企业的…

kafka---broker相关配置

一、Broker 相关配置 1、一般配置 broker.id 当前kafka服务的sid(server id),在kafka集群中,该值是唯一的(unique),如果未设置此值,kafka会自动生成一个int值;为了防止自动生成的值与用户设置…

.net框架和c#程序设计第三次测试

目录 一、测试要求 二、实现效果 三、实现代码 一、测试要求 二、实现效果 数据库中的内容&#xff1a; 使用数据库中的账号登录&#xff1a; 若不是数据库中的内容&#xff1a; 三、实现代码 login.aspx文件&#xff1a; <% Page Language"C#" AutoEventW…

MybatisX的使用

MyBatisX 是一个基于 IntelliJ IDEA 平台的 MyBatis 开发插件&#xff0c;它提供了一系列的功能来简化 MyBatis 开发过程&#xff0c;包括 SQL 代码自动补全、SQL 语句格式化、Mapper 接口和 XML 配置的跳转等。让我为你详细介绍一下 MyBatisX 插件的使用方法&#xff1a; 1. …

8:系统开发基础--8.5:系统设计、8.6:系统测试 、8.7:软件维护 、8.8:软件质量保证、8.9:软件文档

转上一节&#xff1a; http://t.csdnimg.cn/X0GjWhttp://t.csdnimg.cn/X0GjW 8.5&#xff1a;系统设计 考点1&#xff1a;系统设计概述 1&#xff1a;软件设计的任务与活动 体系结构设计&#xff1a;定义软件系统各主要部件之间的关系。 数据设计&#xff1a;基于E-R图确定…

yolov7直接调用zed相机实现三维测距(python)

yolov7直接调用zed相机实现三维测距(python) 1. 相关配置2. 相关代码3. 源码下载相关链接 此项目直接调用zed相机实现三维测距,无需标定,相关内容如下: 1. yolov4直接调用zed相机实现三维测距 2.yolov5直接调用zed相机实现三维测距(python) 3. yolov8直接调用zed相机实…

OpenHarmony实战开发-异步并发概述 (Promise和async/await)。

Promise和async/await提供异步并发能力&#xff0c;是标准的JS异步语法。异步代码会被挂起并在之后继续执行&#xff0c;同一时间只有一段代码执行&#xff0c;适用于单次I/O任务的场景开发&#xff0c;例如一次网络请求、一次文件读写等操作。 异步语法是一种编程语言的特性&…

MongoDB聚合运算符:$rand

MongoDB聚合运算符&#xff1a;$rand 文章目录 MongoDB聚合运算符&#xff1a;$rand语法举例生成随机数据点从集合中随机选择条目 $rand聚合运算符用于返回一个0~1之间的随机浮点数。 语法 { $rand: {} }$rand运算符不需要任何参数。每次调用$rand都会返回一个小数点后最多17位…

java面向对象.day21(继承02--super)

说明 super父 this当前 使用super时&#xff0c;首先要继承父类&#xff0c;其次是在子类里面才能使用super。 继承父类后&#xff0c;运行子类时会同时调用父类的构造方法&#xff0c;如果要显性调用父类的构造方法必须在子类的第一行调用。 单使用super()表示调用父类构造…