浅谈知识蒸馏技术

        最近爆火的DeepSeek 技术,将知识蒸馏技术运用推到我们面前。今天就简单介绍一下知识蒸馏技术并附上python示例代码。

        知识蒸馏(Knowledge Distillation)是一种模型压缩技术,它的核心思想是将一个大型的、复杂的教师模型(teacher model)的知识迁移到一个小型的、简单的学生模型(student model)中,从而在保持模型性能的前提下,减少模型的参数数量和计算复杂度。以下是对知识蒸馏使用的算法及技术的深度分析,并附上 Python 示例代码。

1. 基本原理

知识蒸馏的基本原理是让学生模型学习教师模型的输出概率分布,而不仅仅是学习真实标签。教师模型通常是一个大型的、经过充分训练的模型,它具有较高的性能,但计算成本也较高。学生模型则是一个小型的、结构简单的模型,其目标是在教师模型的指导下学习到与教师模型相似的知识,从而提高自身的性能。

2. 软标签(Soft Labels)

在传统的监督学习中,模型的输出是硬标签(Hard Labels),即每个样本只对应一个确定的类别标签。而在知识蒸馏中,使用的是软标签(Soft Labels),即教师模型输出的概率分布。软标签包含了更多的信息,因为它不仅反映了样本的真实类别,还反映了教师模型对其他类别的不确定性。通过学习软标签,学生模型可以更好地捕捉到数据中的细微差别和不确定性。

3. 损失函数

知识蒸馏的损失函数通常由两部分组成:硬标签损失(Hard Label Loss)和软标签损失(Soft Label Loss)。硬标签损失是学生模型的输出与真实标签之间的交叉熵损失,用于保证学生模型在基本的分类任务上的准确性。软标签损失是学生模型的输出与教师模型的输出之间的交叉熵损失,用于让学生模型学习教师模型的知识。最终的损失函数是硬标签损失和软标签损失的加权和,权重可以根据具体情况进行调整。

4. 温度参数(Temperature)

在计算软标签损失时,通常会引入一个温度参数(Temperature)。温度参数可以控制教师模型输出的概率分布的平滑程度。当温度参数较大时,概率分布会更加平滑,即教师模型对不同类别的不确定性会增加;当温度参数较小时,概率分布会更加尖锐,即教师模型对真实类别的信心会增强。通过调整温度参数,可以平衡教师模型的知识传递和学生模型的学习效果。

5.Python 示例代码


以下是一个使用 PyTorch 实现知识蒸馏的简单示例代码:

import torch

import torch.nn as nn

import torch.optim as optim

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

# 定义教师模型

class TeacherModel(nn.Module):

def __init__(self):

super(TeacherModel, self).__init__()

self.fc1 = nn.Linear(784, 1200)

self.fc2 = nn.Linear(1200, 1200)

self.fc3 = nn.Linear(1200, 10)

self.relu = nn.ReLU()

def forward(self, x):

x = x.view(-1, 784)

x = self.relu(self.fc1(x))

x = self.relu(self.fc2(x))

x = self.fc3(x)

return x

# 定义学生模型

class StudentModel(nn.Module):

def __init__(self):

super(StudentModel, self).__init__()

self.fc1 = nn.Linear(784, 200)

self.fc2 = nn.Linear(200, 200)

self.fc3 = nn.Linear(200, 10)

self.relu = nn.ReLU()

def forward(self, x):

x = x.view(-1, 784)

x = self.relu(self.fc1(x))

x = self.relu(self.fc2(x))

x = self.fc3(x)

return x

# 数据加载

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))

])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 初始化教师模型和学生模型

teacher_model = TeacherModel()

student_model = StudentModel()

# 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# 训练教师模型(这里省略教师模型的训练过程,假设已经训练好)

# ...

# 知识蒸馏训练

def distillation_loss(y, labels, teacher_scores, T, alpha):

hard_loss = criterion(y, labels)

soft_loss = nn.KLDivLoss(reduction='batchmean')(nn.functional.log_softmax(y / T, dim=1),

nn.functional.softmax(teacher_scores / T, dim=1)) * (T * T)

return alpha * hard_loss + (1 - alpha) * soft_loss

T = 5.0 # 温度参数

alpha = 0.1 # 硬标签损失和软标签损失的权重

for epoch in range(10):

for data, labels in train_loader:

optimizer.zero_grad()

teacher_scores = teacher_model(data)

student_scores = student_model(data)

loss = distillation_loss(student_scores, labels, teacher_scores, T, alpha)

loss.backward()

optimizer.step()

print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

代码解释

  1. 模型定义:定义了一个简单的教师模型(TeacherModel)和一个简单的学生模型(StudentModel),用于 MNIST 手写数字识别任务。
  2. 数据加载:使用torchvision加载 MNIST 数据集,并进行数据预处理。
  3. 损失函数定义:定义了知识蒸馏的损失函数distillation_loss,它由硬标签损失和软标签损失组成。
  4. 训练过程:在训练过程中,首先计算教师模型的输出,然后计算学生模型的输出,最后计算知识蒸馏的损失并进行反向传播和参数更新。

通过以上的算法和技术,知识蒸馏可以有效地将教师模型的知识迁移到学生模型中,提高学生模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/68932.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小红的小球染色期望

B-小红的小球染色_牛客周赛 Round 79 题目描述 本题与《F.R小红的小球染色期望》共享题目背景,但是所求内容与范围均不同,我们建议您重新阅读题面。 有 n 个白色小球排成一排。小红每次将随机选择两个相邻的白色小球,将它们染成红色。小红…

ASP.NET Core与配置系统的集成

目录 配置系统 默认添加的配置提供者 加载命令行中的配置。 运行环境 读取方法 User Secrets 注意事项 Zack.AnyDBConfigProvider 案例 配置系统 默认添加的配置提供者 加载现有的IConfiguration。加载项目根目录下的appsettings.json。加载项目根目录下的appsettin…

Redis集群理解以及Tendis的优化

主从模式 主从同步 同步过程: 全量同步(第一次连接):RDB文件加缓冲区,主节点fork子进程,保存RDB,发送RDB到从节点磁盘,从节点清空数据,从节点加载RDB到内存增量同步&am…

沙皮狗为什么禁养?

各位铲屎官们,今天咱们来聊聊一个比较敏感的话题:沙皮狗为什么会被禁养?很多人对沙皮狗情有独钟,但有些地方却明确禁止饲养这种犬种,这背后到底是什么原因呢?别急,今天就来给大家好好揭秘&#…

物联网 STM32【源代码形式-ESP8266透传】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】

一、MQTT介绍 MQTT(Message Queuing Telemetry Transport,消息队列遥测传输协议)是一种基于发布/订阅模式的轻量级通讯协议,构建于TCP/IP协议之上。它最初由IBM在1999年发布,主要用于在硬件性能受限和网络状况不佳的情…

w186格障碍诊断系统spring boot设计与实现

🙊作者简介:多年一线开发工作经验,原创团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹赠送计算机毕业设计600个选题excel文…

题解 洛谷 Luogu P1955 [NOI2015] 程序自动分析 并查集 离散化 哈希表 C++

题目 传送门 P1955 [NOI2015] 程序自动分析 - 洛谷 | 计算机科学教育新生态https://www.luogu.com.cn/problem/P1955 思路 主要用到的知识是并查集 (如何实现并查集,这里不赘述了) 若 xi xj,则合并它们所在的集合。若 xi ! xj,则 i 和 …

无用知识之:std::initializer_list的秘密

先说结论,用std::initializer_list初始化vector,内部逻辑是先生成了一个临时数组,进行了拷贝构造,然后用这个数组的起终指针初始化initializer_list。然后再用initializer_list对vector进行初始化,这个动作又触发了拷贝…

97,【5】buuctf web [极客大挑战 2020]Greatphp

进入靶场 审代码 <?php // 关闭所有 PHP 错误报告&#xff0c;防止错误信息泄露可能的安全隐患 error_reporting(0);// 定义一个名为 SYCLOVER 的类 class SYCLOVER {// 定义类的公共属性 $sycpublic $syc;// 定义类的公共属性 $loverpublic $lover;// 定义魔术方法 __wa…

蓝桥杯单片机第七届省赛

前言 这套题不难&#xff0c;相对于第六套题这一套比较简单了&#xff0c;但是还是有些小细节要抓 题目 OK&#xff0c;以上就是全部的题目了&#xff0c;这套题目相对来说逻辑比较简单&#xff0c;四个按键&#xff0c;S4控制pwm占空比&#xff0c;S5控制计时时间&#xff0…

【C语言】自定义类型讲解

文章目录 一、前言二、结构体2.1 概念2.2 定义2.2.1 通常情况下的定义2.2.2 匿名结构体 2.3 结构体的自引用和嵌套2.4 结构体变量的定义与初始化2.5 结构体的内存对齐2.6 结构体传参2.7 结构体实现位段 三、枚举3.1 概念3.2 定义3.3 枚举的优点3.3.1 提高代码的可读性3.3.2 防止…

我问了DeepSeek和ChatGPT关于vue中包含几种watch的问题,它们是这么回答的……

前言&#xff1a;听说最近DeepSeek很火&#xff0c;带着好奇来问了关于Vue的一个问题&#xff0c;看能从什么角度思考&#xff0c;如果回答的不对&#xff0c;能不能尝试纠正&#xff0c;并帮我整理出一篇不错的文章。 第一次回答的原文如下&#xff1a; 在 Vue 中&#xff0c;…

纯后训练做出benchmark超过DeepseekV3的模型?

论文地址 https://arxiv.org/pdf/2411.15124 模型是AI2的&#xff0c;他们家也是玩开源的 先看benchmark&#xff0c;几乎是纯用llama3 405B后训练去硬刚出一个gpt4o等级的LLamA405 我们先看之前的机遇Lllama3.1 405B进行全量微调的模型 Hermes 3&#xff0c;看着还没缘模型…

UbuntuWindows双系统安装

做系统盘&#xff1a; Ubuntu20.04双系统安装详解&#xff08;内容详细&#xff0c;一文通关&#xff01;&#xff09;_ubuntu 20.04-CSDN博客 ubuntu系统调整大小&#xff1a; 调整指南&#xff1a; 虚拟机中的Ubuntu扩容及重新分区方法_ubuntu重新分配磁盘空间-CSDN博客 …

在 Zemax 中使用布尔对象创建光学光圈

在 Zemax 中&#xff0c;布尔对象用于通过组合或减去较简单的几何形状来创建复杂形状。布尔运算涉及使用集合运算&#xff08;如并集、交集和减集&#xff09;来组合或修改对象的几何形状。这允许用户在其设计中为光学元件或机械部件创建更复杂和定制的形状。 本视频中&#xf…

AI作画提示词:Prompts工程技巧与最佳实践

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///计算机爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于物联网智能项目之——智能家居项目…

LLMs之DeepSeek:Math-To-Manim的简介(包括DeepSeek R1-Zero的详解)、安装和使用方法、案例应用之详细攻略

LLMs之DeepSeek&#xff1a;Math-To-Manim的简介(包括DeepSeek R1-Zero的详解)、安装和使用方法、案例应用之详细攻略 目录 Math-To-Manim的简介 1、特点 2、一个空间推理测试—考察不同大型语言模型如何解释和可视化空间关系 3、DeepSeek R1-Zero的简介&#xff1a;处理更…

二叉树——429,515,116

今天继续做关于二叉树层序遍历的相关题目&#xff0c;一共有三道题&#xff0c;思路都借鉴于最基础的二叉树的层序遍历。 LeetCode429.N叉树的层序遍历 这道题不再是二叉树了&#xff0c;变成了N叉树&#xff0c;也就是该树每一个节点的子节点数量不确定&#xff0c;可能为2&a…

详解单片机学的是什么?(电子硬件)

大家好&#xff0c;我是山羊君Goat。 单片机&#xff0c;对于每一个硬件行业的从业者或者在校电子类专业的学生&#xff0c;相信对于这个名词都不陌生&#xff0c;但是掌没掌握就另说了。 那单片机到底学的是什么呢&#xff1f; 其实单片机在生活中就非常常见&#xff0c;目前…

数据密码解锁之DeepSeek 和其他 AI 大模型对比的神秘面纱

本篇将揭露DeepSeek 和其他 AI 大模型差异所在。 目录 ​编辑 一本篇背景&#xff1a; 二性能对比&#xff1a; 2.1训练效率&#xff1a; 2.2推理速度&#xff1a; 三语言理解与生成能力对比&#xff1a; 3.1语言理解&#xff1a; 3.2语言生成&#xff1a; 四本篇小结…