python torch exp_学习Pytorch过程遇到的坑(持续更新中)

1. 关于单机多卡的处理:

在pytorch官网上有一个简单的示例:函数使用为:torch.nn.DataParallel(model, deviceids, outputdevice, dim)关键的在于model、device_ids这两个参数。DATA PARALLELISM​pytorch.org

但是官网的例子中没有讲到一个核心的问题:即所有的tensor必须要在同一个GPU上。这是网络运行的前提。这篇文章给了我很大的帮助,里面的例子也很好懂,很直观:pytorch: 一机多卡训练的尝试​www.jianshu.com

一般来说有两种数据迁移的方法:

1)是先定义一个device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')【这里面已经定义了device在卡0上“cuda:0”】

然后将model = torch.nn.DataParallel(model,devices_ids=[0, 1, 2])(假设有三张卡)

此后需要将tensor 也迁移到GPU上去。注意所有的tensor必须要在同一张GPU上面

即:tensor1 = tensor1.to(device), tensor2 = tensor2.to(device)等等

(可能有人会问了,我并没有指定那一块GPU啊,怎么这样也没有出错啊?

原因很简单,因为一开始的device中已经指定了那一块卡了(卡的id为0))

2)第二中方法就是直接用tensor.cuda()的方法

即先model = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) (假设有三块卡, 卡的ID 为0, 1, 2)

然后tensor1 = tensor1.cuda(0), tensor2=tensor2.cuda(0)等等。(我这里面把所有的tensor全放进ID 为 0 的卡里面,也可以将全部的tensor都放在ID 为1 的卡里面)

2 关于DataParallel的封装问题

在DataParallel中,没有和nn.Module一样多的特性。但是有些时候我们可能需要使用到如.fc这样的性质(.fc性质在nn.Module中有, 但是在DataParallel中没有)这个时候我们需要一个.Module属性来进行过渡。操作如下:

model = Model() # 这里实例化Model类得到一个model

model.fc # 这样做不会报错

# DataParallel情况下

parallel_model = torch.nn.DataParallel(model)

parallel_model.fc # 会报错。解决办法,很简单, 在fc前加一个.module即可

parall_model.module.fc # 不会报错

3 Pytorch中的数据导入潜规则

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

所以我们在transform的时候可以先定义:normalized = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 然后用的时候直接调用normalized就行了。

4 python中的某些包的版本不同也会导致程序运行失败。

如,今天遇到一个pillow包的问题。原先装的包的6.0.0版本的,但是在制作数据集的时候,训练集跑的好好的,一到验证集就开始无端报错。在确定程序无误之后,将程序放在别的环境中跑(也是pytorch环境),正常运行。于是经过几番查找,发现是pillow出了问题,于是乎卸载了原来的版本,重新装一个低一点的版本问题就解决了。这种版本问题的坑其实很多,而且每个人遇到的还都不尽相同,所以需要慢慢的去摸索才能发现问题所在。

5 关于CUDA 内存溢出的问题。

这个一般是因为batch_size 设置的比较大。(8G显存的话大概batch_size < = 64都ok, 如果还是报错的话,就在对半分 64, 32, 16, 8, 4等等)。而且这个和你的数据大小没什么太大的关系。因为我刚刚开始也是想可能是我训练集太大了,于是将数据集缩小了十倍,还是同样的报错。所以就想可能 batch_size的问题。最后果然是batchsize的问题。

6 关于模型导入

一般来说如果你的模型是再GPU上面训练的,那么如果你继续再GPU上面进行其他的后续操作(如迁移学习等)那么直接使用:

import torch

from torchvision import models

pre_trained_weight = torch.load('pre_trained_weight.pt') # pre_trained_weight.pt 是我在resnet18上面训练好的模型

resnet18 = models.resnet18(pretrained=False) # 导入框架

resnet18.load_state_dict(pre_trained_weight) # load_state_dict()函数表示导入当前权值,因为这个权值都是以字典的形式保存的

# 如果你模型在GPU上训练的,而且后续操作也在GPU上进行,那上面的操作就没啥毛病。但是…………

如果你模型在GPU上训练的,后续操作是在CPU上进行的话。那么还用上面的代码的话就会报错了。因为你模型在GPU上训练,其实其内部的某些数据格式和CPU上的不大一样。所以需要一个函数将GPU上的模型转化为CPU上的模型。这个工作在pytorch里面其实很简单。只要把上面的代码简单修改一下即可:(在torch.load函数里面加一个map_location='cpu'即可!)

import torch

from torchvision import models

pre_trained_weight = torch.load('pre_trained_weight.pt',map_location='cpu') # pre_trained_weight.pt 是我在resnet18上面训练好的模型

resnet18 = models.resnet18(pretrained=False) # 导入框架

resnet18.load_state_dict(pre_trained_weight) # load_state_dict()函数表示导入当前权值,因为这个权值都是以字典的形式保存的

7. 关于两次sort操作:

前几天看SSD pytorch的源码发现了,有这样的一步操作,不得解,

于是查阅了一下资料和动手操作后发现了两次sort操作的神奇之处。

首先 sort操作没什么好说。它接收两个参数:dim和descending参数。dim表示的是从哪个维度进行排列,descending参数接收布尔类型的输入,表示结果是否按降序排列。两次sort操作的具体实施为。

import torch

x = torch.randon(3, 4)

>>>x

tensor([[-0.1361, 0.4076, -0.8244, 0.9163],

[-0.0997, -1.1689, -2.3145, 1.2334],

[-0.4384, -1.6083, 1.7621, -0.9648]])

_, indices = x.sort(dim=1, descending=False)

>>>indices

tensor([[2, 0, 1, 3],

[2, 1, 0, 3],

[1, 3, 0, 2]])

# 上面的是进行第一次的sort, 得到的结果关于x的每一行的元素的升序排列

# 下面进行第二次sort操作。

_, idx = indices.sort(dim=1, descending=False)

>>>idx

tensor([[1, 2, 0, 3],

[2, 1, 0, 3],

[2, 0, 3, 1]])

# 我们来分析一下这个得到的idx和原始数据x的关系。

>>>x

tensor([[-0.1361, 0.4076, -0.8244, 0.9163],

[-0.0997, -1.1689, -2.3145, 1.2334],

[-0.4384, -1.6083, 1.7621, -0.9648]])

按升序排列的话,x的【第一行】中的第一个元素对应的是第二小,第二个元素对应的第三小,第三个元素对应是最小, 最后一个元素应该是最大的

所以这个排列的大小和位置可以从二次sort操作的idx中能看到。现在分析idx,取其第一行【1, 2, 0, 3】, 表示的意思是x[0,0]处在x[0]这一行

的第二位,x[0, 1]处在下x[0]中的第三位, x[0, 2]处在x[0]这一行的第一位, 下x[0, 3]处在x[0]行的最后一位。

(注:这里的第几位表示的是每一行按升序排列原则,其中的元素所处的位置)

从上面的分析中可以看到,两次sort操作得到的idx的意义是: 在保证原始元素的位置不变的情况下,可以表示排序情况(升序or降序)。

以上是原理,那么两次sort究竟用在什么地方呢?

还是上面哪个例子:

>>>x

tensor([[-0.1361, 0.4076, -0.8244, 0.9163],

[-0.0997, -1.1689, -2.3145, 1.2334],

[-0.4384, -1.6083, 1.7621, -0.9648]])

我想取x的第一行元素的前1个最小值, 第二行元素的前2个最小值,第三行元素的前3个最小值。该怎么操作呢?

根据上面的两次sort操作,我们得到idx

tensor([[1, 2, 0, 3],

[2, 1, 0, 3],

[2, 0, 3, 1]])

# 定义criterion

criterion = torch.tensor([1, 2, 3]).view(3, -1)

criterion = criterion.expand_as(idx)

>>>criterion

tensor([[1, 1, 1, 1],

[2, 2, 2, 2],

[3, 3, 3, 3]])

mask = idx < criterion

>>>mask

tensor([[0, 0, 1, 0],

[0, 1, 1, 0],

[1, 1, 0, 1]], dtype=torch.uint8)

# 可以看到,mask得到的就是我们所需要的索引。可以看到mask第一行只有一个1, 第二行有两个1,第三行有三个1.这里的1表示的True的意思,即得到这个数

>>>x[mask]

tensor([-0.8244, -1.1689, -2.3145, -0.4384, -1.6083, -0.9648]) # 最终结果

8. log_sum_exp的trick:机器学习常见模式LogSumExp解密人工智能_机器人之家​www.jqr.com

参考这篇文章,写的通俗易懂。大概介绍一下问题:

发现这个问题是前几天,这里面在进行exp操作的时候用x-x_max。当时很是疑惑。后来一看上面这篇文章才明白了。

一般来说

是有一个确切的值与之对应的。但是在计算里面却不是这样的。输入torch.exp(1000), 结果是:

这样的结果并不意外,因为计算机的存储阶段误差导致的。基于这种情况的存在,所以人们想到了一个比较好的解决方法。具体怎么实现看看上面的链接便清楚了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/538951.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

php 读文件返回字符串,PHP:file_get_contents('php:// input')返回JSON消息的字符串...

我正在尝试在我的PHP应用程序中读取JSON消息&#xff0c;这是我的php代码&#xff1a;$json file_get_contents(php://input);$obj json_decode($json, TRUE);echo $obj->{S3URL};当我这样做时&#xff0c;出现以下错误&#xff1a;Trying to get property of non-object …

Android中ListView数据使用sAdapter.notifyDataSetChanged();方法不刷新的问题

原文链接&#xff1a;http://blog.csdn.net/caihongdao123/article/details/51513410 点击阅读原文 ------------------------- 1.涉及到数据库 当要动态显示更数据库改动&#xff0c;相信大家应该都用过notifyDataSetChanged();. 例如&#xff1a; ...... …

keepalived配置高可用集群

准备工作 分别在主从上安装keepalived和nginxyum install -y keepalivedyum install -y nginx关闭主从上的防火墙和SELinuxsystemctl stop firewalldsetenforce 0 配置主机 查看主机ip [rootlynn-04 ~]# ifconfig ens33: flags4163<UP,BROADCAST,RUNNING,MULTICAST> mtu…

如何快速掌握python包_如何快速掌握一个python模块?

初学者就别想快了。 我自己是这样的。先上网看看一些基础的教程&#xff0c;非常快的过一下&#xff0c;十几分钟&#xff0c;主要是了解这个module能干什么&#xff0c;特别是一些基本的功能&#xff0c;头脑中建立起初步映射。 然后就是用了&#xff0c;不用看了也白搭。我假…

python设计一个函数定义计算并返回n价调和函数_python函数的调和平均值?

我有两个函数&#xff0c;给出精度和召回分数&#xff0c;我需要做一个调和平均函数&#xff0c;定义在同一个库中&#xff0c;使用这两个分数。函数如下所示&#xff1a;功能如下&#xff1a;def precision(ref, hyp):"""Calculates precision.Args:- ref: a l…

jsp超链接到java文件,jsp页面超链接传中文终极解决办法

在做web前端页面的时候&#xff0c;经常碰到传中文问题。网上有许多方案&#xff0c;但仍不能根治&#xff0c;最终要用js或者java的encode相关方法。常规方案有三部&#xff1a;1.改tomcat的server.xml中URIEncodeing为utf-82.页面中编码设置为utf-83.整个项目编码使用utf-8我…

自定义ListView中的分割线

原文&#xff1a;http://blog.csdn.net/zuolongsnail/article/details/7187302点击阅读 --------------------------------------- ListView中每个Item项之间都有分割线&#xff0c;设置Android:footerDividersEnabled表示是否显示分割线&#xff0c;此属性默认为true。 1.不显…

隐藏域input里面放当前时间_【小A问答】Win10的隐藏小秘密,被我发现了!

无惊无险又到小A问答环节辣~~今天的小A要来给大家分享一些小秘密&#xff01;当然&#xff0c;这可不是小A自己的小秘密&#xff0c;是关于你电脑的小秘密哦&#xff01;知道吗&#xff1f;Windows10每一次升级更新&#xff0c;都会伴随着新功能的增加。这些隐藏的功能你都发现…

网络相关的面试题

1&#xff09;简述tcp/ip的三次交互过程&#xff08;个人理解&#xff1a;syn是握手信号&#xff0c;ack是确认信号&#xff0c;ack就相当于前面的syn值1&#xff0c;简单一点理解就是客户端发送握手请求&#xff0c;服务器收到握手请求后&#xff0c;回复一个包确认它接收到了…

h5文字垂直居中_CSS中垂直居中和水平垂直居中的方法

flex垂直居中&#xff1a;第一种&#xff1a;使用flex布局&#xff0c;让居中元素的父元素为flex属性,让它在交叉轴上center就可以达到居中效果了&#xff1a;html代码: <div class"father"><p>我要垂直居中</p> </div>css代码: .father {…

ListMap排序

//compareto就是比较两个数据的大小关系 大于0表示前一个数据比后一个数据大&#xff0c; 0表示相等&#xff0c;小于0表示第一个数据小于第二个数据 public static List<Map<String, String>> sortWifi(List<Map<String, String>> wifiList){if(wif…

thinkphp回调的php调用db类,请问thinkphp中model类自动完成功能 回调函数能不能获取其他字段的值?...

http://www.thinkphp.cn/api/source-class-Model.html#975protected function _validationFieldItem($data,$val) {switch(strtolower(trim($val[4]))) {case function:// 使用函数进行验证case callback:// 调用方法进行验证$args isset($val[6])?(array)$val[6]:array();if…

输入年份和月份输出该月有多少天python_Python实现用户输入年月日,程序打印出这是这一年的第多少天...

1. 自己造轮子yearint(input(请输入年份&#xff0c;如2019>>>))monthint(input(请输入月份&#xff0c;如8>>>))dayint(input(请输入日期,如25>>>))#下面这块代码是按照闰年计算if (year%40 and year%100!0) or (year%4000):calendar{1:31,2:29,3:…

Linux命令之find命令中的-mtime参数

有关find -mtime这个参数的使用有比较多的坑&#xff0c;今天把这个问题在这里记录下来&#xff1a; mtime参数的理解应该如下&#xff1a; -mtime n 按照文件的更改时间来找文件&#xff0c;n为整数。 n 表示文件更改时间距离为n天-n 表示文件更改时间距离在n天以内n 表示文件…

WifiManager的getScanResults()返回列表为0

这个问题查了好久&#xff0c;花了2个小时。就是出不来。 原来问题在android sdk 版本问题。 在android 6.0的时候&#xff0c;返回为空&#xff0c;且不为null&#xff0c;在华为mate&#xff0c;6.0手机上测试&#xff0c;也不报错。 官网和网上没有具体的解决方法。 后来…

c++直角坐标系与极坐标系的转换_平面向量的奇技淫巧——斜坐标系的一系列低级研究...

事先说明&#xff1a;笔者初三&#xff0c;如在叙述中有不严谨的地方&#xff0c;还请诸位指出&#xff0c;自当感激不尽。一.什么是斜坐标系众所周知&#xff0c;我们目前平面中使用相当广的坐标系是笛卡尔发明的平面直角坐标系。然而&#xff0c;笛卡尔真的只使用了这一种坐标…

php 字节转为kb,PHP获取文件大小并转化为KB、MB、GB单位

PHP获取文件大小并转化为KB、MB、GB单位。function getSize($filesize) {if ($filesize > 1073741824) {$filesize round($filesize / 1073741824 * 100) / 100 . GB;} elseif ($filesize > 1048576) {$filesize round($filesize / 1048576 * 100) / 100 . MB;} else…

python 重定向stdout_Python 犄角旮旯--重定向 stdout

What&#xff1f;在 Python 程序中&#xff0c;使用 print 输出调试信息的做法非常常见&#xff0c;但有的时候我们需要将 print 的内容改写到其他位置&#xff0c;比如一个文件中&#xff0c;便于随时排查。但是又不希望大面积替换 print 函数&#xff0c;这就需要一些技巧实现…

Jetty实战之 安装 运行 部署

原文地址&#xff1a;http://blog.csdn.net/kongxx/article/details/7218767 1. 首先从Jetty的官方网站http://wiki.eclipse.org/Jetty/Starting/Downloads下载最新的Jetty&#xff0c;上面有两个版本7.x和8.x&#xff0c;7.x是运行在JDK5及以上版本&#xff0c;8.x是运行在JD…

一行命令从 APK 文件中提取 Endpoint 及 URL

做IoT的人免不了要接触Android&#xff0c;接触Android的人又免不了要研究别人的App应用。 Diggy&#xff0c;一款能够从 apk 文件中提取 endpoint 及 URL 的工具&#xff0c;只要一行命令就可以帮大家提取出相关Android apk文件的安装信息和互联网访问信息。 下载地址&#xf…