Keras入门

首先当然是安装Keras。需要注意的是Keras有三种后端backend。后端是意思是Keras需要依赖他们进行张量的运算。这三种后端是:tensorflow,Theano,CNTK(微软)。这也是keras的优势:可以在多种生态中发布。一般使用Tensorflow作为后端,所以在安装Keras之前需要先安装tensorflow。在安装好Keras之后,命令行中import keras会提示using tenserflow backend,这就是在提示我们正在使用的是Tensorflow的后端,而不是错误信息。

然后介绍模型。Keras的模型有两种方式构造。

第一种是使用Sequential序贯模型。可以直接在Sequential中定义,

from keras.models import Sequential
from keras.layers import Dense, Activationmodel = Sequential([Dense(32, input_shape=(784,)),Activation('relu'),Dense(10),Activation('softmax'),
])
model = Sequential()
model.add(Dense(32, input_dim=784))
model.add(Activation('relu'))  #可以通过name设置层名字,加载已有模型可以通过名字判断是否加载这一层的权重model.pop() #删除最后添加的层 

第二种是函数式。

from keras.models import Model
from keras.layers import Input, Densea = Input(shape=(32,))
b = Dense(32)(a)
model = Model(inputs=a, outputs=b)

构建好之后使用一行函数就可以打印模型各层的参数。

model.summary()

模型编译

我理解的编译是一个静态的过程。是将模型与loss,优化器optimizer,评估标准metrics联系起来。

# 多分类问题
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])# 二分类问题
model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['accuracy'])# 均方误差回归问题
model.compile(optimizer='rmsprop',loss='mse')

 

数据

注意这里的label使用了onehot编码。onehot编码就是使用01字符串对类别标签进行编码,但不同于通信中力求编码长度最短,onehot的one指的就是一次编码中只出现一个1,其余全为0,这样onehot编码长度应该等于种类数。在这个函数中取的是标签列表中最大值+1.所以准确来说返回的是一个01构成的矩阵,矩阵行数是种类类别数,列数是种类标签值的最大值加1.这样做可能是认为种类标签值应该是连续的,为不存在的类别预留位置,同时也体现了类别之间的距离。

import kerasohl=keras.utils.to_categorical([1,3])
print(ohl)
<<<
[[0. 1. 0. 0.][0. 0. 0. 1.]]

对图像数据,一般使用keras提供的函数转换为array。

image = cv2.imread(path)
feature = cv2.resize(image,(IMAGE_DIMS[0],IMAGE_DIMS[1]))
feature = img_to_array(feature) #转换为类似float型的ndarray
features=feature/255.0  # tensor范围是0~1

可以使用yield(data,label)定义函数作为生成器。借助生成器的Next功能,每次只迭代一个值,减少了对内存的消耗。

模型训练

这一步是动态的过程,将模型针对数据做适应,所以使用了fit函数。

# 生成虚拟数据
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(10, size=(1000, 1))# 将标签转换为分类的 one-hot 编码
one_hot_labels = keras.utils.to_categorical(labels, num_classes=10)# 训练模型,以 32 个样本为一个 batch 进行迭代
model.fit(data, one_hot_labels, epochs=10, batch_size=32)

但现实中往往没有这么简单。现实中数据量很大,同时我们还需要进行数据增强,所以更多时候使用的是fit_generator方法。这种方法就利用了刚才提到的yeild构成的生成器,生成器作为fit_generator的第一个参数。在生成其函数中就可以进行一些如数据增强等预处理。因为生成器不像return一样输出就停止了,生成器会一直在上一步停止的地方开始执行,且是按照单位输出的(单位一般是batch_size),为了区别epoch的次数,我们还需要明确参数steps_per_epoch的取值。每次epoch遍历了完整的数据集,那么每次epoch就需要(数据长度/生成器每次输出长度) 次的生成器操作。所以我认为下面第二种解释是更合理的。

history = model.fit_generator(generator(),epochs=epochs,steps_per_epoch=len(x_train)//(batch_size*epochs))
history = model.fit_generator(generator(),epochs=epochs,steps_per_epoch=len(x_train)//(batch_size))

此外,还可以使用model.train_on_batch(x,y) 和 model.test_on_batch(x,y) 进行批量训练与测试。model.train_on_batch(x,y)看名字就知道,也是利用生成器,每次载入一个batch-size。https://github.com/fchollet推荐的是使用fit_generator,因为它也可以使用验证集的生成器。

训练进行中

fit_generator中还有一个可选参数是callback,虽然翻译叫做回调函数,但是它其实是一个类。而回调的意思是在训练的过程中我们可以通过它对模型的参数进行保存和调整。

checkpoint回调函数,可以对模型参数进行保存,这样即便训练过程意外中断,我们可以接着中断的地方继续训练。

LearningStepSchedule回调函数,可以对模型学习率按照策略进行调整。

#keras提供了两种学习率的更新方式
keras.callbacks.LearningRateScheduler(schedule) #第一种通过定义schedule函数,这个函数一般以epoch为自变量进行调整
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0) #第二种自动根据检测量的变化情况调整#定义checkpoint
filepath="weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5"
checkpoint= ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')# Earlystop回调函数可以提前终止训练
early_stopping = EarlyStopping(monitor='val_loss', patience=2)#定义的lr和checkpoint两个回调函数可以同时以列表的形式写在fit参数中
model.fit(train_set_x, train_set_y, validation_split=0.1, nb_epoch=200, batch_size=256, callbacks=[checkpoint,lr]) 

虽然前面已经对数据做了一些预处理,但是这些数据很可能仍然是不平衡的,或者是其他特殊情况。比如分类中,检测欺诈交易detect fraudulent transactions,不仅不平衡,惩罚的结果也不同,比如回归时,每条数据的可信度是不同的。这就是需要在训练时对loss加权重。类别不平衡时,需要设置的参数是class_weight,这是一个字典,如{0:1.,1:50.,2:3.},int型表示类别,float型表示对应的权重,对于权重大的类别,代表我们更关心这一类,当分错时loss更大,惩罚更大。 This can be useful to tell the model to "pay more attention" to samples from an under-represented class.至于sample_weights,是在类内样本级别的加权,对一些可信度低的数据赋予更低的权重;对实时性教新的数据赋予更高的权重。

模型评估

score = model.evaluate(data_test, label_test, batch_size=32)

Reference:

https://keras.io/zh/

https://keras-cn.readthedocs.io/en/latest/models/about_model/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/493571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

利用Inception-V3训练的权重微调,实现猫狗分类(基于keras)

利用Inception-V3训练的权重微调实现猫狗的分类&#xff0c;其中权重的下载在我的博客下载资源处&#xff0c;https://download.csdn.net/download/fanzonghao/10566634 第一种权重不改变直接用mixed7层&#xff08;mixed7呆会把打印结果一放就知道了&#xff09;进行特征提取…

刘锋:互联网左右大脑结构与钱学森开放复杂巨系统

作者&#xff1a;刘锋 互联网进化论作者 计算机博士前言&#xff1a;1990年&#xff0c;钱学森提出了开放的复杂巨系统理论&#xff0c;并提出以人为主&#xff0c;人机结合&#xff0c;从定性到定量的综合集成研究方法&#xff0c;他也预见性的提出“因特网正好生动地体现了…

DL中常用的numpy

读txt文件 按行读取有三种方式&#xff0c;注意readlines和readline的区别。open是python自带打开方式&#xff0c;如果打不开&#xff0c;可以使用encoding"UTF-8"指定解码方案。 读取得到一行之后&#xff0c;行首行尾可能存在一些不需要的字符&#xff0c;就可以…

paip.获取proxool的配置 xml读取通过jdk xml 初始化c3c0在代码中总结

paip.获取proxool的配置 xml读取通过jdk xml 初始化c3c0在代码中xml读取通过jdk xml初始化c3c0在代码中。。。。。作者Attilax 艾龙&#xff0c; EMAIL:1466519819qq.com来源&#xff1a;attilax的专栏地址&#xff1a;http://blog.csdn.net/attilaxproxoolController.ini()…

手写字母数据集转换为.pickle文件

首先是数据集&#xff0c;我上传了相关的资源&#xff0c;https://download.csdn.net/download/fanzonghao/10566701 转换代码如下&#xff1a; import numpy as np import os import matplotlib.pyplot as plt import matplotlib.image as mpig import imageio import pickle…

一文看懂谷歌的AI芯片布局,边缘端TPU将大发神威

来源&#xff1a;新电子2018年7月Google在其云端服务年会Google Cloud Next上正式发表其边缘(Edge)技术&#xff0c;与另两家国际公有云服务大厂Amazon/AWS、Microsoft Azure相比&#xff0c;Google对于边缘技术已属较晚表态、较晚布局者&#xff0c;但其技术主张却与前两业者有…

JS学习笔记-1--基本知识和注意事项

1、JS开始的目的主要是验证表单的输入验证 2、是一种具有面向对象能力的、解释型语言。是基于事件驱动的相对较安全的客户端脚本语言 3、JS 特点&#xff1a;松散型&#xff1a;变量不具备一个明确的类型&#xff1b; 对象属性&#xff1a;把属性名可以映射成任意的属性值&a…

opencv图像处理中的一些滤波器+利用滤波器提取条形码(解析二维码)+公交卡倾斜矫正+物体尺寸丈量

一般来说,图像的能量主要集中在其低频部分,噪声所在的频段主要在高频段,同时图像中的细节信息也主要集中在其高频部分,因此,如何去掉高频干扰同时又保持细节信息是关键。为了去除噪声,有必要对图像进行平滑,可以采用低通滤波的方法去除高频干扰。图像平滑包括空域法和频域法两大…

智联汽车:复盘国内巨头布局

来源&#xff1a;申万宏源摘要&#xff1a;从今年阿里9月云栖大会、华为10月全联接大会、百度11月世界大会、腾讯11月合作伙伴大会可以发现BATH均高调展示了各自在汽车科技领域的研发成果;而京东、滴滴两家公司近两年来关于汽车科技领域的动态亦在频频更新。▌车联网:车载OS竞争…

即插即用+任意blur的超分辨率重建——DPSR

计算机视觉中存在许多的不适定问题ill-posed problem。先来看什么是适定问题well-posed problem&#xff0c;适定问题必须同时满足三个条件&#xff1a; 1. a solution exists 解必须存在2. the solution is unique 解必须唯一3. the solutions behavior changes c…

Tomcat基础教程(一)

Tomcat, 是Servlet和JSP容器&#xff0c;其是实现了JSP规范的servlet容器。它在servlet生命周期内包容&#xff0c;装载&#xff0c;运行&#xff0c;和停止servlet容器。 Servlet容器的三种工作模式&#xff1a; 1. 独立的Servlet容器 Servlet容器与基于JAVA技术的Web服务器集…

opencv--图像金字塔

一&#xff0c;高斯金字塔--图片经过高斯下采样 """ 高斯金字塔 """ def gauss_pyramid():img cv2.imread(./data/img4.png)lower_reso cv2.pyrDown(img)lower_reso2 cv2.pyrDown(lower_reso)plt.subplot(131), plt.imshow(img)plt.title(In…

中国移动:5G蜂窝IoT关键技术分析

来源&#xff1a;5G本文讨论了蜂窝物联网的技术现状&#xff0c;针对增强机器类通信和窄带物联网技术标准&#xff0c;提出了2种现网快速部署方案&#xff0c;并进一步指出了C-IoT面向5G的演进路径。该路径充分考虑了5G网络中网络功能虚拟化、软件定义网络、移动边缘计算和大数…

dataframe常用操作总结

初始化 可以使用arraycolumns的格式&#xff0c; dpd.DataFrame(np.arange(10).reshape(2,5)) df1 pd.DataFrame([[Snow,M,22],[Tyrion,M,32],[Sansa,F,18],[Arya,F,14]], columns[name,gender,age]) 也可以使用字典大括号的格式&#xff1a; df pd.DataFrame({a: [1, 2…

DEDE无简略标题时显示完整标题

新闻的标题需要进行字数限制&#xff0c;这就需要加入一个title属性&#xff0c;让鼠标放上去的时候显示完整标题。另外目前的调用只能同时调用一种标题方式&#xff0c;不过可 以采用以下方法&#xff0c;进行判断&#xff0c;无简略标题显示完整标题。例如dede早期版本中的”…

清华大学发布:人脸识别最全知识图谱

来源&#xff1a;智东西摘要&#xff1a;本期我们推荐来自清华大学副教授唐杰领导的学者大数据挖掘项目Aminer的研究报告&#xff0c;讲解人脸识别技术及其应用领域&#xff0c;介绍人脸识别领域的国内玩人才并预测该技术的发展趋势。自20世纪下半叶&#xff0c;计算机视觉技术…

图像变换dpi(tif->jpg),直方图均衡化,腐蚀膨胀,分水岭,模板匹配,直线检测

一.图像变换dpi 1.示例1 import numpy as np from PIL import Image import cv2 def test_dp():path./gt_1.tif# imgImage.open(path)# print(img.size)# print(img.info)imgcv2.imread(path)imgImage.fromarray(img)print(img.size)print(img.info)img.save(test.jpg, dpi(3…

CV中的经典网络模型

目标检测 目标检测&#xff0c;不仅要识别目标是什么&#xff08;分类&#xff09;&#xff0c;还要知道目标的具体位置&#xff08;可以当作回归来做&#xff09;。 RCNN Selective Search 算法获得候选框&#xff0c;Alexnet提取特征&#xff0c;SVM对每个候选框区域打分。…

无表头单链表增删改查操作

1、返回单链表中第pos个结点中的元素&#xff0c;若pos超出范围&#xff0c;则返回&#xff10;2、把单链表中第pos个结点的值修改为x的值&#xff0c;若修改成功返回&#xff11;&#xff0c;否则返回&#xff10;3、向单链表的表头插入一个元素 4、向单链表的末尾添加一个元素…

JBU联合双边上采样

很多图像处理算法&#xff0c;如立体视觉中的深度估计&#xff0c;图像上色&#xff0c;高动态范围HDR中的tone mapping&#xff0c;图像分割&#xff0c;都有一个共性的问题&#xff1a;寻找一个全局的解&#xff0c;这个解是指一个分段的piecewise平滑含糊&#xff0c;描述了…