知识蒸馏的蒸馏损失方法代码总结(包括:基于logits的方法:KLDiv,dist,dkd等,基于中间层提示的方法:)

有两种知识蒸馏方法:一种利用教师模型的输出概率(基于logits的方法)[15,14,11],另一种利用教师模型的中间表示(基于提示的方法)[12,13,18,17]。基于logits的方法利用教师的输出作为辅助信号来训练一个较小的模型,即学生模型:

利用教师模型的输出概率(基于logits的方法)

该类方法损失函数为:
在这里插入图片描述

DIST

Tao Huang,Shan You,Fei Wang,Chen Qian,and Chang Xu.Knowledge distillation from a strongerteacher.In Advances in Neural Information Processing Systems,2022.

import torch.nn as nn


def cosine_similarity(a, b, eps=1e-8):
    return (a * b).sum(1) / (a.norm(dim=1) * b.norm(dim=1) + eps)


def pearson_correlation(a, b, eps=1e-8):
    return cosine_similarity(a - a.mean(1).unsqueeze(1),
                             b - b.mean(1).unsqueeze(1), eps)


def inter_class_relation(soft_student_outputs, soft_teacher_outputs):
    return 1 - pearson_correlation(soft_student_outputs, soft_teacher_outputs).mean()


def intra_class_relation(soft_student_outputs, soft_teacher_outputs):
    return inter_class_relation(soft_student_outputs.transpose(0, 1), soft_teacher_outputs.transpose(0, 1))


class DIST(nn.Module):
    def __init__(self, beta=1.0, gamma=1.0, temp=1.0):
        super(DIST, self).__init__()
        self.beta = beta
        self.gamma = gamma
        self.temp = temp

    def forward(self, student_preds, teacher_preds, **kwargs):
        soft_student_outputs = (student_preds / self.temp).softmax(dim=1)
        soft_teacher_outputs = (teacher_preds / self.temp).softmax(dim=1)
        inter_loss = self.temp ** 2 * inter_class_relation(soft_student_outputs, soft_teacher_outputs)
        intra_loss = self.temp ** 2 * intra_class_relation(soft_student_outputs, soft_teacher_outputs)
        kd_loss = self.beta * inter_loss + self.gamma * intra_loss
        return kd_loss

KLDiv (2015年的原始方法)

import torch.nn as nn
import torch.nn.functional as F

# loss = alpha * hard_loss + (1-alpha) * kd_loss,此处是单单的kd_loss
class KLDiv(nn.Module):
    def __init__(self, temp=1.0):
        super(KLDiv, self).__init__()
        self.temp = temp

    def forward(self, student_preds, teacher_preds, **kwargs):
        soft_student_outputs = F.log_softmax(student_preds / self.temp, dim=1)
        soft_teacher_outputs = F.softmax(teacher_preds / self.temp, dim=1)
        kd_loss = F.kl_div(soft_student_outputs, soft_teacher_outputs, reduction="none").sum(1).mean()
        kd_loss *= self.temp ** 2
        return kd_loss

dkd (Decoupled KD(CVPR 2022) )

Borui Zhao,Quan Cui,Renjie Song,Yiyu Qiu,and Jiajun Liang.Decoupled knowledge distillation.InIEEE/CVF Conference on Computer Vision and Pattern Recognition,2022.

import torch
import torch.nn as nn
import torch.nn.functional as F


def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
    gt_mask = _get_gt_mask(logits_student, target)
    other_mask = _get_other_mask(logits_student, target)
    pred_student = F.softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    pred_student = cat_mask(pred_student, gt_mask, other_mask)
    pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
    log_pred_student = torch.log(pred_student)
    tckd_loss = (
            F.kl_div(log_pred_student, pred_teacher, reduction='batchmean')
            * (temperature ** 2)
    )
    pred_teacher_part2 = F.softmax(
        logits_teacher / temperature - 1000.0 * gt_mask, dim=1
    )
    log_pred_student_part2 = F.log_softmax(
        logits_student / temperature - 1000.0 * gt_mask, dim=1
    )
    nckd_loss = (
            F.kl_div(log_pred_student_part2, pred_teacher_part2, reduction='batchmean')
            * (temperature ** 2)
    )
    return alpha * tckd_loss + beta * nckd_loss


def _get_gt_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask


def _get_other_mask(logits, target):
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask


def cat_mask(t, mask1, mask2):
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    t2 = (t * mask2).sum(1, keepdims=True)
    rt = torch.cat([t1, t2], dim=1)
    return rt


class DKD(nn.Module):
    def __init__(self, alpha=1., beta=2., temperature=1.):
        super(DKD, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.temperature = temperature

    def forward(self, z_s, z_t, **kwargs):
        target = kwargs['target']
        if len(target.shape) == 2:  # mixup / smoothing
            target = target.max(1)[1]
        kd_loss = dkd_loss(z_s, z_t, target, self.alpha, self.beta, self.temperature)
        return kd_loss

利用教师模型的中间表示(基于提示的方法)

该类方法损失函数为:
[ L_{hint} = D_{hint}(T_s(F_s), T_t(F_t)) ]

ReviewKD (CVPR2021)

论文:

Pengguang Chen,Shu Liu,Hengshuang Zhao,and Jiaya Jia.Distilling knowledge via knowledge review.In IEEE/CVF Conference on Computer Vision and Pattern Recognition,2021.

代码:

https://github.com/dvlab-research/ReviewKD

Adriana Romero,Nicolas Ballas,Samira Ebrahimi Kahou,Antoine Chassang,Carlo Gatta,and YoshuaBengio.Fitnets:Hints for thin deep nets.arXiv preprint arXiv:1412.6550,2014.

Yonglong Tian,Dilip Krishnan,and Phillip Isola.Contrastive representation distillation.In IEEE/CVFInternational Conference on Learning Representations,2020.

Baoyun Peng,Xiao Jin,Jiaheng Liu,Dongsheng Li,Yichao Wu,Yu Liu,Shunfeng Zhou,and ZhaoningZhang.Correlation congruence for knowledge distillation.In International Conference on ComputerVision,2019.

关于知识蒸馏损失函数的文章

FitNet(ICLR 2015)、Attention(ICLR 2017)、Relational KD(CVPR 2019)、ICKD (ICCV 2021)、Decoupled KD(CVPR 2022) 、ReviewKD(CVPR 2021)等方法的介绍:

https://zhuanlan.zhihu.com/p/603748226?utm_id=0

待更新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/217911.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

图像万物分割——Segment Anything算法解析与模型推理

一、概述 在视觉任务中,图像分割任务是一个很广泛的领域,应用于交互式分割,边缘检测,超像素化,感兴趣目标生成,前景分割,语义分割,实例分割,泛视分割等。 交互式分割&am…

(c语言进阶)offsetof函数——计算结构体元素的偏移量

一.基本概念&#xff1a; 头文件&#xff1a;<stddef.h> offsetof(结构体名,结构体元素名) 返回值为size_t&#xff08;unsigned int&#xff09; 二.应用 例题&#xff1a; #include<stdio.h> #include<stddef.h> typedef struct S1 {char a;int b;c…

Clean 架构下的现代 Android 架构指南

Clean 架构下的现代 Android 架构指南 Clean 架构是 Uncle Bob 提出的一种软件架构&#xff0c;Bob 大叔同时也是 SOLID 原则的命名者。 Clean 架构图如下&#xff1a; 这张图描述的是整个软件系统的架构&#xff0c;而不是单体软件&#xff0c;其中至少包括服务端以及客户端…

JVM类加载全过程

Java虚拟机类加载的全过程&#xff0c;即加载&#xff0c;验证&#xff0c;准备&#xff0c;解析&#xff0c;初始化 一、加载 加载 是 类加载过程中的一个阶段&#xff0c; 有以下三部分组成 1&#xff09;通过一个类的全限定名来获取定义此类的二进制流 2&#xff09;将这…

【Java 基础】19 多线程基础

文章目录 进程和线程进程&#xff08;Process&#xff09;线程&#xff08;Thread&#xff09; 线程的创建1&#xff09;继承 Thread 类2&#xff09;实现 Runnable 接口3&#xff09;使用 Lambda 表达式4&#xff09;总结 线程的状态状态的分类状态间转换 多线程是一种 同时执…

Github无法打开

文章目录 一、问题二、解决2.1、科学上网&#xff08;使用中&#xff09;2.2、使用代理&#xff08;不稳定&#xff09;2.3、修改hosts&#xff08;得更新&#xff09;2.3.1、找到hosts文件2.3.2、复制hosts文件2.3.3、添加记录2.3.4、替换原来的hosts文件2.3.5、成功访问Githu…

CefSharp 获取POST(AJAX)、GET消息返回值(request)

CefSharp作为专门为爬虫工具开发的库比Selenium这种开发目的是页面测试工具然后用来做爬虫的工具要贴心得多。我们操作网页的时候发送或者做了某个动作提交表单之后需要知道我们的动作或者提交是否成功&#xff0c;因为有的页面会因为网络延迟问题提交失败&#xff0c;需要准确…

[Azure]azure磁盘加密(Windows/Linux) ADE(Azure Disk Encryption)

Azure 磁盘加密用于保护数据&#xff0c;对于Windows使用BitLocker对磁盘进行加密&#xff0c;同时与Key Vault集成&#xff0c;控制和管理Key和Secret。 本文利用Potal对磁盘进行加密 注&#xff1a;Azure DIsk Encryption 可能会导致VM重启&#xff0c;对VM造成影响&#xff…

哈希与哈希表

哈希表的概念 哈希表又名散列表&#xff0c;官话一点讲就是&#xff1a; 散列表&#xff08;Hash table&#xff0c;也叫哈希表&#xff09;&#xff0c;是根据关键码值(Key value)而直接进行访问的数据结构。也就是说&#xff0c;它通过把关键码值映射到表中一个位置来访问记…

基于SSM的老年公寓信息管理的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

nodejs微信小程序+python+PHP天天网站书城管理系统的设计与实现-计算机毕业设计推荐

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性&#xff1a;…

QT 中 QMessageBox 的简单用法

效果 思路 // 创建一个question弹出对话框&#xff0c;添加两个按钮&#xff1a;Yes和NoQMessageBox *box new QMessageBox(QMessageBox::Question, "提示", "确认删除的信息吗&#xff1f;", QMessageBox::Yes | QMessageBox::No, this);box->button(…

成人学生钢笔练字快速入门,硬笔书法行书楷书教程合集

一、教程描述 虽然现在都是电脑打字&#xff0c;需要手写的场合越来越少&#xff0c;但是可以写一手人见人爱&#xff0c;花见花开的好字&#xff0c;仍然是很拉风很惊艳的&#xff0c;可以给人留下深刻印象。本套硬笔书法教程&#xff0c;大小40.90G&#xff0c;共有591个文件…

Python 网络爬虫(二):HTTP 基础知识

《Python入门核心技术》专栏总目录・点这里 文章目录 1. HTTP 协议简述2. HTTP 请求过程3. HTTP 的结构3.1 请求行3.2 请求头3.3 请求体3.4 状态行3.5 响应头3.6 响应体 4. Cookie 状态管理5. HTTP 请求示例6. 总结 大家好&#xff0c;我是水滴~~ 在准备学习网络爬虫之前&…

公有云迁移研究——AWS Route53

大纲 1 什么是Route 532 Route 53能做些什么# 3 通过DNS托管来实现分流3.1 创建DNS托管3.2 对托管创建记录对流量进行分配 4 通过流量策略来对流量进行分流4.1 创建流量策略 5 对比两者的区别6 推荐 在给客户从本地机房往AWS迁移的过程中&#xff0c;我们接到如下需求&#xff…

vue打印功能

安装 vue3-print-nb yarn add vue3-print-nb //或 npm install vue3-print-nb main.js中引入 vue3-print-nb import { createApp } from vue; import App from ./App.vue; const app createApp(App); // 打印插件 import print from vue3-print-nb app.use(print) // 页面…

公有云迁移研究——AWS Translate

大纲 1 什么是Translate2 Aws Translate是怎么运作的3 Aws Translate和Google Translate的区别4 迁移任务4.1 迁移原因 5 Aws Translate的Go demo6 迁移中遇到的问题6.1 账号和权限问题&#xff1a;6.2 小语种 1 什么是Translate Translate是一种文本翻译服务&#xff0c;它使…

制作一个RISC-V的操作系统三-编译与链接

文章目录 GCCGCC简介GCC的命令格式gcc -Egcc -cgcc -Sgcc -ggcc -vGCC的主要执行步骤GCC涉及的文件类型针对多个源文件的处理 ELFELF介绍ELF文件格式ELF文件处理相关工具&#xff1a;Binutils&#xff08;binary utility&#xff09;readlelf -hreadelf -S或readelf -SW&#x…

思维模型 路径依赖定律

本系列文章 主要是 分享 思维模型&#xff0c;涉及各个领域&#xff0c;重在提升认知。难以摆脱的惯性。 1 路径依赖定律的应用 1.1 打破路径依赖定律的苹果 在 20 世纪 80 年代&#xff0c;苹果公司推出了 Macintosh 电脑&#xff0c;这是一款具有图形用户界面和鼠标的创新产…

MacOS M芯片 安装MySQL5.7教程

目录 1. 安装Homebrew1.1 快速安装1.2 检查是否安装成功 2. 通过Homebrew安装MySQL2.1 搜索 MySQL 版本2.2 安装MySQL 5.72.3 位置说明2.4 启动MySQL服务2.5 检查服务状态2.6 设置环境变量2.7 重置密码 3. 测试安装 1. 安装Homebrew 1.1 快速安装 /bin/bash -c "$(curl …