理解知识蒸馏中的散度损失函数(KLDivergence/kldivloss )-以DeepSeek为例

1. 知识蒸馏简介

什么是知识蒸馏?

知识蒸馏(Knowledge Distillation)是一种模型压缩技术,目标是让一个较小的模型(学生模型,Student Model)学习一个较大、性能更优的模型(教师模型,Teacher Model)的知识。这样,我们可以在保持较高准确率的同时,大幅减少计算和存储成本。

为什么需要知识蒸馏?

  • 降低计算成本:大模型(如 DeepSeek、GPT-4)通常计算量巨大,不适合部署到移动设备或边缘设备上。
  • 加速推理:较小的模型可以更快地推理,减少延迟。
  • 减少内存占用:适用于资源受限的环境,如嵌入式设备或低功耗服务器。

知识蒸馏的核心思想是:学生模型不仅仅学习教师模型的硬标签(one-hot labels),更重要的是学习教师模型输出的概率分布,从而获得更丰富的表示能力。

2. KL 散度的数学原理

2.1 KL 散度公式

在知识蒸馏过程中,我们通常使用Kullback-Leibler 散度(KL Divergence) 来衡量两个概率分布(教师模型和学生模型)之间的差异。

2.2 直观理解

KL 散度可以理解为如果用分布 Q 来近似分布 P,会损失多少信息

  • 当 KL 散度为 0,表示两个分布完全相同。
  • KL 散度不是对称的,即 D_{KL}(P || Q) \neq D_{KL}(Q || P)

3. DeepSeek 中的 KL 散度应用

DeepSeek 作为一个强大的开源大语言模型(LLM),在模型蒸馏时广泛使用了 KL 散度。例如,在训练较小版本的 DeepSeek 时,研究人员采用了温度标度(Temperature Scaling) 来调整教师模型的输出,使其更适合学生模型学习。

教师模型的 softmax 输出使用温度参数 TT 进行调整:

当 T 增大时,softmax 输出的概率分布变得更平滑,从而让学生模型更容易学习教师模型的知识。

在 DeepSeek 的蒸馏过程中,常见的损失函数是加权组合:

其中:

  • 第一项是 KL 散度损失,使得学生模型的输出接近教师模型。
  • 第二项是交叉熵损失,确保学生模型仍然学习真实标签。
  • λ是一个超参数,控制两者的平衡。

4. 代码示例:用 Keras 进行知识蒸馏

下面我们用 TensorFlow/Keras 训练一个简单的学生模型,让它学习一个教师模型的知识。

4.1 定义教师模型

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers# 构建一个简单的教师模型
teacher_model = keras.Sequential([layers.Dense(128, activation="relu", input_shape=(784,)),layers.Dense(10, activation="softmax")
])

4.2 训练教师模型

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.reshape(-1, 784) / 255.0, x_test.reshape(-1, 784) / 255.0
y_train, y_test = keras.utils.to_categorical(y_train, 10), keras.utils.to_categorical(y_test, 10)teacher_model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
teacher_model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test))

4.3 让教师模型生成 soft labels

temperature = 5.0
def soft_targets(logits):return tf.nn.softmax(logits / temperature)y_teacher = soft_targets(teacher_model.predict(x_train))

4.4 训练学生模型

student_model = keras.Sequential([layers.Dense(64, activation="relu", input_shape=(784,)),layers.Dense(10, activation="softmax")
])student_model.compile(optimizer="adam",loss=tf.keras.losses.KLDivergence(),  # 使用 KL 散度metrics=["accuracy"]
)student_model.fit(x_train, y_teacher, epochs=5, batch_size=32, validation_data=(x_test, y_test))

5. 真实应用场景

5.1 轻量级大模型

  • DistilBERT:使用 BERT 作为教师模型进行蒸馏,训练更小的 Transformer。
  • TinyBERT:针对任务优化蒸馏,提高学生模型的表现。
  • DeepSeek-Chat 小模型:使用 KL 散度训练高效版本,提高推理速度。

5.2 知识蒸馏的优势

  • 可以训练更小的模型,适用于移动端、嵌入式设备。
  • 学生模型比直接训练的模型泛化性更强,能更好地模仿教师模型。
  • 结合 KL 散度 + 交叉熵 可以提升训练效果。

结论

KL 散度损失是知识蒸馏的核心,它让学生模型学习教师模型的概率分布,从而获得更好的表现。DeepSeek 这样的 LLM 在蒸馏过程中广泛使用 KL 散度,使得较小模型也能高效推理。希望本文能帮助你理解 KL 散度在知识蒸馏中的应用!

其它

代码示例一,

假设我们有两个概率分布 p(真实分布)和 q(预测分布),我们使用 KLDivergence 计算它们之间的 KL 散度损失。

import tensorflow as tf
import numpy as np# 定义 KLDivergence 损失函数
kl_loss = tf.keras.losses.KLDivergence()# 真实分布 p (标签)
p = np.array([0.1, 0.4, 0.5], dtype=np.float32)# 预测分布 q
q = np.array([0.2, 0.3, 0.5], dtype=np.float32)# 计算 KL 散度损失
loss_value = kl_loss(p, q)print(f'KL Divergence Loss: {loss_value.numpy()}')

代码示例二,

一个完整的 Keras 代码示例,展示了如何在分类任务中使用 KLDivLoss 作为损失函数。这个示例使用一个简单的神经网络对 手写数字 MNIST 数据集 进行分类,并使用 KLDivLoss 计算真实分布和模型预测分布之间的散度。

import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# 归一化数据到 [0,1] 之间
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0# 将标签转换为概率分布 (one-hot 编码)
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)# 构建一个简单的神经网络模型
model = keras.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation="relu"),layers.Dense(10, activation="softmax")  # 输出层用 softmax 归一化
])# 编译模型,使用 KLDivLoss 作为损失函数
model.compile(optimizer="adam",loss=tf.keras.losses.KLDivergence(),metrics=["accuracy"])# 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_data=(x_test, y_test))# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_acc:.4f}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67738.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Electron使用WebAassembly实现CRC-8 MAXIM校验

Electron使用WebAssembly实现CRC-8 MAXIM校验 将C/C语言代码,经由WebAssembly编译为库函数,可以在JS语言环境进行调用。这里介绍在Electron工具环境使用WebAssembly调用CRC-8 MAXIM格式校验的方式。 CRC-8 MAXIM校验函数WebAssebly源文件 C语言实现CR…

Vue3.0实战:大数据平台可视化

文章目录 创建vue3.0项目项目初始化项目分辨率响应式设置项目顶部信息条创建页面主体创建全局引入echarts和axios后台接口创建express销售总量图实现完整项目下载项目任何问题都可在评论区,或者直接私信即可。 创建vue3.0项目 创建项目: vue create vueecharts选择第三项:…

vector容器(详解)

本文最后是模拟实现全部讲解,文章穿插有彩色字体,是我总结的技巧和关键 1.vector的介绍及使用 1.1 vector的介绍 https://cplusplus.com/reference/vector/vector/(vector的介绍) 了解 1. vector是表示可变大小数组的序列容器。…

Airflow:深入理解Apache Airflow Task

Apache Airflow是一个开源工作流管理平台,支持以编程方式编写、调度和监控工作流。由于其灵活性、可扩展性和强大的社区支持,它已迅速成为编排复杂数据管道的首选工具。在这篇博文中,我们将深入研究Apache Airflow 中的任务概念,探…

开发环境搭建-4:WSL 配置 docker 运行环境

在 WSL 环境中构建:WSL2 (2.3.26.0) Oracle Linux 8.7 官方镜像 基本概念说明 容器技术 利用 Linux 系统的 文件系统(UnionFS)、命名空间(namespace)、权限管理(cgroup),虚拟出一…

JavaScript 基础 - 7

关于JS函数部分的学习和一个案例的练习 1 函数封装 抽取相同部分代码封装 优点 提高代码复用性:封装好的函数可以在多个地方被重复调用,避免了重复编写相同的代码。例如,编写一个计算两个数之和的函数,在多个不同的计算场景中都…

详解u3d之AssetBundle

一.AssetBundle的概念 “AssetBundle”可以指两种不同但相关的东西。 1.1 AssetBundle指的是u3d在磁盘上生成的存放资源的目录 目录包含两种类型文件(下文简称AB包): 一个序列化文件,其中包含分解为各个对象并写入此单个文件的资源。资源文件&#x…

微信登录模块封装

文章目录 1.资质申请2.combinations-wx-login-starter1.目录结构2.pom.xml 引入okhttp依赖3.WxLoginProperties.java 属性配置4.WxLoginUtil.java 后端通过 code 获取 access_token的工具类5.WxLoginAutoConfiguration.java 自动配置类6.spring.factories 激活自动配置类 3.com…

MySQL数据库(二)- SQL

目录 ​编辑 一 DDL (一 数据库操作 1 查询-数据库(所有/当前) 2 创建-数据库 3 删除-数据库 4 使用-数据库 (二 表操作 1 创建-表结构 2 查询-所有表结构名称 3 查询-表结构内容 4 查询-建表语句 5 添加-字段名数据类型 6 修改-字段数据类…

ARM嵌入式学习--第十天(UART)

--UART介绍 UART(Universal Asynchonous Receiver and Transmitter)通用异步接收器,是一种通用串行数据总线,用于异步通信。该总线双向通信,可以实现全双工传输和接收。在嵌入式设计中,UART用来与PC进行通信,包括与监控…

Python3 OS模块中的文件/目录方法说明十七

一. 简介 前面文章简单学习了 Python3 中 OS模块中的文件/目录的部分函数。 本文继续来学习 OS 模块中文件、目录的操作方法:os.walk() 方法、os.write()方法 二. Python3 OS模块中的文件/目录方法 1. os.walk() 方法 os.walk() 方法用于生成目录树中的文件名&a…

css三角图标

案例三角&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title><s…

线性数据结构:单向链表

放弃眼高手低&#xff0c;你真正投入学习&#xff0c;会因为找到一个新方法产生成就感&#xff0c;学习不仅是片面的记单词、学高数......只要是提升自己的过程&#xff0c;探索到了未知&#xff0c;就是学习。 目录 一.链表的理解 二.链表的分类&#xff08;重点理解&#xf…

基于PyQt5打造的实用工具——PDF文件加图片水印,可调大小位置,可批量处理!

01 项目简介 &#xff08;1&#xff09;项目背景 随着PDF文件在信息交流中的广泛应用&#xff0c;用户对图片水印的添加提出了更高要求&#xff0c;既要美观&#xff0c;又需高效处理批量文件。现有工具难以实现精确调整和快速批量操作&#xff0c;操作繁琐且效果不理想。本项…

MCU内部ADC模块误差如何校准

本文章是笔者整理的备忘笔记。希望在帮助自己温习避免遗忘的同时&#xff0c;也能帮助其他需要参考的朋友。如有谬误&#xff0c;欢迎大家进行指正。 一、ADC误差校准引言 MCU 片内 ADC 模块的误差总包括了 5 个静态参数 (静态失调&#xff0c;增益误差&#xff0c;微分非线性…

“新月智能武器系统”CIWS,开启智能武器的新纪元

新月人物传记&#xff1a;人物传记之新月篇-CSDN博客 相关文章链接&#xff1a;星际战争模拟系统&#xff1a;新月的编程之道-CSDN博客 新月智能护甲系统CMIA--未来战场的守护者-CSDN博客 “新月之智”智能战术头盔系统&#xff08;CITHS&#xff09;-CSDN博客 目录 智能武…

实验六 项目二 简易信号发生器的设计与实现 (HEU)

声明&#xff1a;代码部分使用了AI工具 实验六 综合考核 Quartus 18.0 FPGA 5CSXFC6D6F31C6N 1. 实验项目 要求利用硬件描述语言Verilog&#xff08;或VHDL&#xff09;、图形描述方式、IP核&#xff0c;结合数字系统设计方法&#xff0c;在Quartus开发环境下&#xff…

SCRM系统如何提升客户管理及业务协同的效率与价值

内容概要 在当今商业环境中&#xff0c;SCRM系统&#xff08;社交客户关系管理系统&#xff09;正逐渐受到越来越多企业的关注和重视。随着科技的发展&#xff0c;传统的客户管理方式已经无法满足快速变化的市场需求&#xff0c;SCRM系统通过整合客户数据和社交网络信息&#…

[免费]微信小程序智能商城系统(uniapp+Springboot后端+vue管理端)【论文+源码+SQL脚本】

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序智能商城系统(uniappSpringboot后端vue管理端)&#xff0c;分享下哈。 项目视频演示 【免费】微信小程序智能商城系统(uniappSpringboot后端vue管理端) Java毕业设计_哔哩哔哩_bilibili 项目介绍…

PID算法的数学实现和参数确定方法

目录 概述 1 算法描述 1.1 PID算法模型 1.2 PID离散化的图形描述 1.3 PID算法的特点 2 离散化的PID算法 2.1 位置式PID算法 2.2 增量式PID算法 2.3 位置式PID与增量式PID比较 3 控制器参数整定 3.1 PID参数确定方法 3.1.1 凑试法 3.1.2 临界比例法 3.1.3 经验法…