pytorch-激活函数与GPU加速

目录

  • 1. sigmod和tanh
  • 2. relu
  • 3. Leaky Relu
  • 4. selu
  • 5. softplus
  • 6. GPU加速
  • 7. 使用GPU加速手写数据训练

1. sigmod和tanh

sigmod梯度区间是0~1,当梯度趋近0或者1时会出现梯度弥散的问题。
tanh区间时-1~1,是sigmod经过平移和缩放而得到的,也存在梯度弥散的问题。
在这里插入图片描述

2. relu

relu函数当梯度<0时,梯度是0,梯度>0时梯度是1,不会出现梯度弥散和梯度爆炸,虽然relu函数使用广泛也不易出现梯度弥散和梯度爆炸,但是不代表它不会出现。
在这里插入图片描述

3. Leaky Relu

在梯度<0的时候,不在是等于0而是变成了a*x, a是一个比较小的系数,确保梯度小于0时不再是0
在这里插入图片描述

4. selu

由两部分组成一部分时Relu,另一部分是一个指数函数,从而使得selu在0点变成了连续的。
在这里插入图片描述

5. softplus

时relu的一个连续光滑的版本,在0处变得光滑而连续
在这里插入图片描述
总结:目前用的最大的sigmod、tanh、relu、leakyrelu,其他两种用的较少

6. GPU加速

torch.device(‘cuda:0’)中的cuda:0代表第几块显卡,如果使用CPU那么就是torch.device(‘cpu’)
使用.to(device)就把模块或者数据搬到了GPU上,然而模块和数据是有一些区别的,模块执行.to(device)返回一个reference和不使用初始化是完全一样的属于一个inplace操作,但是data就不一样了,比如:data2=data.to(device),data2和data是完全不一样的,data2是gpu数据,data是cpu数据。
注意:.cuda()方法已经不推荐使用了
在这里插入图片描述

7. 使用GPU加速手写数据训练

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

从代码中可以看到网络、loss函数和数据都搬到了GPU上,激活函数改成了LeakyRelu

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/3736.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【毕设绝技】基于 SpringCloud 的在线交易平台商城的设计与实现-数据库设计(三)

毕业设计是每个大学生的困扰&#xff0c;让毕设绝技带你走出低谷迎来希望&#xff01; 基于 SpringCloud 的在线交易平台商城的设计与实现 一、数据库设计原则 在系统中&#xff0c;数据库用来保存数据。数据库设计是整个系统的根基和起点&#xff0c;也是系统开发的重要环节…

Matlab|交直流混合配电网潮流计算(统一求解法)

目录 1 主要内容 算例模型 统一求解法迭代方程 算法流程图 2 部分代码 3 程序结果 4 下载链接 1 主要内容 该程序为matlab代码&#xff0c;采用统一求解法对交直流混合配电网进行潮流计算&#xff0c;统一迭代法又称统一求解法&#xff0c;其思路是将混联系统中的交流网…

强化学习的智能体概念与代码实例

在强化学习中&#xff0c;智能体&#xff08;agent&#xff09;是与环境进行交互并学习如何做出决策以达到某种目标的实体。智能体通常具有以下几个组成部分&#xff1a; 策略&#xff08;Policy&#xff09;&#xff1a;智能体的策略定义了在给定状态下选择动作的规则或概率分…

C语言 | Leetcode C语言题解之第44题通配符匹配

题目&#xff1a; 题解&#xff1a; bool allStars(char* str, int left, int right) {for (int i left; i < right; i) {if (str[i] ! *) {return false;}}return true; } bool charMatch(char u, char v) { return u v || v ?; };bool isMatch(char* s, char* p) {in…

React Hooks(常用)笔记

一、useState&#xff08;保存组件状态&#xff09; 1、基本使用 import { useState } from react;function Example() {const [initialState, setInitialState] useState(default); } useState(保存组件状态) &#xff1a;React hooks是function组件(无状态组件) &#xf…

探索 Chrome 插件开发之旅

浏览器扩展程序拥有无限可能性&#xff0c;它们能丰富我们的浏览体验&#xff0c;提升工作效率&#xff0c;甚至改变网络世界的交互方式。谷歌 Chrome 浏览器的插件生态尤为繁荣&#xff0c;本文将引导你走进 Chrome 插件开发的世界&#xff0c;从入门基础知识到实战案例&#…

IDEA生成测试类

方法一 具体流程: 选中要生成的测试类------------>选择code选项------------>选择Generate选项---------->选择test选项---------->选择要生成的方法 第一步: 光标选中需要生成测试类的类 找到code选项 选中Generate选项 选中test选项 选中你要生成的测试…

【嵌入式笔试题】C语言笔试题(4)

C语言非常之经典的笔试题。 4.综合题(18道) 4.1下面代码输出是几? int main() { int j = 2; int i = 1; if(i = 1) j = 3; if(i = 2) j = 5; printf("%d", j); } 答案: 输出为 5 。 解读: 注意 if 的条…

授人以渔 选购篇十二:路由器选购要点

文章目录 系列文章Wi-Fi 标准&#xff1a;WiFi6或以上无线速度&#xff1a;千兆以上Mesh组网有线端口LAN端口数量&#xff1a;布线拓扑结构决定LAN端口速率&#xff1a;千兆WLAN端口&#xff1a;2.5G其他&#xff1a;Link Aggregation&#xff08;链路聚合&#xff09; 附加接口…

简单的jmeter上传文件脚本

1、设置上传接口的headers的值 2、添加post请求

数据结构系列-二叉树之前序遍历

&#x1f308;个人主页&#xff1a;羽晨同学 &#x1f4ab;个人格言:“成为自己未来的主人~” 这篇文章&#xff0c;我们主要的内容是对二叉树当中的前历的算法进行讲解&#xff0c;二叉树中的算法所要求实现的是 从根到左子树再到右子树的遍历顺序&#xff0c;可能这样不太…

什么是用户体验(UX)文案,为什么它很重要?

网上购物如今比以往任何时候都更加相关。所以我们将以此为例说明什么是用户体验&#xff08;UX&#xff09;文案&#xff0c;以及为什么它很重要。 假设你去了一个在线商店。你需要执行一系列操作&#xff1a; 找到合适的部分选择你感兴趣的产品弄清楚它们是什么&#xff0c;…

2.6设计模式——Flyweight 享元模式(结构型)

意图 运用共享技术有效地支持大量细粒度的对象。 结构 其中 Flyweight描述一个接口&#xff0c;通过这个接口Flyweight可以接受并作用于外部状态。ConcreteFlyweight实现Flyweight接口&#xff0c;并作为内部状态&#xff08;如果有&#xff09;增加存储空间。ConcreteFlywe…

Unity AssetsBundle打包

为什么要使用AssetsBundle包 减少安装包的大小 默认情况下&#xff0c;unity编译打包是对项目下的Assets文件夹全部内容进行压缩打包 那么按照这个原理&#xff0c;你的Assets文件夹的大小将会影响到你最终打包出的安装包的大小&#xff0c;假如你现在正在制作一个游戏项目&…

没有文件服务器,头像存哪里合适

没有文件服务器&#xff0c;头像存哪里合适 视频在bilibili&#xff1a;没有文件服务器&#xff0c;头像存哪里合适 1. 背景 之前有同学私信我说&#xff0c;他的项目只是想存个头像&#xff0c;没有别的文件存储需求&#xff0c;不想去用什么Fastdfs之类的方案搭建文件服务…

【C++杂货铺】多态

目录 &#x1f308;前言&#x1f308; &#x1f4c1;多态的概念 &#x1f4c1; 多态的定义及实现 &#x1f4c2; 多态的构成条件 &#x1f4c2; 虚函数 &#x1f4c2; 虚函数重写 &#x1f4c2; C11 override 和 final &#x1f4c2; 重载&#xff0c;覆盖&#xff08;重写…

ARM学习(26)链接库的依赖查看

笔者今天来聊一下查看链接库的依赖。 通常情况下&#xff0c;运行一个可执行文件的时候&#xff0c;可能会出现找不到依赖库的情况&#xff0c;比如图下这种情况&#xff0c;可以看到是缺少了license.dll或者libtest.so&#xff0c;所以无法运行。怎么知道它到底缺少什么dll呢&…

HarmonyOS-Next开源三方库 MPChart:打造出色的图表体验

点击下载源码https://download.csdn.net/download/liuhaikang/89228765 简介 随着移动应用的不断发展&#xff0c;数据可视化成为提高用户体验和数据交流的重要手段之一。在 OpenAtom OpenHarmony&#xff08;简称“OpenHarmony”&#xff09;应用开发中&#xff0c;一个强大而…

机器学习day2

一、KNN算法简介 KNN 算法&#xff0c;或者称 k最邻近算法&#xff0c;是 有监督学习中的分类算法 。它可以用于分类或回归问题&#xff0c;但它通常用作分类算法。 二、KNN分类流程 1.计算未知样本到每一个训练样本的距离 2.将训练样本根据距离大小升序排列 3.取出距离最近…