DeepSeek R1中提到“知识蒸馏”到底是什么

在 DeepSeek-R1 中,知识蒸馏(Knowledge Distillation)是实现模型高效压缩与性能优化的核心技术之一。在DeepSeek的论文中,使用 DeepSeek-R1(教师模型)生成 800K 高质量训练样本,涵盖数学、编程、科学推理等任务,并通过规则过滤混合语言、冗余段落和代码块,样本包括结构化推理过程(如 <think> 标签内的思考链)和最终答案。

蒸馏能将大模型(DeepSeek-R1)的复杂推理模式(如长链思考、自我验证)迁移至小模型。例如,DeepSeek-R1-Distill-Qwen-32B 在 AIME 2024 上达到 72.6%(Pass@1),显著优于 QwQ-32B-Preview(50.0%)。

DeepSeek R1 通过蒸馏将大模型的推理能力“压缩”至小模型,兼顾性能与成本,推动推理技术的广泛落地,并为社区提供高效开源工具。

1. 什么是知识蒸馏?

想象你是一个刚学做菜的新手,想复刻米其林大厨的招牌菜。如果只告诉你最终味道(比如“酸甜适中”),你很难完美复制。但如果你能知道大厨做菜时的 每个细节(比如火候调整顺序、调料配比、食材处理技巧),你就能学得更像。

深度学习中的知识蒸馏(Knowledge Distillation) 就是类似的过程,知识蒸馏中,有两个重要的角色:

  • 老师模型(Teacher Model):一个复杂的大模型(比如GPT-3、ResNet-152),性能强大但计算成本高。
  • 学生模型(Student Model):一个简单的小模型(比如MobileNet),轻量但性能较弱。

 

假设有一个经验丰富的老师(比如一个大而复杂的机器学习模型),它知识渊博,但反应慢、体积大(比如需要很强的算力才能运行)。
现在,你想培养一个学生(比如一个小而轻的模型),让它也能掌握老师的知识,但反应快、体积小(比如能在手机或小设备上运行)。 

这时就可以用蒸馏(Distillation)——让老师把自己的“经验”提炼出来,教给学生。

知识蒸馏目标

让学生模型通过“观察”老师模型的决策过程(而不仅是最终结果),继承老师的“经验”,最终达到接近老师的性能。

蒸馏的关键:学老师的「判断方式」

  • 传统方法:学生直接学“正确答案”(比如标签:“这张图是猫”)。

  • 蒸馏方法:学生不仅学答案,还学老师更细致的“思考过程”。比如,老师可能会说:“这张图有99%概率是猫,0.8%是狗,0.2%是狐狸……” 这种概率分布(也叫“软标签”)比单纯的答案(“是猫”)包含更多信息。

学生通过模仿老师的这种细致判断,能学得更像老师的思维方式,最终达到接近老师的效果,但体积和速度却好得多。

2. 为什么需要知识蒸馏?

大模型的困境

在大模型火爆的今天,使用大模型的人越来越多,但大模型通常参数众多,计算成本高昂且资源消耗巨大,而蒸馏技术可以将这些大型的教师模型的知识传递给规模更小的学生模型,从而显著降低计算复杂度和存储需求,使得模型更适合在资源受限的环境中部署。

小模型的优势

相对来说,部署更小的模型需要更少的GPU资源,并且小模型的推理速度更快。此外,通过知识蒸馏使得一些模型能在手机、摄像头等边缘设备运行。

3. 蒸馏的核心思想——学“软标签”而不是“硬标签”

我们以图片识别任务为例,来对比一下传统的训练与知识蒸馏的训练:

传统训练(硬标签)知识蒸馏(软标签)
输入一张图片一张图片
标签猫(猫100%)教师模型的输出(猫90%,狗5%,...)
学习目标

模型直接学习“非黑即白”的答案。

学生模型的输出尽可能的接近老师模型。

相比于使用硬标签,软标签的优势如下: 

1. 丰富的信息表达

软标签提供了更加灵活和丰富的信息。在分类问题中,软标签是一个概率分布,表示样本属于各个类别的可能性,而硬标签仅提供了一个确定的类别。这种概率分布的形式能够更好地反映数据的复杂性和不确定性,有助于模型学习到更细致的数据特征。(比如猫和狗都有四条腿,但猫更可能尖耳朵)。

2. 提升模型泛化能力

软标签通过提供类别间的关联信息,帮助模型学习到更平滑的决策边界,从而提高模型的泛化能力。在面对模糊分类、噪声数据或类别间界限不明确的情况时,软标签能够使模型更好地处理这些复杂情况,提高分类准确率。

3. 防止过拟合

软标签作为一种正则化手段,能够减少模型对训练数据的过度拟合。通过引入软标签,模型被迫考虑更多的类别可能性,而不是仅仅关注正确类别,这有助于模型在训练过程中保持一定的“不确定性”,从而提高其在未见数据上的表现。

4. 优化效率更高

在优化过程中,软标签可以保证优化过程始终处于优化效率最高的中间区域,避免进入饱和区。相比之下,硬标签监督下,由于 softmax 的作用,优化到达一定程度时,优化效率会显著降低。而软标签通过提供更平滑的概率分布,使得模型在训练过程中能够更有效地更新参数。

5. 更好的知识迁移

在知识蒸馏中,教师模型的软标签包含了其对数据的深层次理解和特征捕捉。通过使用软标签,学生模型能够更好地模仿教师模型的行为,学习到教师模型的决策过程和知识表示,从而在保持较高性能的同时实现模型压缩。

6. 提高模型鲁棒性

软标签能够增强模型的鲁棒性,使其在面对数据噪声和对抗攻击时表现得更加稳定。通过学习软标签中的概率分布,模型能够更好地处理输入数据中的不确定性,从而减少对噪声和对抗样本的敏感性。

7. 适用于复杂任务

在一些复杂的任务中,如多标签分类或多模态学习,软标签能够更好地捕捉数据之间的细微差别和关联性。例如,在图文检索任务中,软标签可以跨模态和模态内捕获更细粒度和细微的语义信息,从而提高模型的性能。

8. 提供更平滑的标签分布

软标签通过引入温度参数,可以调节教师模型输出概率分布的平滑程度。当温度参数大于1时,教师模型的输出变得更加平滑,这有助于学生模型更容易地模仿教师模型的行为,从而提高蒸馏效果。

9. 降低学习难度

与仅使用硬标签的传统训练方法相比,知识蒸馏技术通过引入教师模型的软标签信息,显著降低了学生模型的学习难度。这种知识迁移机制使得构建小型高效模型成为可能,为模型压缩技术提供了新的解决方案。

10. 增强模型校准

软标签能够使模型输出的预测概率更加接近真实概率,从而增强模型的校准能力。这对于一些需要精确概率估计的任务,如风险评估和决策支持系统,具有重要意义。

 4. 蒸馏的关键步骤

步骤1:训练老师模型

用常规方法训练一个大模型(例如ResNet-50),使其在任务上达到高精度。

步骤2:生成软标签

用老师模型对训练数据做预测,得到每个样本的概率分布(例如 [0.99, 0.008, 0.002])。

步骤3:训练学生模型

学生模型同时学习:

  1. 软标签损失:模仿老师的概率分布(使用KL散度或交叉熵)。

  2. 硬标签损失(可选):传统的真实标签损失。

  3. 温度参数(Temperature):软化概率分布,让模型更关注类别间的关系。

 5. PyTorch实现蒸馏的完整代码案例

我们使用CIFAR-10数据集,将ResNet-18(老师)的知识蒸馏到MobileNetV2(学生)。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 超参数
epochs = 20
batch_size = 256
temperature = 4  # 温度参数
alpha = 0.7      # 软标签损失权重# 数据加载
transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)# 定义老师模型(ResNet-18)
teacher = torchvision.models.resnet18(pretrained=True)
teacher.fc = nn.Linear(teacher.fc.in_features, 10)  # CIFAR-10有10类
teacher = teacher.to(device)# 定义学生模型(MobileNetV2)
student = torchvision.models.mobilenet_v2(pretrained=True)
student.classifier[1] = nn.Linear(student.last_channel, 10)
student = student.to(device)# 训练老师模型(此处假设老师已预训练好,直接加载)
# 实际中需要先训练老师模型,此处为简化跳过# 定义损失函数和优化器
criterion_hard = nn.CrossEntropyLoss()           # 硬标签损失
criterion_soft = nn.KLDivLoss(reduction='batchmean')  # 软标签损失
optimizer = optim.Adam(student.parameters(), lr=0.001)# 蒸馏训练循环
for epoch in range(epochs):teacher.eval()   # 固定老师模型student.train()  # 训练学生模型running_loss = 0.0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)# 前向传播with torch.no_grad():teacher_logits = teacher(inputs)student_logits = student(inputs)# 计算损失# 软标签损失(使用温度参数软化)soft_loss = criterion_soft(nn.functional.log_softmax(student_logits / temperature, dim=1),nn.functional.softmax(teacher_logits / temperature, dim=1)) * (alpha * temperature * temperature)  # 缩放损失# 硬标签损失hard_loss = criterion_hard(student_logits, labels) * (1 - alpha)total_loss = soft_loss + hard_loss# 反向传播optimizer.zero_grad()total_loss.backward()optimizer.step()running_loss += total_loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')print("Distillation finished!")

其中,有几点需要注意的是:

  • 温度参数temperature=4 放大模型的“不确定性”,让学生更关注类别间关系。

  • 损失混合alpha=0.7 表示70%依赖老师软标签,30%依赖真实标签。

  • KL散度:衡量学生与老师概率分布的差异。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/894101.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

关联传播和 Python 和 Scikit-learn 实现

文章目录 一、说明二、什么是 Affinity Propagation。2.1 先说Affinity 传播的工作原理2.2 更多细节2.3 传播两种类型的消息2.4 计算责任和可用性的分数2.4.1 责任2.4.2 可用性分解2.4.3 更新分数&#xff1a;集群是如何形成的2.4.4 估计集群本身的数量。 三、亲和力传播的一些…

通过配置代理解决跨域问题(Vue+SpringBoot项目为例)

跨域问题&#xff1a; 是由浏览器的同源策略引起的&#xff0c;同源策略是一种安全策略&#xff0c;用于防止一个网站访问其他网站的数据。 同源是指协议、域名和端口号都相同。 跨域问题常常出现在前端项目中&#xff0c;当浏览器中的前端代码尝试从不同的域名、端口或协议…

(1)Linux高级命令简介

Linux高级命令简介 在安装好linux环境以后第一件事情就是去学习一些linux的基本指令&#xff0c;我在这里用的是CentOS7作演示。 首先在VirtualBox上装好Linux以后&#xff0c;启动我们的linux&#xff0c;输入账号密码以后学习第一个指令 简介 Linux高级命令简介ip addrtou…

TOGAF之架构标准规范-信息系统架构 | 数据架构

TOGAF是工业级的企业架构标准规范&#xff0c;信息系统架构阶段是由数据架构阶段以及应用架构阶段构成&#xff0c;本文主要描述信息系统架构阶段中的数据架构阶段。 如上所示&#xff0c;信息系统架构&#xff08;Information Systems Architectures&#xff09;在TOGAF标准规…

Windows 程序设计7:文件的创建、打开与关闭

文章目录 前言一、文件的创建与打开CreateFile1. 创建新的空白文件2. 打开已存在文件3. 打开一个文件时&#xff0c;如果文件存在则打开&#xff0c;如果文件不存在则新创建文件4.打开一个文件&#xff0c;如果文件存在则打开文件并清空内容&#xff0c;文件不存在则 新创建文件…

FastReport.NET控件篇之富文本控件

简介 FastReport.NET 提供了 RichText 控件&#xff0c;用于在报表中显示富文本内容。富文本控件支持多种文本格式&#xff08;如字体、颜色、段落、表格、图片等&#xff09;&#xff0c;非常适合需要复杂排版和格式化的场景。 富文本控件(RichText)使用场景不多&#xff0c…

爬虫基础(三)Session和Cookie讲解

目录 一、前备知识点 &#xff08;1&#xff09;静态网页 &#xff08;2&#xff09;动态网页 &#xff08;3&#xff09;无状态HTTP 二、Session和Cookie 三、Session 四、Cookie &#xff08;1&#xff09;维持过程 &#xff08;2&#xff09;结构 正式开始说 Sessi…

PythonFlask框架

文章目录 处理 Get 请求处理 POST 请求应用 app.route(/tpost, methods[POST]) def testp():json_data request.get_json()if json_data:username json_data.get(username)age json_data.get(age)return jsonify({username: username测试,age: age})从 flask 中导入了 Flask…

002-基于Halcon的图像几何变换

本节将简要介绍Halcon中有关图像几何变换的基本算子及其应用&#xff0c;主要涉及五种常见的二维几何变换形式&#xff1a;平移、镜像、旋转、错切和放缩。这几种变换可归结为一类更高级更抽象的空间变换类型&#xff0c;即仿射变换&#xff08;Affine transformation&#xff…

Hive:日志,hql运行方式,Array,行列转换

日志 可以在终端通过 find / | grep hive-log4j2 命令查找Hive的日志配置文件 这些文件用于配置Hive的日志系统。它们不属于系统日志也不属于Job日志&#xff0c;而是用于配置Hive如何记录系统日志和Job日志, 可以通过hive-log4j2 查找日志的位置 HQL的3种运行方式 第1种就是l…

Unity 粒子特效在UI中使用裁剪效果

1.使用Sprite Mask 首先建立一个粒子特效在UI中显示 新建一个在场景下新建一个空物体&#xff0c;添加Sprite Mask组件&#xff0c;将其的Layer设置为UI相机渲染的UI层&#xff0c; 并将其添加到Canvas子物体中&#xff0c;调整好大小&#xff0c;并选择合适的Sprite&#xff…

【实践案例】使用Dify构建企业知识库

文章目录 背景知识检索增强生成&#xff08;RAG&#xff09;向量检索关键词检索混合检索向量化和相似度计算实例说明 实践案例创建知识库Rerank 模型设置创建Dify工作流测试 背景知识 检索增强生成&#xff08;RAG&#xff09; 检索增强生成&#xff08;Retrieval-Augmented …

Maui学习笔记- SQLite简单使用案例02添加详情页

我们继续上一个案例&#xff0c;实现一个可以修改当前用户信息功能。 当用户点击某个信息时&#xff0c;跳转到信息详情页&#xff0c;然后可以点击编辑按钮导航到编辑页面。 创建项目 我们首先在ViewModels目录下创建UserDetailViewModel。 实现从详情信息页面导航到编辑页面…

算法基础学习——快排与归并(附带java模版)

快速排序和归并排序是两种速度较快的排序方式&#xff0c;是最应该掌握的两种排序算法&#xff0c; &#xff08;一&#xff09;快速排序&#xff08;不稳定的&#xff09; 基本思想&#xff1a;分治 平均时间复杂度&#xff1a;O(nlogn) / 最慢O(n^2) / 最快O(n) 步骤&…

数据结构的队列

一.队列 1.队列&#xff08;Queue&#xff09;的概念就是先进先出。 2.队列的用法&#xff0c;红色框和绿色框为两组&#xff0c;offer为插入元素&#xff0c;poll为删除元素&#xff0c;peek为查看元素红色的也是一样的。 3.LinkedList实现了Deque的接口&#xff0c;Deque又…

1. Java-MarkDown文件创建-工具类

Java-MarkDown文件创建-工具类 1. 思路 根据markdown语法&#xff0c;拼装markdown文本内容 2. 工具类 import java.util.Arrays; import java.util.List;/*** Markdown生成工具类* Author: 20004855* Date: 2021/1/15 16:00*/ public class MarkdownGenerator {private Str…

Go学习:格式化输入输出

目录 1. 输出 2. 输入 1. 输出 常用格式&#xff1a; 格式说明%d整型格式%s字符串格式%c字符格式%f浮点数格式%T操作变量所属类型%v自动匹配格式输出 简单示例代码&#xff1a; package mainimport "fmt"func main() {a : 10b : "abc"c : ad : 3.14/…

回顾:Maven的环境搭建

1、下载apache-maven-3.6.0 **网址:**http://maven.apache.org 然后解压到指定的文件夹&#xff08;记住文件路径&#xff09; 2、配置Maven环境 复制bin文件夹 的路径D:\JavaTool\apache-maven-3.6.0\bin 环境配置成功 3、检查是否配置成功 winR 输入cmd 命令行输入mvn -v…

【以音频软件FFmpeg为例】通过Python脚本将软件路径添加到Windows系统环境变量中的实现与原理分析

在Windows系统中&#xff0c;你可以通过修改环境变量 PATH 来使得 ffmpeg.exe 可在任意路径下直接使用。要通过Python修改环境变量并立即生效&#xff0c;如图&#xff1a; 你可以使用以下代码&#xff1a; import os import winreg as reg# ffmpeg.exe的路径 ffmpeg_path …

解决报错“The layer xxx has never been called and thus has no defined input shape”

解决报错“The layer xxx has never been called and thus has no defined input shape”(这里写自定义目录标题) 报错显示 最近在跑yolo的代码时遇到这样一个错误&#xff0c;显示“the layer {self.name} has never been called”.这个程序闲置了很久&#xff0c;每次一遇到…