[pytorch、学习] - 3.6 softmax回归的从零开始实现

参考

3.6 softmax回归的从零开始实现

import torch
import torchvision
import numpy as np
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

3.6.1. 获取和读取数据

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.6.2. 初始化模型参数

num_inputs = 784
num_outputs = 10W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float)  # torch.Size([784, 10])
b = torch.zeros(num_outputs, dtype=torch.float)   # torch.Size([10])# 同之前一样,我们需要模型参数梯度。
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)

在这里插入图片描述

3.6.3. 实现softmax运算

def softmax(X):X_exp = X.exp()partition = X_exp.sum(dim=1, keepdim=True)return X_exp / partition

3.6.4. 定义模型

# 传入特征,给出预测值
def net(X):return softmax(torch.mm(X.view((-1, num_inputs)), W) + b)

3.6.5. 定义损失函数

def cross_entropy(y_hat, y):return -torch.log(y_hat.gather(1, y.view(-1, 1)))

3.6.6. 计算分类准确率

def accuracy(y_hat, y):return (y_hat.argmax(dim=1) ==y).float().mean().item()def evaluate_accuracy(data_iter, net):acc_sum, n = 0.0, 0for X, y in data_iter:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()n += y.shape[0]return acc_sum /n

3.6.7. 训练模型

  • d2lzh
num_epochs, lr = 5, 0.1# 本函数已保存在d2lzh包中方便以后使用
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:d2l.sgd(params, lr, batch_size)else:optimizer.step()  # “softmax回归的简洁实现”一节将用到train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)

在这里插入图片描述

3.6.8. 预测

X, y = iter(test_iter).next()true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]d2l.show_fashion_mnist(X[0:9], titles[0:9])

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250191.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Django基础必备三件套: HttpResponse render redirect

1. HttpResponse : 它的作用是内部传入一个字符串参数, 然后发给浏览器 def index(request):return HttpResponse(ok) 2. render : 可以接收三个参数, 一是request参数, 二是待渲染的 html 模板文件, 三是保存具体数据的字典参数 def index(request):return render(request, …

React 简单实例 (React-router + webpack + Antd )

React Demo Github 地址 经过React Native 的洗礼之后,写了这个 demo ;React 是为了使前端的V层更具组件化,能更好的复用,同时可以让你从操作dom中解脱出来,只需要操作数据就会改变相应的dom; 而React Nat…

[pytorch、学习] - 3.7 softmax回归的简洁实现

参考 3.7. softmax回归的简洁实现 使用pytorch实现softmax import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.7.1. 获取和读取数据 batch_size 256 train_iter…

【模板】NTT

NTT模板 #include<bits/stdc.h> using namespace std; #define LL long long const int MAXL22; const int MAXN1<<MAXL; const int Mod998244353; int rev[MAXN],A[MAXN],B[MAXN],C[MAXN]; int fast_pow(int a,int b){int ans1;while(b){if(b&1)ans1ll*ans*a%…

centos 7 php7 yum源

rpm -Uvh https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpmrpm -Uvh https://mirror.webtatic.com/yum/el7/webtatic-release.rpm 转载于:https://www.cnblogs.com/myJuly/p/10008252.html

[pytorch、学习] - 3.9 多重感知机的从零开始实现

参考 3.9 多重感知机的从零开始实现 import torch import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.9.1. 获取和读取数据 batch_size 256 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size)3.9.2. 定义模型参…

C语言逗号运算符和逗号表达式基础总结

逗号运算符的作用&#xff1a; 1&#xff0c;起分隔符的作用&#xff1a; 定义变量用于分隔变量&#xff1a;int a,b输入或输出时用于分隔输出表列 printf("%d%d",a,b) 2,用于逗号表达式的顺序运算符 语法&#xff1a;表达式1&#xff0c;表达式2&#xff0c;...,表达…

java基础-泛型举例详解

泛型 泛型是JDK5.0增加的新特性&#xff0c;泛型的本质是参数化类型&#xff0c;即所操作的数据类型被指定为一个参数。这种类型参数可以在类、接口、和方法的创建中&#xff0c;分别被称为泛型类、泛型接口、泛型方法。 一、认识泛型 在没有泛型之前,通过对类型Object的引用来…

MySQL数据库视图(view),视图定义、创建视图、修改视图

原文链接&#xff1a;https://blog.csdn.net/moxigandashu/article/details/63254901转载于:https://www.cnblogs.com/chrdai/p/9131881.html

[pytorch、学习] - 3.10 多重感知机的简洁实现

参考 3.10. 多重感知机的简洁实现 import torch from torch import nn from torch.nn import init import numpy as np import sys sys.path.append("..") import d2lzh_pytorch as d2l3.10.1. 定义模型 num_inputs, num_outputs, num_hiddens 784, 10, 256 # 参…

【汇编语言】——第三章课后总结

第三章 的书本上主要有以下几个内容&#xff1a; 1.内存中字的存储 字单元&#xff1a;即存放一个字型数据&#xff08;16位&#xff09;的内存单元&#xff0c;由两个地址连续的内存单元组成。 小端法&#xff1a;高地址内存单元中存放字型数据的高位字节&#xff0c;低地址内…

如何从 Android 手机免费恢复已删除的通话记录/历史记录?

有一个有合作意向的人给我打电话&#xff0c;但我没有接听。更糟糕的是&#xff0c;我错误地将其删除&#xff0c;认为这是一个骚扰电话。那么有没有办法从 Android 手机恢复已删除的通话记录呢&#xff1f;” 塞缪尔问道。如何在 Android 上恢复已删除的通话记录&#xff1f;如…

springBoot 登录拦截器

1、首选创建一个继承HandlerInterceptor的拦截器 import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse;import org.springframework.web.servlet.HandlerInterceptor; import org.springframework.web.servlet.ModelAndView; /*** 拦…

[pytorch、学习] - 3.11 模型选择、欠拟合和过拟合

参考 3.11 模型选择、欠拟合和过拟合 3.11.1 训练误差和泛化误差 在解释上述现象之前&#xff0c;我们需要区分训练误差&#xff08;training error&#xff09;和泛化误差&#xff08;generalization error&#xff09;。通俗来讲&#xff0c;前者指模型在训练数据集上表现…

关于'java' 不是内部或外部命令,也不是可运行的程序 或批处理文件 和 错误: 找不到或无法加载主类 helloworld的问题...

一、前几天电脑重装了一次系统将java配置的环境变量都弄没了&#xff0c;自己添加了两个新的变量JAVA_HOME&#xff08;自己jdk的地址&#xff09;以及在path中添加%JAVA_HOME%\bin;%JAVA_HOME%\jre\bin; 然后因为这几天都是用eclipse进行编程的&#xff0c;没有出现问题&#…

spring-boot注解详解(一)

spring-boot注解详解(一) SpringBootApplication SpringBootApplication (默认属性)Configuration EnableAutoConfiguration ComponentScan。 Configuration&#xff1a;提到Configuration就要提到他的搭档Bean。使用这两个注解就可以创建一个简单的spring配置类&#xf…

前端基础-jQuery的优点以及用法

一、jQuery介绍 jQuery是一个轻量级的、兼容多浏览器的JavaScript库。jQuery使用户能够更方便地处理HTML Document、Events、实现动画效果、方便地进行Ajax交互&#xff0c;能够极大地简化JavaScript编程。它的宗旨就是&#xff1a;“Write less, do more.“二、jQuery的优势 一…

[pytorch、学习] - 3.12 权重衰减

参考 3.12 权重衰减 本节介绍应对过拟合的常用方法 3.12.1 方法 正则化通过为模型损失函数添加惩罚项使学出的模型参数更小,是应对过拟合的常用手段。 3.12.2 高维线性回归实验 import torch import torch.nn as nn import numpy as np import sys sys.path.append("…

Scapy之ARP询问

引言 校园网中&#xff0c;有同学遭受永恒之蓝攻击&#xff0c;但是被杀毒软件查下&#xff0c;并知道了攻击者的ip也是校园网。所以我想看一下&#xff0c;这个ip是PC&#xff0c;还是路由器。 在ip视角&#xff0c;路由器和pc没什么差别。 实现 首先是构造arp报文&#xff0c…

spring-boot注解详解(二)

ResponseBody 作用&#xff1a; 该注解用于将Controller的方法返回的对象&#xff0c;通过适当的HttpMessageConverter转换为指定格式后&#xff0c;写入到Response对象的body数据区。使用时机&#xff1a; 返回的数据不是html标签的页面&#xff0c;而是其他某种格式的数据时…