李宏毅机器学习2023-HW13-Network Compression

文章目录

  • Task
  • Link
  • Baseline
    • Simple Baseline
    • Medium Baseline
    • Strong Baseline
    • Boss Baseline
      • FitNet Knowledge Distillation
      • Relational Knowledge Distillation (RKD)
      • Distance Metric (DM) Knowledge Distillation

Task

通过network compression完成图片分类,数据集跟hw3中11种食品分类一致。需要设计小模型student model,参数数目少于60k,训练该模型接近teacher model的精度test-Acc ≅ 0.902

Link

kaggle

Baseline

Simple Baseline

Just run the sample code

Medium Baseline

loss function定义为 KL divergence,公式如下:
L o s s = α T 2 × K L ( p ∣ ∣ q ) + ( 1 − α ) ( O r i g i n a l C r o s s E n t r o p y L o s s ) , w h e r e p = s o f t m a x ( student’s logits T ) , a n d q = s o f t m a x ( teacher’s logits T ) Loss=αT^2×KL(p||q)+(1−α)(Original Cross Entropy Loss),where \ p=softmax(\frac{\text{student's logits}}{T}),and\ q=softmax(\frac{\text{teacher's logits}}{T}) Loss=αT2×KL(p∣∣q)+(1α)(OriginalCrossEntropyLoss),where p=softmax(Tstudent’s logits),and q=softmax(Tteacher’s logits)
同时epoch可以增加到50,其他地方标有#medium的也要修改

# medium
def loss_fn_kd(student_logits, labels, teacher_logits, alpha=0.5, temperature=5.0):# ------------TODO-------------# Refer to the above formula and finish the loss function for knowkedge distillation using KL divergence loss and CE loss.# If you have no idea, please take a look at the provided useful link above.student_prob = F.softmax(student_logits/temperature, dim=-1)teacher_prob = F.softmax(teacher_logits/temperature, dim=-1)KL_loss = (teacher_prob * (teacher_prob.log() - student_prob)).mean()CE_loss = nn.CrossEntropyLoss()(student_logits, labels)loss = alpha * temperature**2* KL_loss + (1 - alpha) * CE_lossreturn loss

Strong Baseline

用depth-wise and point-wise convolutions修改model architecture+增加epoch

  • other useful techniques
    • group convolution (Actually, depthwise convolution is a specific type of group convolution)
    • SqueezeNet
    • MobileNet
    • ShuffleNet
    • Xception
    • GhostNet

还可以应用中间层特征学习,这里我没有做

# strong
def dwpw_conv(in_channels, out_channels, kernel_size, stride=1, padding=0):return nn.Sequential(nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=padding, groups=in_channels), #depthwise convolutionnn.BatchNorm2d(in_channels),nn.ReLU(),nn.Conv2d(in_channels, out_channels, 1), # pointwise convolutionnn.BatchNorm2d(out_channels),nn.ReLU(),)

Boss Baseline

Other advanced Knowledge Distillation(FitNet/RKD/DM) + 增加epoch + Depthwise & Pointwise Conv layer(深度可分离卷积)

当然也可以应用中间层特征学习

FitNet Knowledge Distillation

FitNet focuses on transferring knowledge from intermediate feature representations (hidden layers) instead of just using the output logits. The student model is trained to mimic the feature maps from certain layers of the teacher model.

#boss
def loss_fn_fitnet(teacher_feature, student_feature, labels, alpha=0.5):"""FitNet Knowledge Distillation Loss Function.Args:- teacher_feature: The feature maps from a hidden layer of the teacher model.- student_feature: The feature maps from the corresponding hidden layer of the student model.- labels: Ground truth labels for the task.- alpha: Weighting factor for the feature distillation loss.Returns:- loss: Combined loss with cross-entropy and feature map alignment."""# Mean squared error loss to align feature maps of teacher and studentfeature_loss = F.mse_loss(student_feature, teacher_feature)# Hard label cross-entropy loss for the student output (classification)hard_loss = F.cross_entropy(student_feature, labels)# Combine both lossesloss = alpha * hard_loss + (1 - alpha) * feature_lossreturn loss

Relational Knowledge Distillation (RKD)

Relational Knowledge Distillation focuses on transferring the relationships (distances and angles) between data samples as learned by the teacher. The student model is trained to match these relationships instead of just focusing on output probabilities.

# boss
def pairwise_distance(x):"""Calculate pairwise distance between batch samples."""return torch.cdist(x, x, p=2)def angle_between_pairs(x):"""Calculate angles between all pairs of points in batch."""diff = x.unsqueeze(1) - x.unsqueeze(0)norm = diff.norm(dim=-1, p=2, keepdim=True)normalized_diff = diff / (norm + 1e-8)angles = torch.bmm(normalized_diff, normalized_diff.transpose(1, 2))return anglesdef loss_fn_rkd(teacher_feature, student_feature, labels, alpha=0.5):"""Relational Knowledge Distillation Loss Function.Args:- teacher_feature: Teacher model feature embeddings.- student_feature: Student model feature embeddings.- labels: Ground truth labels.- alpha: Weighting factor for relational distillation loss.Returns:- loss: Combined relational knowledge and hard label loss."""# Pairwise distances between features in the teacher and student modelteacher_dist = pairwise_distance(teacher_feature)student_dist = pairwise_distance(student_feature)# Distillation loss using the L2 norm between relational distancesdistance_loss = F.mse_loss(student_dist, teacher_dist)# Angle-based loss between teacher and student feature vectorsteacher_angle = angle_between_pairs(teacher_feature)student_angle = angle_between_pairs(student_feature)angle_loss = F.mse_loss(student_angle, teacher_angle)# Hard label cross-entropy loss for the student outputhard_loss = F.cross_entropy(student_feature, labels)# Combine the lossesloss = alpha * hard_loss + (1 - alpha) * (distance_loss + angle_loss)return loss

Distance Metric (DM) Knowledge Distillation

Distance Metric distillation focuses on transferring the distance metric (such as Euclidean distance or cosine similarity) between instances in the teacher’s feature space to the student model.

def loss_fn_dm(teacher_feature, student_feature, labels, alpha=0.5):"""Distance Metric (DM) Knowledge Distillation Loss Function.Args:- teacher_feature: The feature representations from the teacher model.- student_feature: The feature representations from the student model.- labels: Ground truth labels for the task.- alpha: Weighting factor for distance metric loss.Returns:- loss: Combined distance metric loss and cross-entropy loss."""# Calculate pairwise distance between teacher and student embeddingsteacher_dist = pairwise_distance(teacher_feature)student_dist = pairwise_distance(student_feature)# Distance metric loss using Mean Squared Error (MSE) lossdist_loss = F.mse_loss(student_dist, teacher_dist)# Hard label cross-entropy loss for the student's outputhard_loss = F.cross_entropy(student_feature, labels)# Combine the lossesloss = alpha * hard_loss + (1 - alpha) * dist_lossreturn loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/54310.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT 带箭头的控件QPolygon

由于对当前项目需要绘制一个箭头控件&#xff0c;所以使用了QPainter和QPolygon来进行绘制&#xff0c;原理就是计算填充&#xff0c;下面贴出代码和效果图 这里简单介绍下QPolygon QPolygon是继承自 QVector<QPoint>那么可以很简单的理解为&#xff0c;他就是一个点的…

Leetcode面试经典150题-138.随机链表的复制

题目比较简单&#xff0c;重点是理解思想&#xff0c;random不管&#xff0c;copy一定要放在next 而且里面的遍历过程不能省略 解法都在代码里&#xff0c;不懂就留言或者私信 /* // Definition for a Node. class Node {int val;Node next;Node random;public Node(int val…

springboot-创建连接池

操作数据库 代码开发步骤&#xff1a; pom.xml文件配置依赖properties文件配置连接数据库信息&#xff08;连接池用的是HikariDataSource&#xff09;数据库连接池开发 configurationproperties和value注解从properties文件中取值bean方法开发 service层代码操作数据库 步骤&am…

数据分析师的得力助手:vividime Desktop让数据分析变得更简单高效

在数据驱动决策的今天&#xff0c;数据分析已成为企业不可或缺的一部分。面对海量的数据和复杂的业务需求&#xff0c;一款高效、易用的报表工具显得尤为重要。本文将深入解析为何一款优秀的报表工具对于数据分析至关重要&#xff0c;并以市场上备受好评的免费BI工具——vividi…

集成学习详细介绍

以下内容整理于&#xff1a; 斯图尔特.罗素, 人工智能.现代方法 第四版(张博雅等译)机器学习_温州大学_中国大学MOOC(慕课)XGBoost原理介绍------个人理解版_xgboost原理介绍 个人理解-CSDN博客 集成学习(ensemble)&#xff1a;选择一个由一系列假设h1, h2, …, hn构成的集合…

YOLOv10改进系列,YOLOv10损失函数更换为Powerful-IoU(2024年最新IOU),助力高效涨点

改进前训练结果: 改进后的结果: 摘要 边界框回归(BBR)是目标检测中的核心任务之一,BBR损失函数显著影响其性能。然而,观察到现有基于IoU的损失函数存在不合理的惩罚因子,导致回归过程中锚框扩展,并显著减缓收敛速度。为了解决这个问题,深入分析了锚框扩展的原因。针…

【网络】详解HTTP协议的CGI机制和CGI进程

目录 引言 CGI机制模型 伪代码示例 个人主页&#xff1a;东洛的克莱斯韦克-CSDN博客 引言 CGI机制是HTTP协议提供的偏底层的一套机制&#xff0c;也是非常重要的机制——它让大量的业务进程和HTPP协议解耦。而CGI进程是业务层的&#xff0c;用来处理各种数据&#xff0c;比…

OpenCV结构分析与形状描述符(24)检测两个旋转矩形之间是否相交的一个函数rotatedRectangleIntersection()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 测两个旋转矩形之间是否存在交集。 如果存在交集&#xff0c;则还返回交集区域的顶点。 下面是一些交集配置的例子。斜线图案表示交集区域&#…

孙怡带你深度学习(2)--PyTorch框架认识

文章目录 PyTorch框架认识1. Tensor张量定义与特性创建方式 2. 下载数据集下载测试展现下载内容 3. 创建DataLoader&#xff08;数据加载器&#xff09;4. 选择处理器5. 神经网络模型构建模型 6. 训练数据训练集数据测试集数据 7. 提高模型学习率 总结 PyTorch框架认识 PyTorc…

如何在安卓設備上更換IP地址?

IP地址是設備在網路中的唯一標識&#xff0c;通過IP地址&#xff0c;網路能夠識別並與設備進行通信。本文將詳細介紹在安卓設備上更換IP地址的幾種方法。 在安卓設備上更換IP地址的方法 1. 使用Wi-Fi網路更換IP地址 最簡單的方法是通過Wi-Fi網路更換IP地址。步驟如下&#x…

NVIDIA最新AI论文介绍NEST:一种用于语音处理的快速高效自监督模型

语音处理专注于开发能够分析、解释和生成人类语音的系统。这些技术涵盖了多种应用&#xff0c;例如自动语音识别&#xff08;ASR&#xff09;、说话人验证、语音转文本翻译以及说话人分离。随着对虚拟助手、转录服务和多语言交流工具的依赖不断增加&#xff0c;高效准确的语音处…

Android的内核

Android的内核是基于Linux的长期支持版本的“Android通用内核(ACK)”。 Android作为一个广泛使用的操作系统&#xff0c;其根基在于内核的设计和功能。下面将深入探讨Android内核的各个方面&#xff0c;从其基本结构到与Linux内核的关系&#xff0c;再到内核的版本管理及在设备…

Vue2电商平台项目 (三) Search模块、面包屑(页面自己跳自己)、排序、分页器!

文章目录 一、Search模块1、Search模块的api2、Vuex保存数据3、组件获取vuex数据并渲染(1)、分析请求数据的数据结构(2)、getters简化数据、渲染页面 4、Search模块根据不同的参数获取数据(1)、 派发actions的操作封装为函数(2)、设置带给服务器的参数(3)、Object.assign整理参…

解决NumbaWarning error的报错

愿武艺晴小朋友一定得每天都开心 SCENIC中,当运行python change.py命令行,把count矩阵转换为loom文件时,发生了如下报错: py:371: NumbaWarning: The TBB threading layer requires TBB version 2021 update 6 or later i.e., TBB_INTERFACE_VERSION >= 1 2060. Found…

如何通过OceanBase的多级弹性扩缩容能力应对业务洪峰

每周四晚上的10点&#xff0c;都有近百万的年轻用户进入泡泡玛特的抽盒机小程序&#xff0c;共同参与到抢抽盲盒新品的活动中。瞬间的并发流量激增对抽盒机小程序的系统构成了巨大的挑战&#xff0c;同时也对其数据库的扩容能力也提出了更高的要求。 但泡泡玛特的工程师们一点…

netstat和ss命令用法

使用 netstat 或 ss 命令来检查网络连接&#xff0c;这是非常好的做法。这两个命令都可以帮助您查看当前系统上的网络连接状态&#xff0c;包括监听的端口和建立的连接。下面是对这两个命令的详细说明和使用方法&#xff1a; 使用 netstat 命令 netstat 是一个网络统计工具&a…

Redhat 7,8,9系(复刻系列) 一键部署Oracle19c rpm

Oracle19c前言 Oracle 19c 是甲骨文公司推出的一款企业级关系数据库管理系统,它带来了许多新的功能和改进,使得数据库管理更加高效、安全和可靠。以下是关于 Oracle 19c 的详细介绍: 主要新特性 多租户架构:支持多租户架构,允许多个独立的数据库实例在同一个物理服务器上…

JDBC API详解一

DriverManager 驱动管理类&#xff0c;作用&#xff1a;1&#xff0c;注册驱动&#xff1b;2&#xff0c;获取数据库连接 1&#xff0c;注册驱动 Class.forName("com.mysql.cj.jdbc.Driver"); 查看Driver类源码 static{try{DriverManager.registerDriver(newDrive…

java十进制码、六进制码和字符码的转换

一、字符转换为ASCII码&#xff1a; int i(int)1; 二、ASCII码转换为字符&#xff1a; char ch (char)40; 三、十六进制码转换为字符&#xff1a; char charValue (char)\u0040; package week3;public class check_point4_8 {public static void main(String[] args) {S…

谷歌怎么像图里这样多开贴吧号??

&#x1f3c6;本文收录于《CSDN问答解惑-专业版》专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收…