tensorflow定制模型和训练算法

1.从自定义损失函数开始

这里先实现一个Lenet-5模型

input_data = keras.layers.Input(shape = (28,28,1))
conv_1 = keras.layers.Conv2D(filters=6, kernel_size=(5,5), strides=1, activation='relu', padding = 'same')(input_data)
pool_1 = keras.layers.AveragePooling2D(pool_size=(2,2), strides=2, padding='valid')(conv_1)
s1 = keras.layers.Normalization()(pool_1)conv_2 = keras.layers.Conv2D(filters=16, kernel_size=(5,5),strides=1,activation='relu', padding = 'valid')(s1)
pool_2 = keras.layers.AveragePooling2D(pool_size=(2,2),strides=2, padding = 'valid')(conv_2)
s2 = keras.layers.Normalization()(pool_2)conv_3 = keras.layers.Conv2D(filters=120, kernel_size = 5, strides=1, activation='relu', padding = 'valid')(s2)
flatten = keras.layers.Flatten()(conv_3)
dense = keras.layers.Dense(120, activation='relu')(flatten)
output = keras.layers.Dense(10, activation='softmax')(dense)
model = keras.Model(inputs = [input_data], outputs= [output])checkPoint = ModelCheckpoint(filepath='model.h5', monitor='val_acc', verbose=1, save_best_only=True)
callback = [checkPoint]
model.compile(loss=’categorical_crossentropy', optimizer='sgd', metrics=['acc'])
hist = model.fit(x_train, y_train, batch_size=32, epochs=10, validation_split=0.3, callbacks=callback)

tensorflow构建模型有两种方法,一种是顺序式,一种是函数式,我用的是函数式。

神经网络架构:

输入28,28,1
卷积层15,5,6
平均池化层12,2
卷积层25,5,16
平均池化层22,2
卷积层35,5,120
全连接层184
输出层10

PS:中间隐藏层使用的是函数为tanh,末尾输出为softmax

我修改了激活函数,使用的是RELU,并在每一层输出后使用了归一化层.

依次为例,如果需要修改这个模型的组件,应该有哪些步骤呢?比如先修改一下损失函数,这里使用的是交叉熵。

修改为自己提供的损失函数。然后保存模型。应该怎么做?

先写出损失函数,考虑到批量梯度下降算法,应该是矩阵的算法

def My_loss(y_true, y_pred):error = 0.5 * tf.square(y_true, y_pred)return error

然后修改model.compile里面的loss参数

model.compile(loss=My_loss, optimizer='sgd', metrics=['acc'])

这样就能使用自己定义的组件了。

如果我想要修改0.5为0.1怎么办?

你可能会说,在函数里面修改一下不就得了嘛,但是我想保存模型的时候,把这个数字作为自定义的类型呢?

也就是当我们保存模型之后,再次加载模型,然后输入自己的参数。

2.保存和加载自定义的组件

每一个组件都有自己对应的父类,我们把自己的组件作为子类,然后继承父类,实现几个必要的函数就行了。然后使用字典进行初始化。这么说,肯定有点抽象,我们来做一件具体的事情

def My_loss(y_true, y_pred, r):error = r * tf.square(y_true, y_pred)return error

这里的r就是可变的,但是我们没办法使用到模型中,因为有三个参数,没办法灵活的定义r

model.compile(loss=My_loss(0.1), optimizer='sgd', metrics=['acc'])

你可能会想这么写

显然是错的,程序会崩溃。

解决办法是

def new_loss(r):def My_loss(y_true, y_pred):error = r * tf.square(y_true - y_pred)return errorreturn My_loss

model.compile(loss=new_loss(0.1), optimizer='sgd', metrics=['acc'])

model.save_model('model.h5')

这样在保存的模型不会保存你之前设定的阈值,这里就必须自己设定了。

mymodel = keras.Models.load_model('model.h5', custum_objects={‘My_loss’:new_loss(0.1)})

如果想要保存下来,就必须用继承的方式

class NewLoss(keras.Losses.Loss):def __init__(self,r, *kwargs):self.r = rsuper().__self__(*kwargs)def call(self, y_true, y_y_pred):def myloss(y_true, y_pred):error = self.r * tf.square(y_true - y_pred)return errordef get_config(self):base_config = super().get_config()return {**base_config, 'r' :self.r}

model = keras.models.load_model('model.h5', custom_objects={'NewLoss':NewLoss})

然后加载保存的模型 ,就不需要指定参数的大小

不仅仅是这样的损失函数可以保存,还可以自定义包括正则化,初始化,激活函数,层等等,方法与上面的损失一样。

3.建立自己的评价指标

        评价指标有时候和损失函数混用,但是他们毕竟不是一个东西,我导师经常逼我用F1分数给神经网络调参,我真的好无语,如果我说损失函数必须可微才能执行梯度下降,她就会觉得我代码写错了,我最后只好说,啊对对对,我用的是f1分数调参。太痛了。总之,评价函数和损失函数是两个不一样的物种。

        keras.metrics里面有许多评价指标,精确度,召回率等等,在训练的时候,如果想要计算自己的指标,也可以按照上面提到的继承法来设置自己的指标。

class My_mertrics(keras.metrics.Metric):def __init__(self, r =0.5, **kwargs):super().__init__(**kwargs)self.r = rself.myloss = NewLoss(r)self.total = self.add_weight('total', initializer='zeros')self.count = self.add_weight('count', initializer='zeros')def update_state(self, y_true, y_pred, sample_weight = None):metric = self.myloss(y_true, y_pred)self.total.assign_add(tf.reduce_sum(metric))self.count.assign_add(tf.cast(tf.size(y_true), tf.float32))def result(self):return self.total / self.countdef get_config(self):base_config = super().get_config()return {**base_config, "r":self.r}

update_state是在每个训练批次进行调用的函数,keras会自动跟踪数据,最后每次返回前面所积累的数据 。

model.compile(loss=new_loss(0.1), optimizer='sgd', metrics=[My_metrics(0.1)])

就会输出自定义的标准了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/225315.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

动态规划——斐波那契数列模型:1137.第N个泰波那契数

文章目录 题目描述算法原理1.状态表示(最重要的)什么是状态表示?状态表示怎么来的呢?本题的状态表示 2.状态转移方程(最难的)本题的状态转移方程 3.初始化(后三步完成剩下百分之一的细节问题)本题的初始化 4.填表顺序本…

第23节: Vue3 绑定 HTML 类

在UniApp中使用Vue3框架时&#xff0c;你可以使用类绑定语法来动态地添加或移除HTML元素的类。 下面是一个示例&#xff0c;演示了如何在UniApp中使用Vue3框架使用绑定HTML类&#xff1a; <template> <view> <button click"toggleClass">Toggl…

C语言实现快速傅立叶(FFT)(一)

1. FFT理论相关知识 FFT&#xff08;快速傅里叶变换&#xff09;其本质就是DFT&#xff0c;只不过可以快速的计算出DFT结果&#xff0c;所以首先应该理解DFT&#xff0c;DFT(Discrete Fourier Transform) 离散傅里叶变换的缩写&#xff0c;FFT(Fast Fourier Transform)快速傅里…

【算法与数据结构】376、LeetCode摆动序列

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析&#xff1a;本题难点在于要考虑到不同序列的情况&#xff0c;具体来说要考虑一下几种特殊情况&#xff1a; 1、上…

4.qml 3D-Light、DirectionalLight、PointLight、SpotLight、AxisHelper类深入学习

今天我们学习灯光类 首先来学习Light类&#xff0c;它是所有灯光的虚基类&#xff0c;该类是无法创建的&#xff0c;主要是为子类提供很多公共属性。 常用属性如下所示&#xff1a; ambientColor : color&#xff0c;该属性定义在被该光照亮之前应用于材质的环境颜色。默认值…

oracle 锁表解决办法

相关表介绍 V$LOCKED_OBJECT&#xff08;记录锁信息的表&#xff09;v$session&#xff08;记录会话信息的表&#xff09;v$sql&#xff08;记录 sql 执行的表&#xff09;dba_objects&#xff08;用来管理对象&#xff0c;表、库等等&#xff09; 查询锁表的 SID select b.…

Cockpit upload文件上传漏洞(CVE-2023-1313)

0x01 产品简介 Cockpit 是一个自托管、灵活且用户友好的无头内容平台,用于创建自定义数字体验。 0x02 漏洞概述 Cockpit assetsmanager/upload接口处存在文件上传漏洞,攻击者可通过该漏洞在服务器端任意上传代码,写入后门,获取服务器权限,进而控制整个web服务器。 0x0…

基于SpringBoot的语言课学习系统

文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 🍅文末获取源码联系🍅 项目介绍 基于SpringBoot的语言课学习系统,java项…

Web开发伴侣 Prepros 7.17 Crack

您友好的 Web 开发伙伴&#xff0c;Prepros 编译您的文件、转译您的 JavaScript、重新加载您的浏览器并 使开发变得非常容易测试您的网站&#xff0c;以便您可以专注于制作 他们完美。 编译一切 Prepros 可以编译 Sass、Less、Stylus、Pug/Jade、Haml、Slim、CoffeeScript 和 …

linux中core调度器

背景 开始把core调度器当成了linux的主调度器&#xff0c;导致查找网上资料时总觉得对不上&#xff0c;最后从linux的rust文档中明白了&#xff0c;core调度器是为了解决超线程场景下缓存漏洞&#xff08;如mds、L1HF&#xff09;而存在的。简单来说就是一个cpu上同时运行两个…

翻译: 为什么需要微调大模型 Why Fine-tuning LLM

虽然RAG提供了一种方式来给大型语言模型提供额外的信息&#xff0c;但还有另一种叫做微调&#xff08;fine-tuning&#xff09;的技术&#xff0c;也是给它更多信息的一种方式。特别是&#xff0c;如果你有的上下文比大型语言模型的输入长度或上下文窗口长度更大&#xff0c;那…

如何使用ArcGIS Pro拼接影像

为了方便数据的存储和传输&#xff0c;我们在网上获取到的影像一般都是分块的&#xff0c;正式使用之前需要对这些影像进行拼接&#xff0c;这里为大家介绍一下ArcGIS Pro中拼接影像的方法&#xff0c;希望能对你有所帮助。 数据来源 本教程所使用的数据是从水经微图中下载的…

ArcGIS Pro SDK文件选择对话框

文件保存对话框 // 获取默认数据库var gdbPath Project.Current.DefaultGeodatabasePath;//设置文件的保存路径SaveItemDialog saveLayerFileDialog new SaveItemDialog(){Title "Save Layer File",OverwritePrompt true,//获取或设置当同名文件已存在时是否出现…

PPT插件-好用的插件-PPT 素材该怎么积累-大珩助手

PPT 素材该怎么积累&#xff1f; 使用大珩助手中的素材库功能&#xff0c;将Word中的&#xff0c;或系统中的文本文件、图片、其他word文档、pdf&#xff0c;所有见到的好素材&#xff0c;一键收纳。 步骤&#xff1a;选中文件&#xff0c;按住鼠标左键拖到素材库界面中&…

微服务架构之争:Quarkus VS Spring Boot

在容器时代&#xff08;“Docker时代”&#xff09;&#xff0c;无论如何&#xff0c;Java仍然活着。Java在性能方面一直很有名&#xff0c;主要是因为代码和真实机器之间的抽象层&#xff0c;多平台的成本&#xff08;一次编写&#xff0c;随处运行——还记得吗&#xff1f;&a…

轻松入门:Python 中的 Scipy 库初探

写在开头 Python在科学计算领域中的强大地位得益于其丰富的库和工具&#xff0c;而Scipy库则是这个生态系统中的一颗璀璨明珠。本文将带你轻松入门Scipy库&#xff0c;深入探索其基本用途和功能。 1.scipy库的简介 Scipy库是Scientific Python的缩写&#xff0c;是建立在Num…

虚拟电厂 能源物联新方向

今年有多热&#xff1f;据上海市气象局官微消息&#xff0c;5月29日13时09分&#xff0c;徐家汇站气温达36.1℃&#xff0c;打破了百年来的当地5月份气温*高纪录。不仅如此&#xff0c;北京、四川、江西、湖南、广东、广西等地也频频发布高温预警。 伴随着居民用电急剧攀升&am…

什么是PSR标准?有哪些常见的PSR标准?

PSR 是 PHP Standard Recommendation&#xff08;PHP 标准推荐&#xff09;的缩写&#xff0c;是由 PHP-FIG&#xff08;PHP Framework Interop Group&#xff09;组织提出并维护的一系列 PHP 编程规范。这些规范旨在促进 PHP 生态系统中各种项目的互操作性和可维护性。以下是一…

Gitee:远程仓库步骤

第一步&#xff1a;新建仓库 第二步&#xff1a;初始化本地仓库&#xff0c;git init 创建分支 git branch 新分支名 第三步&#xff1a;git add . &#xff1a;添加到暂存区 第四步&#xff1a;git config –global user.email关联邮箱&#xff0c;user.name用户名 第…

LeetCode137. Single Number II

文章目录 一、题目二、题解 一、题目 Given an integer array nums where every element appears three times except for one, which appears exactly once. Find the single element and return it. You must implement a solution with a linear runtime complexity and u…