参考笔记:
YOLOv5改进系列(二十五) 知识蒸馏理论与实践_yolov5知识蒸馏-CSDN博客
全网最细图解知识蒸馏(涉及知识点:知识蒸馏实现代码,知识蒸馏训练过程,推理过程,蒸馏温度,蒸馏损失函数)-CSDN博客
学习视频:【精读AI论文】知识蒸馏_哔哩哔哩_bilibili
目录
1. 什么是知识蒸馏
2.轻量化网络的方式有哪些
3.为什么要知识蒸馏
4.知识蒸馏的理论依据
5.知识蒸馏分类
5.1 目标蒸馏-Logits方法
5.1.1 Hard-targets、Soft-targets
5.1.2 蒸馏温度T
5.1.3 蒸馏温度T的特点
5.2 特征蒸馏方法
6.知识蒸馏过程
7.图解知识蒸馏
本文主讲的是目标蒸馏-Logits方法
1. 什么是知识蒸馏
知识蒸馏就是把一个大的教师模型的知识萃取出来,把它浓缩到一个小的学生模型,可以理解为一个大的教师神经网络把他的知识教给小的学生网络,这里有一个知识的迁移过程,从教师网络迁移到了学生网络身上,教师网络一般比较臃肿,所以教师网络把知识教给学生网络,学生网络是一个比较小的网络,这样就可以用学生网络去做一些轻量化网络做的事情
2.轻量化网络的方式有哪些
(1)压缩已训练好的模型:知识蒸馏、权值量化、权重剪枝、通道剪枝、注意力迁移
(2)直接训练轻量化网络:SqueezeNet、MobileNetv1v2v3、MnasNet、ShuffleNet、EfficientNet、EfficientDet
(3)加速卷积运算:im2col + GEMM、Wiongrad、低秩分解
(4)硬件部署:TensorRT、Jetson、TensorFlow-lite、Openvino、FPGA集成电路
3.为什么要知识蒸馏
深度学习在计算机视觉、语音识别、自然语言处理等内的众多领域中均取得了令人难以置信的性能。但是,大多数模型在计算上过于昂贵,无法在移动端或嵌入式设备上运行。因此需要对模型进行压缩,这样小的模型就适用于部署在终端设备上了
Student Model 部署在终端设备上
(1)提升模型精度
如果对目前的网络模型 A 的精度不是很满意,那么可以先训练一个更高精度的 teacher 模型 B(通常参数量更多,时延更大),然后用这个训练好的 teacher 模型 B 对 student 模型 A 进行知识蒸馏,得到一个更高精度的 A 模型。
(2)降低模型时延,压缩模型参数
如果对目前的网络模型 A 的时延不满意,可以先找到一个时延更低,参数量更小的模型 B ,通常来讲, B 模型精度也会比较低,然后通过训练一个更高精度的 teacher 模型 C 来对这个参数量小的模型 B 进行知识蒸馏,使得该模型 B 的精度接近最原始的模型 A,从而达到降低时延的目的。
(3)标签之间的域迁移
假如使用狗和猫的数据集训练了一个 teacher 模型 A ,使用香蕉和苹果训练了一个 teacher 模型 B ,那么就可以用这两个模型同时蒸馏出一个可以识别狗、猫、香蕉以及苹果的模型,将两个不同域的数据集进行集成和迁移
4.知识蒸馏的理论依据
知识蒸馏使用的是 Teacher—Student 模型,其中 Teacher 是“知识”的输出者, Student 是“知识”的接受者。知识蒸馏的过程分为 3 个阶段:
(1)Teacher 模型训练:简称为 Net-T ,它的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对 Teacher 模型不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入 X , 其都能输出 Y ,其中 Y 经过 softmax 的映射,输出值是对应类别的概率值
(2)Student 模型训练: 简称为 Net-S ,它是参数量较小、模型结构相对简单的模型。同样的,对于输入 X ,其都能输出 Y , Y 经过 softmax 映射后同样能输出对应类别的概率值
(3)知识蒸馏训练: Net-T 学习能力强,可以将它学到的知识迁移给学习能力相对弱的 Net-S 模型,以此来增强 Net-T 模型的泛化能力。复杂笨重但是效果好的 Net-T 模型不上线,就单纯是个导师角色,真正部署上线进行预测任务的是灵活轻巧的 Net-S 小模型
5.知识蒸馏分类
知识蒸馏是对模型的能力进行迁移,根据迁移的方法不同可以简单分为基于目标蒸馏(也称为 Soft-target 蒸馏或 Logits 方法蒸馏)和基于特征蒸馏的算法两个大的方向
5.1 目标蒸馏-Logits方法
分类问题的共同点是模型最后会有一个 softmax 层,其输出值对应了相应类别的概率值。在知识蒸馏时,由于我们已经有了一个泛化能力较强的 Teacher 模型,我们利用 Teacher 模型来蒸馏训练 Student 模型时,可以直接让 Student 模型去学习 Teacher 模型的泛化能力。一个很直白且高效的迁移泛化能力的方法就是:使用 Tearcher 模型的 softmax 层输出的类别概率来作为 Soft-targets ,即 Soft-targets 作为训练 Student 模型的标签
5.1.1 Hard-targets、Soft-targets
传统的模型训练过程:采用 Hard-targets ,类别标签对应的 One-Hot 向量,除了正确类别为 1,其他类别都是 0
知识蒸馏训练过程:用 Teacher Model 的 class probabilities 作为 Student Model 的 Soft-targets
下面看一些例子
假设有 3 个类别,分别是 [马,驴子,汽车]
对于马这张图片,Hard-targets 和 Soft-targets 两种形式的类别标签如下:
Hard-targets、Soft-targets
使用 Hard-targets 来训练网络,对于马这张图片,相当于告诉网络这就是一匹马,不是驴和汽车,并且不是驴和汽车的程度是相等的,因为 Hard-targets 中驴和汽车的标签值都为 0
但通过肉眼其实能观察到马和驴其实是有一点相似的,马更像驴子而更不像汽车。Hard-targets 一个致命的缺陷就是秉持绝对的 “正确” ,对于非正确类别不提供任何额外的信息
使用 Soft-targets 来训练网络,对于马这种图片,相当于告诉网络是马的概率为 0.7 ,是汽车的概率为 0.05 ,说明马和汽车是非常不像的,而是驴的概率为 0.25,说明马和驴是有一点类似的。同理,驴和汽车也是非常不像的。显然 Soft-targets 传递的信息更多
其他例子
综上:
(1)在使用 Soft-targets 训练时,Student Model 可以很快学习到 Teacher Model 的推理过程
(2)传统的 Hard-targets 的训练方式,所有的负标签都会被平等对待。而 Soft-targets 包含了更多的“知识”和“信息”,像谁,不像谁,有多像,有多不像,特别是非正确类别概率的相对大小(驴和车)。所以我们可以用 Teacher 网络的预测结果作为训练 Student 网络的标签
5.1.2 蒸馏温度T
如果觉得 Teacher Model 的 Soft-targets 还不够 Soft ,即想把其他非正确类别的概率也变大,把它们的相对大小充分暴露出来,让学生网络能够有一个强烈的信号,知道这些非正确类别的更多信息,因此提出了蒸馏温度 T , T 越大,Soft-targets 就越 Soft
实现方法:在原始的 Softmax 公式加 T。T = 1 时,就是常规的 Softmax 操作
计算公式
下面我们来看一下 T = 1 与 T = 3 时的对比例子
不同温度的Softmax对比
(1)T = 1 时,普通的 Softmax 操作, 4 个概率值的两级分化十分严重,猫和狗的预测概率非常小,接近于 0 。而马的预测概率非常高,达到 0.88 。
(2)T = 3 时,使用蒸馏温度,4 个概率值相比较于 T = 1 的概率值两级分化情况没那么严重,变得更 Soft 了,而且各类别概率值的大小顺序和 T=1 时是一样的
5.1.3 蒸馏温度T的特点
- 原始的 Softmax 函数是 T = 1 时的特例;T<1 时,概率分布比原始更“陡峭”,即当 0<T<1 时,Softmax 的输出值会接近与 Hard-targets ;T>1 时,概率分布比原始更“平缓”
- 随着 T 的增加,Softmax 的输出分布会越来越平缓。温度越高,Softmax 上各个值的分布就越平均,极端情况的 Softmax 值是平均分布的,如下图 T = 100 时
- 温度的高低改变的是 Student 模型训练过程中对负标签的关注程度。温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;温度较高时,负标签相关的值会相对增大, Student 模型会相对更多地关注到负标签
- 针对较困难的分类或检测任务, Temperature 通常取 1 ,确保 Teacher Model 中正确预测的贡献
- 温度的选取需要进行实际实验的比较,本质上就是如下两种情况作取舍:
- 当想从负标签中多学到一些信息的时候,温度应调高一些
- 当想减少负标签的干扰的时候,温度应调低一些
总的来说,温度的选择和 Student 模型的大小有关,Student 模型参数量比较小时,相对比较低的温度就可以了。因为参数量小的模型不能学到所有 Teacher 模型的知识,所以可以适当忽略掉
5.2 特征蒸馏方法
另外一种知识蒸馏思路是特征蒸馏方法,它不像 Logits 方法那样 Student 只学习 Teacher 的 Logits 这种结果知识,而是学习 Teacher 网络结构中的中间层特征
最早采用这种模式的工作来自于论文《FITNETS:Hints for Thin Deep Nets》,它强迫 Student 某些中间层的网络响应,去逼近 Teacher 对应的中间层的网络响应。这种情况下, Teacher 中间特征层的响应,就是传递给 Student 的知识。在此之后,提出了各种新方法,但是大致思路还是这个思路,本质是 Teacher 将特征级知识迁移给 Student 。
6.知识蒸馏过程
知识蒸馏过程
损失函数计算
上图中的 distillation loss 和 student loss 也可以称为 soft loss 和 hard loss,两者都可以采用传统的交叉熵损失函数,交叉熵损失函数的计算公式如下:
交叉熵损失函数计算公式
下图是损失函数计算的例子
图解损失函数计算过程
实际的损失函数设计中,会给 soft loss 和 hard loss 增加一个权重系数,构成完成的损失函数,即:
Total loss 设计为 Soft Loss 与 Hard Loss 所对应的交叉熵的加权和,其中 越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小 Soft Loss 的比重,让真实标签帮助鉴别困难样本