softmax实现

import matplotlib.pyplot as plt
import torch
from IPython import display
from d2l import torch as d2lbatch_size = 256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
test_iter.num_workers = 0
train_iter.num_workers = 0
num_inputs = 784   # 将图片数据拉伸成一个向量  28*28=784
num_outputs = 10   # 类别数量w = torch.normal(0,0.01,size = (num_inputs,num_outputs),requires_grad=True)
b = torch.zeros(num_outputs,requires_grad=True)
def softmax(x):x_exp = torch.exp(x)partition = x_exp.sum(1,keepdim=True)return x_exp/partition   # 使用了广播机制 使得矩阵所有元素均大于0,且可解释为概率
# 验证softmax
x = torch.normal(0,1,(2,5))
x_prob = softmax(x)
x_prob,x_prob.sum(1)
# 实现softmax回归模型,得到可解释为概率的张量
def net(x):
#     x.reshape为268*784的矩阵return softmax(torch.matmul(x.reshape((-1,w.shape[0])),w)+b)
# 拿出预测索引,其中包含两个样本在三个类别的预测
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1,0.3,0.6],[0.3,0.2,0.5]])y_hat[[0,1],y]"""[0,1]指的是真实样本的下标,对于第0个样本,拿出y[0]样本类别的预测值,
对于第1个样本,拿出y[1]样本类别的预测值。拿出真实标号类的预测值。"""# 交叉熵损失函数
def cross_entropy(y_hat,y):return -torch.log(y_hat[range(len(y_hat)),y])cross_entropy(y_hat,y)
# 比较预测值和真实y
def accuracy(y_hat,y):if len(y_hat.shape)>1 and y_hat.shape[1]>1:# 元素最大的那个下表存到y_hat里面y_hat = y_hat.argmax(axis=1)#把y_hat转为y的数据类型再与y做比较,存入cmpcmp = y_hat.type(y.dtype)==y#返回预测正确的aggravatereturn float(cmp.type(y.dtype).sum())
accuracy(y_hat,y)/len(y)
def evaluate_accuracy(net,data_iter):"""计算指定数据集上的精度"""if isinstance(net,torch.nn.Module):"""将模型设置为评估模式"""net.eval()"""正确预测数,预测总数"""metric = Accumulator(2)for x,y in data_iter:metric.add(accuracy(net(x),y),y.numel())return metric[0] / metric[1]
class Accumulator:"""在n个变量上累加"""def __init__(self,n):self.data = [0,0]*ndef add(self,*args):self.data = [a+float(b) for a,b in zip(self.data,args)]def reset(self):self.data = [0.0]*len(self.data)def __getitem__(self,idx):return self.data[idx]evaluate_accuracy(net,test_iter)
# softmax回归训练
def train_epoch_ch3(net,train_iter,loss,updater):if isinstance(net,torch.nn.Module):net.train()"""长度为3的迭代器来累加信息"""metric = Accumulator(3)for x,y in train_iter:y_hat = net(x)l = loss(y_hat,y)if isinstance(updater,torch.optim.Optimizer):
#     梯度置0updater.zero_grad()
#     计算梯度l.backward()
#     更新参数updater.step()
#metric.add(float(l)*len(y),accuracy(y_hat,y),y.size().numel())else:l.sum().backward()updater(x.shape[0])metric.add(float(l.sum()),accuracy(y_hat,y),y.numel())
#      返回的是损失,所有loss的累加除以样本总数,  分类正确是样本数除以样本总数return metric[0]/metric[2],metric[1]/metric[2]
class Animator:  #save"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []d2l.use_svg_display()self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: d2l.set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()d2l.plt.draw()d2l.plt.pause(0.001)display.display(self.fig)display.clear_output(wait=True)
# 训练函数
def train_ch3(net,train_iter,test_iter,loss,num_epochs,updater):animator = Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0.3, 0.9],legend=['train loss', 'train acc', 'test acc'])for epoch in range(num_epochs):train_metrics = train_epoch_ch3(net, train_iter, loss, updater)test_acc = evaluate_accuracy(net, test_iter)animator.add(epoch + 1, train_metrics + (test_acc,))train_loss, train_acc = train_metricslr = 0.1
def updater(batch_size):return d2l.sgd([w,b],lr,batch_size)
# 训练模型10个迭代周期
num_epochs = 10
train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,updater)
d2l.plt.show()

一开始不出图,后来 再add函数中加

d2l.plt.draw()
d2l.plt.pause(0.001)

最后加d2l.plt.show()

参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/192429.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】-日志系统

一、背景介绍 MySQL中提供了各种各样的日志,每一个日志在不同的阶段有不同的作用,对数据的一致性和正确性得到保障,为数据恢复也提供至关重要的作用,那今天我们一起来讨论讨论MySQL中的各个日志 二、正文 binlog:…

NIO--07--Java lO模型详解

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 何为 IO?先从计算机结构的角度来解读一下I/o.再从应用程序的角度来解读一下I/O 阻塞/非阻塞/同步/异步IO阻塞IO非阻塞IO异步IO举例 Java中3种常见的IO模型BIO (Blo…

Redis缓存的使用

什么是缓存 缓存就是数据交换的缓冲区,是存储数据的临时地方,一般读写性能较高。 缓存的作用: 降低后端负载提高读写效率,降低响应时间 缓存的成本: 数据一致性成本代码维护成本运维成本 Redis特点 键值型数据库…

C语言之结构体

一.前言引入. 我们知道在C语言中有内置类型,如:整型,浮点型等。但是只有这些内置类 型还是不够的,假设我想描述学⽣,描述⼀本书,这时单⼀的内置类型是不⾏的。描述⼀个学⽣需要名字、年龄、学号、⾝⾼、体…

Spark经典案例分享

Spark经典案例 链接操作案例二次排序案例 链接操作案例 案例需求 数据介绍 代码如下: package base.charpter7import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.a…

四、Zookeeper节点类型

目录 1、临时节点 2、永久节点 Znode有两种,分别为临时节点和永久节点。 节点的类型在创建时即被确定,并且不能改变。 1、临时节点 临时节点的生命周期依赖于创建它们的会话。一旦会话结束,临时节点将被自动删除,

OpenCV-Python:计算机视觉介绍

目录 1.背景 2.计算机视觉发展历史 3.计算机视觉主要任务 4.计算机视觉应用场景 5.知识笔记 1.背景 OpenCV是计算机视觉的一个框架,想要学习OpenCV,需要对计算机视觉有一个大致的了解。计算机视觉是指通过计算机技术和算法来模拟人类视觉系统的能力…

Redis高效缓存:加速应用性能的利器

目录 引言 1. Redis概述 1.1 什么是Redis? 1.2 Redis的特点 2. Redis在缓存中的应用 2.1 缓存的重要性 2.2 Redis作为缓存的优势 2.3 缓存使用场景 3. Redis在实时应用中的应用 3.1 实时数据处理的挑战 3.2 Redis的实时数据处理优势 3.3 实时应用中的Red…

mediapipe+opencv实现保存图像中的人脸,抹去其他信息

mediapipeopencv MediaPipe本身不提供图像处理功能,它主要用于检测和跟踪人脸、手势、姿势等。如果您想要从图像中仅提取人脸主要信息并去除其他信息. # codingutf-8 """project: teatAuthor:念卿 刘file: test.pydate&…

Kubernetes学习笔记-Part.09 K8s集群构建

目录 Part.01 Kubernets与docker Part.02 Docker版本 Part.03 Kubernetes原理 Part.04 资源规划 Part.05 基础环境准备 Part.06 Docker安装 Part.07 Harbor搭建 Part.08 K8s环境安装 Part.09 K8s集群构建 Part.10 容器回退 第九章 K8s集群构建 9.1.集群初始化 集群初始化是首…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《余电上网/制氢方式下微电网系统全生命周期经济性评估》

该标题涉及到对微电网系统的全生命周期经济性进行评估,其重点关注两种运营方式:余电上网和制氢。以下是对标题的解读: 微电网系统: 微电网是指一种小规模的电力系统,通常包括分布式能源资源(如太阳能、风能…

ES通过抽样agg聚合性能提升3-5倍

一直以来,es的agg聚合分析性能都比较差(对应sql的 group by)。特别是在超多数据中做聚合,在搜索的条件命中特别多结果的情况下,聚合分析会非常非常的慢。 一个聚合条件:聚合分析请求的时间 search time a…

部署springboot项目到GKE(Google Kubernetes Engine)

GKE是 Google Cloud Platform 提供的托管 Kubernetes 服务,允许用户在 Google 的基础设施上部署、管理和扩展容器。本文介绍如何部署一个简单的springboot项目到GKE. 本文使用podman. 如果你用的是docker, 只需要把本文中所有命令中的podman替换成docker即可 非H…

java+springboot物资连锁仓库经营商业管理系统+jsp

主要任务:通过网络搜集与本课题相关的素材资料,认真分析连锁经营商业管理系统的可行性和要实现的功能,做好需求分析,确定该系统的主要功能模块,依据数据库设计的原则对数据库进行设计。最后通过编码实现本系统功能并测…

Linux周期任务

我自己博客网站里的文章 Linux周期任务:at和crontab 每个人或多或少都有一些约会或者是工作,有的工作是长期周期性的, 例如: 每个月一次的工作报告每周一次的午餐会报每天需要的打卡…… 有的工作则是一次性临时的&#xff0…

Prometheus+Grafana搭建日志采集

介绍 一、什么是日志数据采集 日志数据采集是指通过各种手段获取应用程序运行时产生的各类日志信息,并将这些信息存储到特定的地方,以便后续分析和使用。通常情况下,这些日志信息包括系统运行状态、错误信息、用户操作记录等等。通过对这些…

牛客算法题 【HJ97 记负均正】 golang实现

题目 HJ97 记负均正 描述 首先输入要输入的整数个数n,然后输入n个整数。输出为n个整数中负数的个数,和所有正整数的平均值,结果保留一位小数。 0即不是正整数,也不是负数,不计入计算。如果没有正数,则平均…

大文件分片上传、分片进度以及整体进度、断点续传(一)

大文件分片上传 效果展示 前端 思路 前端的思路&#xff1a;将大文件切分成多个小文件&#xff0c;然后并发给后端。 页面构建 先在页面上写几个组件用来获取文件。 <body><input type"file" id"file" /><button id"uploadButton…

动态规划学习——回文串

目录 一&#xff0c;回文子串 1.题目 2.题目接口 3&#xff0c;解题代码及其思路 解题代码&#xff1a; 二&#xff0c; 分割回文串II 1&#xff0c;题目 2&#xff0c;题目接口 3&#xff0c;解题思路及其代码 一&#xff0c;回文子串 1.题目 给你一个字符串 s &…