简单的知识蒸馏

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

import keras
from keras import layers
from keras import ops
import numpy as np

# 随着训练的进行,由于总损失(loss)是学生损失(student_loss)和蒸馏损失(distillation_loss)的加权和,
# 模型会同时考虑减少自身的预测损失(即提高预测准确性)和与教师模型预测分布的相似性。通过调整 alpha 参数,您
# 可以控制这两个目标之间的权衡。较大的 alpha 值将使模型更关注自身的预测准确性,而较小的 alpha 值则会使模型
# 更关注与教师模型预测分布的相似性。

# 知识蒸馏一般应该是从复杂的精度高的模型到简单的模型,让学生模型去学习教师模型的预测分布,但这个例子,因为简单的模型也
# 能达到不错的精度,所以没看出来性能提升

#教师模型一般应该是预训练模型,在高分辨率的图片数据集上训练过的,学生模型用来学习教师模型的预测概率分布
# 学生模型的结构和复杂性应该根据任务的要求、数据的特性以及资源限制来仔细选择。如果学生模型过于简单,它可能无法
# 捕捉到教师模型学习到的复杂模式和特征,导致性能不佳。另一方面,如果学生模型过于复杂,虽然它可能能够学习到更多
# 的细节和特征,但也可能导致过拟合,并且在计算资源上可能不高效。因此,在选择学生模型的结构时,需要进行权衡和实验。
# 一种常见的做法是从一个简单的模型开始,并逐步增加其复杂性,以观察性能如何变化。通过这种方式,可以找到在给定资源
# 和任务要求下性能最佳的学生模型结构。

#知识提炼器
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher#教师模型
        self.student = student#学生模型
    #编译,保存一些优化器,损失函数,权重,温度等的参数
    def compile(
        self,
        optimizer,
        metrics,
        student_loss_fn,
        distillation_loss_fn,
        alpha=0.1,
        temperature=3,
    ):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn#学生损失函数
        self.distillation_loss_fn = distillation_loss_fn#蒸馏损失函数
        self.alpha = alpha#蒸馏权重
        self.temperature = temperature#温度参数
    #计算损失
    def compute_loss(
        self, x=None, y=None, y_pred=None, sample_weight=None, allow_empty=False
    ):
        teacher_pred = self.teacher(x, training=False)#获取教师模型的预测
        student_loss = self.student_loss_fn(y, y_pred)#根据学生损失函数计算学生损失
        # 计算蒸馏损失。这通常涉及将教师模型和学生模型的预测都通过softmax函数(使用温度参数进行缩放),
        # 然后计算两者之间的差异。这里乘以(self.temperature**2)是一个常见的调整,用于平衡蒸馏损失。
        distillation_loss = self.distillation_loss_fn(
            ops.softmax(teacher_pred / self.temperature, axis=1),
            ops.softmax(y_pred / self.temperature, axis=1),
        ) * (self.temperature**2)
        # 根据alpha参数,将学生损失和蒸馏损失组合成一个总损失
        loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        return loss
    def call(self, x):
        return self.student(x)

#教师模型比较大,学生模型比较小
teacher = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),#(14,14,256)
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),#(14,14,256)
        layers.Conv2D(512, (3, 3), strides=(2, 2), padding="same"),#(7,7,512)
        layers.Flatten(),#(7*7*512)
        layers.Dense(10),
    ],
    name="teacher",
)

student = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(16, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(negative_slope=0.2),
        layers.MaxPooling2D(pool_size=(2, 2), strides=(1, 1), padding="same"),
        layers.Conv2D(32, (3, 3), strides=(2, 2), padding="same"),
        layers.Flatten(),
        layers.Dense(10),
    ],
    name="student",
)
student_scratch = keras.models.clone_model(student)#新模型与原模型具有相同的结构,但不用原模型的权重,优化器等等

batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

print(x_train.shape,x_test.shape,x_train.dtype,np.max(x_train),np.min(x_train))

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))

print(x_train.shape,x_test.shape,x_train.dtype,np.max(x_train),np.min(x_train))

teacher.summary()#3*3*1*256+256,3*3*256*512+512,flatten:7*7*512,把像素值展平成向量,25088*10+10

teacher.compile(
    optimizer=keras.optimizers.Adam(),#优化器
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#损失函数:多元交叉熵
    metrics=[keras.metrics.SparseCategoricalAccuracy()],#指标:准确率
)

teacher.fit(x_train, y_train, epochs=5)

teacher.evaluate(x_test, y_test)

distiller = Distiller(student=student, teacher=teacher)#构建知识提炼器

distiller.compile(
    optimizer=keras.optimizers.Adam(),#优化器
    metrics=[keras.metrics.SparseCategoricalAccuracy('acc')],#度量指标
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),#学生损失函数:多元交叉熵
    distillation_loss_fn=keras.losses.KLDivergence(),#知识提炼损失:kld
    alpha=0.1,#提炼权重(用来设定学生和提炼损失的占比)
    temperature=10,#温度,缩放系数
)

distiller.fit(x_train, y_train, epochs=3)#提炼教师到学生(让学生模型能学习教师模型的预测分布)

distiller.evaluate(x_test, y_test)

student_scratch.compile(#这个拷贝模型的职责就是衡量知识提炼中学生模型究竟从教师模型中学到了多少东西
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[keras.metrics.SparseCategoricalAccuracy('acc')],
)

student_scratch.fit(x_train, y_train, epochs=3)

student_scratch.evaluate(x_test, y_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/7112.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大模型日报2024-05-06

大模型日报 2024-05-06 大模型技术 智谱AI 正研发对标Sora的国产文生视频模型,最快年内发布 摘要: 估值超200亿的国内 A1大模型独角兽公司“智谱 A“正在研发对标OpenAl Sora的高质量文生视频模型,预计最快年内发布。据悉,成立于2019年的智谱…

selenium解放双手--记某电力学校的刷课脚本

免责声明:本文仅做技术交流与学习... 重难点: 1-对目标网站的html框架具有很好的了解,定位元素,精准打击. 2-自动化过程中窗口操作的转换. 前置知识: python--selenium模块的操作使用 前端的html代码 验证码自动化操作 Chrome & Chromedriver : Chrome for Testing ava…

[机器学习-02] 数据可视化神器:Matplotlib和Seaborn工具包实战图形大全

目录 引言 正文 01-Matplotlib包的使用示例 1)Matplotlib导入方式 2)折线图绘制 3)散点图绘制 4)柱状图绘制 5)饼图绘制 6)等高线图绘制 7)箱线图绘制 8)较为复杂…

7zip如何只压缩文件不带上级目录?

在使用7zip进行文件压缩的时候,如果直接选择要压缩的文件进行压缩,得到的压缩包则会多包含一层顶层目录,解压缩之后需要点击两次才能进入到实际目录中,为了解决这个问题,本文根据探索找到了一种解决办法。 如下是一个演…

表空间的概述

目录 表空间的属性 表空间的类型 永久性表空间(PermanentTablespace) 临时表空间(Temp Tablespace ) 撤销表空间(Undo Tablespace) 大文件表空间(BigfileTablespace) 表空间的状态 联机状态(Online) 读写状态(Read Write) 只读状态(Read) 脱机状态(Offline) Oracle从…

Java_从入门到JavaEE_09

一、构造方法/构造器 含义:和new一起是创建对象的功能 特点: 与类名相同的方法没有返回项 注意: 当类中没有写构造方法时,系统会默认添加无参构造(无参数的构造方法)构造方法可以重载的 有参构造好处&…

JavaWeb入门-HTML

一、HTML 1.HTML 网络的骨架 超文本标记语言 ①超文本 图片、音频、视频、普通文本。。。 ②标记语言 语法&#xff1a;通过标签的形式展示 a.双标签 <html>内容</html> b.单标签 <br> 2.HelloWorld ①新建网页文件&#xff08;后…

代码随想录算法训练营第四十三天| 1049. 最后一块石头的重量 II,494. 目标和,474.一和零

题目链接&#xff1a;1049. 最后一块石头的重量 II 思路 把石头分成重量尽量相同的两堆&#xff0c;这样就能保证最后一块石头的重量最小。转换为01背包问题&#xff0c;重量和价值都是stone。 ①dp数组&#xff0c;dp[j]表示容量为j的背包可以装的最大价值为dp[j] ②递推公式…

探索Linux目录结构:深入理解Linux文件系统

探索Linux目录结构&#xff1a;深入理解Linux文件系统 Linux操作系统以其强大的稳定性和灵活性而闻名&#xff0c;其中一个关键特征就是其独特的文件系统结构。深入了解Linux目录结构对于系统管理员和开发人员至关重要。本文将带您深入探索Linux文件系统的目录结构&#xff0c…

透明加密软件选哪个好?选择时一定要注意以下三点

透明加密软件哪个好&#xff1f; 这是许多企事业单位在面临数据防泄漏问题时经常思考的问题。随着信息技术的发展&#xff0c;企业的数据安全变得越来越重要。透明加密技术作为一种有效的数据保护手段&#xff0c;被越来越多的企业所采用。然而&#xff0c;市场上的透明加密软…

delphi获取进程版本信息

结构体声明 typeTFileInfo packed recordCommpanyName: widestring;FileDescription: widestring;FileVersion: widestring;InternalName: widestring;LegalCopyright: widestring;LegalTrademarks: widestring;OriginalFileName: widestring;ProductName: widestring;Produc…

Django高级表单处理与验证实战

title: Django高级表单处理与验证实战 date: 2024/5/6 20:47:15 updated: 2024/5/6 20:47:15 categories: 后端开发 tags: Django表单验证逻辑模板渲染安全措施表单测试重定向管理最佳实践 引言&#xff1a; 在Web应用开发中&#xff0c;表单是用户与应用之间进行交互的重要…

OpenHarmony实战开发-请求自绘制内容绘制帧率

对于基于XComponent进行Native开发的业务&#xff0c;可以请求独立的绘制帧率进行内容开发&#xff0c;如游戏、自绘制UI框架对接等场景。 接口说明 开发步骤 说明&#xff1a; 本范例是通过Drawing在Native侧实现图形的绘制&#xff0c;并将其呈现在NativeWindow上 1.定义Ark…

《第一行代码》第二版学习笔记(7)——使用通知和摄像头

文章目录 一、使用通知二、调用摄像头 介绍了通知基于8.0的使用方法和如何调用摄像头拍照 一、使用通知 public void onClick(View v) {if (v.getId() R.id.send_notice){Intent intent new Intent(this,NotificationActivity.class);PendingIntent pi PendingIntent.getAct…

【哈希表】Leetcode 14. 最长公共前缀

题目讲解 14. 最长公共前缀 算法讲解 我们使用当前第一个字符串中的与后面的字符串作比较&#xff0c;如果第一个字符串中的字符没有出现在后面的字符串中&#xff0c;我们就直接返回&#xff1b;反之当容器中的所有字符串都遍历完成&#xff0c;说明所有的字符串都在该位置…

源码拾贝三则

目录 一 一种枚举类型的新型使用方式 二 Eigen库中的LDLT分解 三 Eigen中的访问者模式 一 一种枚举类型的新型使用方式 ///D:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30133\include\xiosbase enum _Iostate { // consta…

springcloud第4季 springcloud-alibaba之分布式事务seata

一 seata介绍 1.1 seata介绍 1.seata是一款解决分布式事务的解决方案&#xff0c;致力于在微服务架构下提供高性能和简单易用的分布式事务服务。 2.seata的几种术语&#xff1a;一个中心&#xff1a;全局事务id TC(Transaction Coordinator):事务协调者。负责维护全局和分…

双非二本找工作前的准备day21

学习目标&#xff1a; 每天复习代码随想录上的题目1-2道算法&#xff08;时间充足可以继续&#xff09; 今日碎碎念&#xff1a; 1&#xff09;今天开始是二叉树系列 2&#xff09;出租屋里不知道干啥&#xff0c;看看书啊刷刷算法&#xff0c;打打游戏&#xff0c;学学技术…

通过iMock学习Jvmsandbox

Jvm-sandbox Jvm-sandbox基于Jvm-sandbox的Mock平台iMockiMock的工程学习iMock怎么写的&#xff08;sandbox的module应该怎么写&#xff09; Jvm-sandbox Jvm-sandbox是阿里开源的一款java的沙箱&#xff0c;看网上的介绍在沙箱里你可以做你能想到的奇妙的事情。 基于Jvm-san…