知识蒸馏基础知识

news/2025/2/23 19:06:17

参考笔记:

YOLOv5改进系列(二十五) 知识蒸馏理论与实践_yolov5知识蒸馏-CSDN博客

全网最细图解知识蒸馏(涉及知识点:知识蒸馏实现代码,知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)-CSDN博客

学习视频:【精读AI论文】知识蒸馏_哔哩哔哩_bilibili 


目录

1. 什么是知识蒸馏

2.轻量化网络的方式有哪些

3.为什么要知识蒸馏

4.知识蒸馏的理论依据

5.知识蒸馏分类

5.1 目标蒸馏-Logits方法

5.1.1 Hard-targets、Soft-targets

5.1.2 蒸馏温度T

5.1.3 蒸馏温度T的特点

5.2 特征蒸馏方法

6.知识蒸馏过程

7.图解知识蒸馏


 本文主讲的是目标蒸馏-Logits方法 

1. 什么是知识蒸馏

知识蒸馏

知识蒸馏就是把一个大的教师模型的知识萃取出来,把它浓缩到一个小的学生模型,可以理解为一个大的教师神经网络把他的知识教给小的学生网络,这里有一个知识的迁移过程,从教师网络迁移到了学生网络身上,教师网络一般比较臃肿,所以教师网络把知识教给学生网络,学生网络是一个比较小的网络,这样就可以用学生网络去做一些轻量化网络做的事情

2.轻量化网络的方式有哪些

(1)压缩已训练好的模型:知识蒸馏、权值量化、权重剪枝、通道剪枝、注意力迁移

(2)直接训练轻量化网络:SqueezeNet、MobileNetv1v2v3、MnasNet、ShuffleNet、EfficientNet、EfficientDet

(3)加速卷积运算:im2col + GEMM、Wiongrad、低秩分解

(4)硬件部署:TensorRT、Jetson、TensorFlow-lite、Openvino、FPGA集成电路

3.为什么要知识蒸馏

深度学习在计算机视觉、语音识别、自然语言处理等内的众多领域中均取得了令人难以置信的性能。但是,大多数模型在计算上过于昂贵,无法在移动端或嵌入式设备上运行。因此需要对模型进行压缩,这样小的模型就适用于部署在终端设备上了

Student Model 部署在终端设备上 

(1)提升模型精度

如果对目前的网络模型 A 的精度不是很满意,那么可以先训练一个更高精度的 teacher 模型 B(通常参数量更多,时延更大),然后用这个训练好的 teacher 模型 Bstudent 模型 A 进行知识蒸馏,得到一个更高精度的 A 模型。

(2)降低模型时延,压缩模型参数

如果对目前的网络模型 A 的时延不满意,可以先找到一个时延更低,参数量更小的模型 B ,通常来讲, B 模型精度也会比较低,然后通过训练一个更高精度的 teacher 模型 C 来对这个参数量小的模型 B 进行知识蒸馏,使得该模型 B 的精度接近最原始的模型 A,从而达到降低时延的目的。

(3)标签之间的域迁移

假如使用狗和猫的数据集训练了一个 teacher 模型 A ,使用香蕉和苹果训练了一个 teacher 模型 B ,那么就可以用这两个模型同时蒸馏出一个可以识别狗、猫、香蕉以及苹果的模型,将两个不同域的数据集进行集成和迁移

4.知识蒸馏的理论依据

知识蒸馏使用的是 Teacher—Student 模型,其中 Teacher 是“知识”的输出者, Student 是“知识”的接受者。知识蒸馏的过程分为 3 个阶段:

(1)Teacher 模型训练:简称为 Net-T ,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对 Teacher 模型不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入 X , 其都能输出 Y ,其中 Y 经过 softmax 的映射,输出值是对应类别的概率值
(2)Student 模型训练: 简称为 Net-S ,它是参数量较小、模型结构相对简单的模型。同样的,对于输入 X ,其都能输出 YY 经过 softmax 映射后同样能输出对应类别的概率值

(3)知识蒸馏训练: Net-T 学习能力强,可以将它学到的知识迁移给学习能力相对弱的 Net-S 模型,以此来增强 Net-T 模型的泛化能力。复杂笨重但是效果好的 Net-T 模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的 Net-S 小模型

5.知识蒸馏分类

知识蒸馏是对模型的能力进行迁移,根据迁移的方法不同可以简单分为基于目标蒸馏(也称为 Soft-target 蒸馏或 Logits 方法蒸馏)和基于特征蒸馏的算法两个大的方向

5.1 目标蒸馏-Logits方法

分类问题的共同点是模型最后会有一个 softmax 层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于我们已经有了一个泛化能力较强的 Teacher 模型,我们利用 Teacher 模型来蒸馏训练 Student 模型时,可以直接让 Student 模型去学习 Teacher 模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用 Tearcher 模型的 softmax 层输出的类别概率来作为 Soft-targets Soft-targets 作为训练 Student 模型的标签

5.1.1 Hard-targets、Soft-targets

传统的模型训练过程:采用 Hard-targets ,类别标签对应的 One-Hot 向量,除了正确类别为 1,其他类别都是 0

知识蒸馏训练过程:用 Teacher Modelclass probabilities 作为 Student ModelSoft-targets


下面看一些例子

假设有 3 个类别,分别是 [马,驴子,汽车]

 对于马这张图片,Hard-targetsSoft-targets 两种形式的类别标签如下:

Hard-targets、Soft-targets 

使用 Hard-targets 来训练网络,对于马这张图片,相当于告诉网络这就是一匹马,不是驴和汽车,并且不是驴和汽车的程度是相等的,因为 Hard-targets 中驴和汽车的标签值都为 0

但通过肉眼其实能观察到马和驴其实是有一点相似的,马更像驴子而更不像汽车。Hard-targets 一个致命的缺陷就是秉持绝对的 “正确” ,对于非正确类别不提供任何额外的信息

使用 Soft-targets 来训练网络,对于马这种图片,相当于告诉网络是马的概率为 0.7 ,是汽车的概率为 0.05 ,说明马和汽车是非常不像的,而是驴的概率为 0.25,说明马和驴是有一点类似的。同理,驴和汽车也是非常不像的。显然 Soft-targets 传递的信息更多

其他例子 


综上:

(1)在使用 Soft-targets 训练时,Student Model 可以很快学习到 Teacher Model 的推理过程

(2)传统的 Hard-targets 的训练方式,所有的负标签都会被平等对待。而 Soft-targets 包含了更多的“知识”和“信息”,像谁,不像谁,有多像,有多不像,特别是非正确类别概率的相对大小(驴和车)。所以我们可以用 Teacher 网络的预测结果作为训练 Student 网络的标签 

5.1.2 蒸馏温度T

如果觉得 Teacher ModelSoft-targets 还不够 Soft ,即想把其他非正确类别的概率也变大,把它们的相对大小充分暴露出来,让学生网络能够有一个强烈的信号,知道这些非正确类别的更多信息,因此提出了蒸馏温度 TT 越大,Soft-targets 就越 Soft

实现方法:在原始的 Softmax 公式加 TT = 1 时,就是常规的 Softmax 操作

计算公式 

 下面我们来看一下 T = 1T = 3 时的对比例子

不同温度的Softmax对比 

(1)T = 1 时,普通的 Softmax 操作, 4 个概率值的两级分化十分严重,猫和狗的预测概率非常小,接近于 0  。而马的预测概率非常高,达到 0.88 。 

(2)T = 3 时,使用蒸馏温度,4 个概率值相比较于 T = 1 的概率值两级分化情况没那么严重,变得更 Soft 了,而且各类别概率值的大小顺序和 T=1 时是一样的

5.1.3 蒸馏温度T的特点

  • 原始的 Softmax 函数是 T = 1 时的特例;T<1 时,概率分布比原始更“陡峭”,即当 0<T<1 时,Softmax 的输出值会接近与 Hard-targetsT>1 时,概率分布比原始更“平缓”
  • 随着 T 的增加,Softmax 的输出分布会越来越平缓。温度越高,Softmax 上各个值的分布就越平均,极端情况的 Softmax 值是平均分布的,如下图 T = 100

  • 温度的高低改变的是 Student 模型训练过程中对负标签的关注程度。温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;温度较高时,负标签相关的值会相对增大, Student 模型会相对更多地关注到负标签
  • 针对较困难的分类或检测任务, Temperature 通常取 1 ,确保 Teacher Model 中正确预测的贡献
  • 温度的选取需要进行实际实验的比较,本质上就是如下两种情况作取舍:
    • 当想从负标签中多学到一些信息的时候,温度应调高一些
    • 当想减少负标签的干扰的时候,温度应调低一些

总的来说,温度的选择和 Student 模型的大小有关,Student 模型参数量比较小时,相对比较低的温度就可以了。因为参数量小的模型不能学到所有 Teacher 模型的知识,所以可以适当忽略掉 

5.2 特征蒸馏方法

另外一种知识蒸馏思路是特征蒸馏方法,它不像 Logits 方法那样 Student 只学习 TeacherLogits 这种结果知识,而是学习 Teacher 网络结构中的中间层特征

最早采用这种模式的工作来自于论文《FITNETS:Hints for Thin Deep Nets》,它强迫 Student 某些中间层的网络响应,去逼近 Teacher 对应的中间层的网络响应。这种情况下, Teacher 中间特征层的响应,就是传递给 Student 的知识。在此之后,提出了各种新方法,但是大致思路还是这个思路,本质是 Teacher 将特征级知识迁移给 Student

6.知识蒸馏过程

知识蒸馏过程

损失函数计算 

上图中的 distillation lossstudent loss 也可以称为 soft losshard loss,两者都可以采用传统的交叉熵损失函数,交叉熵损失函数的计算公式如下:

交叉熵损失函数计算公式 

下图是损失函数计算的例子

图解损失函数计算过程 

实际的损失函数设计中,会给 soft losshard loss 增加一个权重系数,构成完成的损失函数,即:

Total loss 设计为 Soft LossHard Loss 所对应的交叉熵的加权和,其中 \color{red}\alpha 越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小 Soft Loss 的比重,让真实标签帮助鉴别困难样本

7.图解知识蒸馏


http://www.niftyadmin.cn/n/5863693.html

相关文章

《AI与NLP:开启元宇宙社交互动新纪元》

在科技飞速发展的当下&#xff0c;元宇宙正从概念逐步走向现实&#xff0c;成为人们关注的焦点。而在元宇宙诸多令人瞩目的特性中&#xff0c;社交互动体验是其核心魅力之一。人工智能&#xff08;AI&#xff09;与自然语言处理&#xff08;NLP&#xff09;技术的迅猛发展&…

Plant Simulation培训教程-机器人3D仿真模块

原创 知行 天理智能科技 2025年01月04日 16:43 浙江 又到年终盘点的时候了&#xff0c;在这里我把之前录制的Plant Simulation培训教程-机器人3D仿真模块分享出来&#xff0c;有需要的可以直接联系我。 3D仿真模块包含多种设备机构自定义、动画制作。多轴机器人自定义动画实现…

值传递,引用传递

在Java中&#xff0c;值传递和引用传递是两种不同的参数传递方式&#xff0c;尽管存在一些关于Java是否支持引用传递的争议。 值传递&#xff08;Pass by Value&#xff09; 值传递意味着当你调用一个方法时&#xff0c;方法参数接收到的是调用时传入的实际值的副本。换句话说…

企业数据分析-偿债能力分析

核心财务分析指标&#xff1a; 企业的资产 负债 所有者权益 资产是企业拥有的经济资源&#xff0c;负债是企业需要偿还的经济责任&#xff0c;所有者权益则是企业所有者在企业中的权益。这个公式表明&#xff0c;企业的所有资产&#xff0c;要么是债权人的&#xff0c;要么…

基于SpringBoot的城乡商城协作系统【附源码】

基于SpringBoot的城乡商城协作系统 效果如下&#xff1a; 系统登陆页面 系统管理员主页面 商品信息管理页面 系统用户主页面 社区交流页面 用户充值页面 订单提交页面 商品信息页面 研究背景 随着互联网技术的飞速发展&#xff0c;电子商务在我国城乡地区的普及程度越来越高…

【Shell编程 / 9】脚本实战项目:从基础到进阶的自动化管理方案

文章目录 Shell脚本实战项目自动化部署脚本系统监控脚本文件备份脚本定时任务管理脚本文件传输自动化脚本自动化日志清理脚本用户管理脚本 Shell脚本实战项目 在掌握了 Shell 脚本的基本语法和高级技巧后&#xff0c;实践是进一步提升脚本编写能力的关键。通过参与一些实际的项…

一种简单有效的分析qnx+android智能座舱项目中的画面闪烁的方法(8155平台)

在智能座舱项目的开发过程中&#xff0c;画面闪烁问题是一个常见但棘手的挑战。由于这些闪烁现象往往转瞬即逝&#xff0c;传统的分析工具如截图、录屏或dump图层等方法难以捕捉和定位问题根源。针对这一难题&#xff0c;本文介绍了一种较为有效的分析方法&#xff0c;能够帮助…

使用Supervisor管理PHP脚本进程任务

Supervisor 是一个用 Python 编写的进程管理工具&#xff0c;用于监控和管理后台进程。它可以帮助你确保某些进程&#xff08;如 PHP 消费者脚本&#xff09;持续运行&#xff0c;并在进程崩溃或服务器重启后自动重启。 在 PHP Redis 实现异步处理的场景中&#xff0c;Supervi…