batchsize和数据量设置比例_Keras - GPU ID 和显存占用设定步骤

初步尝试 Keras (基于 Tensorflow 后端)深度框架时, 发现其对于 GPU 的使用比较神奇, 默认竟然是全部占满显存, 1080Ti 跑个小分类问题, 就一下子满了. 而且是服务器上的两张 1080Ti.

服务器上的多张 GPU 都占满, 有点浪费性能.

因此, 需要类似于 Caffe 等框架的可以设定 GPU ID 和显存自动按需分配.

实际中发现, Keras 还可以限制 GPU 显存占用量.

这里涉及到的内容有:

GPU ID 设定

GPU 显存占用按需分配

GPU 显存占用限制

GPU 显存优化

1. GPU ID 设定

#! -- coding: utf-8 --*--

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

这里将 GPU ID 设为 1.

GPU ID 从 0 开始, GPUID=1 即表示第二块 GPU.

2. GPU 显存占用按需分配

#! -- coding: utf-8 --*--

import tensorflow as tf

import keras.backend.tensorflow_backend as ktf

# GPU 显存自动调用

config = tf.ConfigProto()

config.gpu_options.allow_growth=True

session = tf.Session(config=config)

ktf.set_session(session)

3. GPU 显存占用限制

#! -- coding: utf-8 --*--

import tensorflow as tf

import keras.backend.tensorflow_backend as ktf

# 设定 GPU 显存占用比例为 0.3

config = tf.ConfigProto()

config.gpu_options.per_process_gpu_memory_fraction = 0.3

session = tf.Session(config=config)

ktf.set_session(session )

这里虽然是设定了 GPU 显存占用的限制比例(0.3), 但如果训练所需实际显存占用超过该比例, 仍能正常训练, 类似于了按需分配.

设定 GPU 显存占用比例实际上是避免一定的显存资源浪费.

4. GPU ID 设定与显存按需分配

#! -- coding: utf-8 --*--

import os

import tensorflow as tf

import keras.backend.tensorflow_backend as ktf

# GPU 显存自动分配

config = tf.ConfigProto()

config.gpu_options.allow_growth=True

#config.gpu_options.per_process_gpu_memory_fraction = 0.3

session = tf.Session(config=config)

ktf.set_session(session)

# 指定GPUID, 第一块GPU可用

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

5. 利用fit_generator最小化显存占用比例/数据Batch化

#! -- coding: utf-8 --*--

# 将内存中的数据分批(batch_size)送到显存中进行运算

def generate_arrays_from_memory(data_train, labels_train, batch_size):

x = data_train

y=labels_train

ylen=len(y)

loopcount=ylen // batch_size

while True:

i = np.random.randint(0,loopcount)

yield x[i*batch_size:(i+1)*batch_size],y[i*batch_size:(i+1)*batch_size]

# load数据到内存

data_train=np.loadtxt("./data_train.txt")

labels_train=np.loadtxt('./labels_train.txt')

data_val=np.loadtxt('./data_val.txt')

labels_val=np.loadtxt('./labels_val.txt')

hist=model.fit_generator(generate_arrays_from_memory(data_train,

labels_train,

batch_size),

steps_per_epoch=int(train_size/bs),

epochs=ne,

validation_data=(data_val,labels_val),

callbacks=callbacks )

5.1 数据 Batch 化

#! -- coding: utf-8 --*--

def process_line(line):

tmp = [int(val) for val in line.strip().split(',')]

x = np.array(tmp[:-1])

y = np.array(tmp[-1:])

return x,y

def generate_arrays_from_file(path,batch_size):

while 1:

f = open(path)

cnt = 0

X =[]

Y =[]

for line in f:

# create Numpy arrays of input data

# and labels, from each line in the file

x, y = process_line(line)

X.append(x)

Y.append(y)

cnt += 1

if cnt==batch_size:

cnt = 0

yield (np.array(X), np.array(Y))

X = []

Y = []

f.close()

补充知识:Keras+Tensorflow指定运行显卡以及关闭session空出显存

Step1: 查看GPU

watch -n 3 nvidia-smi #在命令行窗口中查看当前GPU使用的情况, 3为刷新频率

Step2: 导入模块

导入必要的模块

import os

import tensorflow as tf

from keras.backend.tensorflow_backend import set_session

from numba import cuda

Step3: 指定GPU

程序开头指定程序运行的GPU

os.environ['CUDA_VISIBLE_DEVICES'] = '1' # 使用单块GPU,指定其编号即可 (0 or 1or 2 or 3)

os.environ['CUDA_VISIBLE_DEVICES'] = '1,2,3' # 使用多块GPU,指定其编号即可 (引号中指定即可)

Step4: 创建会话,指定显存使用百分比

创建tensorflow的Session

config = tf.ConfigProto()

config.gpu_options.per_process_gpu_memory_fraction = 0.1 # 设定显存的利用率

set_session(tf.Session(config=config))

Step5: 释放显存

确保Volatile GPU-Util显示0%

程序运行完毕,关闭Session

K.clear_session() # 方法一:如果不关闭,则会一直占用显存

cuda.select_device(1) # 方法二:选择GPU1

cuda.close() #关闭选择的GPU

以上这篇Keras - GPU ID 和显存占用设定步骤就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/287940.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Blazor University (17)使用 RenderFragments 模板化组件

原文链接:https://blazor-university.com/templating-components-with-renderfragements/使用 RenderFragments 模板化组件源代码[1]到目前为止,我们已经创建了基于参数生成 100% 渲染输出的组件,但组件并不总是那么简单。有时我们需要创建将…

OpenGL® ES 3.0 Programming Guide - Book Website

OpenGL ES 3.0 Programming Guide - Book Website http://opengles-book.com sample codes in GitHub: https://github.com/danginsburg/opengles3-book/

ArcGIS实验教程——实验六:空间数据格式转换

ArcGIS实验视频教程合集:《ArcGIS实验教程从入门到精通》(附配套实验数据) 【实验描述】 空间数据从一个GIS平台跨到另一个GIS,必须经过格式转换,才能实现数据信息共享。本实验主要讲述空间数据矢栅互转、CAD(DWG)数据和Shapefile数据互转、栅格数据与ASCII文件之间的转换…

php基础教程 第五步 逻辑控制

逻辑判断 在开发项目时,竟然会出现逻辑控制。例如当用户输入“hello”时你需要自动回复“hello 欢迎”,当用户设置的定时时间到达时,你需要提醒用户时间已经结束;再举个例子,在玩网络游戏时,用户控制的角色…

在春意盎然的季节里初识GIT

Git 与 SVN 区别 GIT不仅仅是个版本控制系统,它也是个内容管理系统(CMS),工作管理系统等。 如果你是一个具有使用SVN背景的人,你需要做一定的思想转换,来适应GIT提供的一些概念和特征。 Git 与 SVN 区别点: 1、GIT是分布式的&…

WinForm混合Blazor(下)

有时,为了省事,我们也可以把窗体的控件注入到ServiceCollection中,在razor中订阅事件,这样就省了中间的桥梁,直接用控件当桥梁,下面以一个Button和Timer为例,来展示使用方式。本例是把Button和T…

ArcGIS实验教程——实验七:矢量数据空间校正(Spatial Adjustment)

ArcGIS实验视频教程合集:《ArcGIS实验教程从入门到精通》(附配套实验数据) 【实验描述】 本系列实验教程实验二讲述了栅格数据的数字化之前必须进行的操作--地理配准(地理配配准完整操作步骤),栅格地理配准和矢量空间校正都属于几何校正的内容,关于空间校正、地理配准、…

数据结构之冒泡排序

1 冒泡排序 冒泡排序(Bubble Sort)也是一种简单直观的排序算法。它重复地走访过要排序的数列,一次比较两个元素,如果他们的顺序错误就把他们交换过来 算法过程如下: 比较相邻的元素。如果第一个比第二个大,就交换他们两个。 对每一对相邻元素作同样的工作,从开始第一…

博图程序需要手动同步_贴吧求助帖博图实例单按钮控制灯的程序

接上一期在贴吧看见的求助帖(上图看得见水印),因为没人回复,发帖的楼主好像删除了帖子。结果我抽时间用博图15.1,S71200做了一个,希望给需要帮助的新人能够起到作用,感觉有用的话可以关注一下我的公众号低压电工&#…

操作系统,,,也考完了【流坑】

操作系统博大精深岂是区区2学分能容?学习了一些机制和理论模型之后课外还是要继续学习转载于:https://www.cnblogs.com/learn-to-rock/p/5447750.html

php基础教程 第六步 学习数组以及条件判断switch补充

条件语句 switch 在上一节的学习中&#xff0c;学习了php的条件语句if。在php编程中进行条件判断还可以使用switch语句。switch语句语法如下&#xff1a; <?php switch (值或表达式) { case 值等于值1:当值等于值1时要执行的代码break; case 值等于值2:当值等于值2时要执…

ArcGIS实验教程——实验八:矢量数据拼接

ArcGIS实验视频教程合集:《ArcGIS实验教程从入门到精通》(附配套实验数据) 【实验描述】 数字化工作都是分工完成的,那么数字化完成之后,怎样将各部分数字化的成果拼接成一个完整的矢量数据呢?本实验针对线状和面状矢量数据,讲解矢量化数据拼接的常用方法:合并(merge)…

iOS 类库列表

1. LinqToObjectiveC #import "NSArrayLinqExtensions.h" 它为NSArray添加了许多方法&#xff0c;能让你用流式API来转换、排序、分组和过滤其中的数据。转载于:https://www.cnblogs.com/SimonGao/p/4747065.html

dotnet-exec 小工具

dotnet-exec 小工具Intro在之前的文章中很多会有一些示例代码&#xff0c;这些代码一般都是一些很小的示例&#xff0c;尤其是介绍一些新特性的示例&#xff0c;基本上不会引用其他包&#xff0c;只有 SDK 就可以执行&#xff0c;对于这些示例&#xff0c;一般会每个实例单独一…

安卓手机抓包charles乱码_charles-抓包Andriod 手机的设置

长按弹出修改后&#xff1a;charles如果不配置SSL通用证书&#xff1b;会导致HPPTS协议的域名抓取失败/乱码的现象&#xff1b;现在SSL越来越多&#xff0c;很多博客都上了SSL&#xff0c;支付相关的行业更是基础配置&#xff1b;charles配置SSL证书&#xff0c;算起来很简单&a…

分布式服务下的关键技术(转)

系统架构演化历程-初始阶段架构 初始阶段的小型系统 应用程序、数据库、文件等所有的资源都在一台服务器上通俗称为LAMP&#xff08;linux、apache、mysql、php&#xff09;。 特征&#xff1a; 应用程序、数据库、文件等所有的资源都在一台服务器上。 描述&#xff1a; 通常服…

ArcGIS实验教程——实验九:矢量数据提取

ArcGIS实验视频教程合集:《ArcGIS实验教程从入门到精通》(附配套实验数据) 【实验描述】本实验以矢量数据为实验数据,讲解矢量数据的提取方法及注意事项。 一、实验内容 1、直接选取,导出(所有要素) 2、导出视图范围中的所有要素 3、按指定的裁剪框裁剪数据 4、按指…

linux之sort命令

1 sort命令的参数 sort 参数(可以省略) file 具体参数如下 -b:忽略每行前面开始的空格字符,空格数量不固定时,该选项几乎是必须要使用的("-n"选项隐含该选项,测试发现都隐含) -c:检查文件是否已经按照顺序排序,如未排序,会提示从哪一行开始乱序 -C:类似于&q…

php基础教程 第七步数组补充及循环基础

键值对 上一节中简单的了解了数组的定义、取值及存储&#xff0c;这一节补充一下上一节数组的内容。 在上一节中&#xff0c;我们知道索引是用来标记值的位置&#xff0c;通过索引可以取得当前位置的值。这种一个索引对应着一个值的关系是一个映射关系&#xff0c;称为键值对。…

vs2013 c# 中调用 c 编写的dll出错的可能错误

先说出错原因: 堆栈调用顺序 解决办法: 使用 __stdcall 或 使用C#属性 CallingConvention 起因是我想在c#中调用c函数结果出错了 如下 C 头文件 #define DLLExport extern "C" __declspec(dllexport)DLLExport int func(int a, int b);DLLExport void init…