Focal and Global Knowledge Distillation——目标检测网络的知识蒸馏
Paper地址:https://arxiv.org/abs/2111.11837
GitHub链接:https://github.com/yzd-v/FGD
方法

FGKD(Focal and Global Knowledge Distillation)通过Focal distillation与Global distillation的结合,兼顾了Instance-level信息、Spatial/Channel Attention以及全局相关性信息。
首先定义前背景分离Mask、Attention等,然后基于Feature map计算Focal distillation(由Feature loss与Attention loss构成),具体如下:
- 引入Binary mask分离前背景,其中r表示gt-box区域:

- 设置Scale mask以平衡前背景Loss,其中Hr与Wr表示gt-box的高与宽:

- 通过Reduced mean计算获得Spatial attention与Channel attention,并进一步通过Softmax计算获得Attention mask:


- 然后Feature loss计算如下(基于Teacher与Student的特征输出,通常是Neck特征):

- 然后计算Attention loss,并最终确定Focal distillation loss,其中l表示L1 loss:
![]()
![]()
def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
loss_mse = nn.MSELoss(reduction='sum')
Mask_fg = Mask_fg.unsqueeze(dim=1)
Mask_bg = Mask_bg.unsqueeze(dim=1)
C_t = C_t.unsqueeze(dim=-1)
C_t = C_t.unsqueeze(dim=-1)
S_t = S_t.unsqueeze(dim=1)
fea_t= torch.mul(preds_T, torch.sqrt(S_t))
fea_t = torch.mul(fea_t, torch.sqrt(C_t))
fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
fea_s = torch.mul(preds_S, torch.sqrt(S_t))
fea_s = torch.mul(fea_s, torch.sqrt(C_t))
fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
fg_loss = loss_mse(fg_fea_s, fg_fea_t) / len(Mask_fg)
bg_loss = loss_mse(bg_fea_s, bg_fea_t) / len(Mask_bg)
return fg_loss, bg_loss
Global distillation通过提取不同像素之间的相关性,以实现Context信息的迁移,其中R(F)表示GcBlock的特征转换操作:


def get_rela_loss(self, preds_S, preds_T):
loss_mse = nn.MSELoss(reduction='sum')
context_s = self.spatial_pool(preds_S, 0)
context_t = self.spatial_pool(preds_T, 1)
out_s = preds_S
out_t = preds_T
channel_add_s = self.channel_add_conv_s(context_s)
out_s = out_s + channel_add_s
channel_add_t = self.channel_add_conv_t(context_t)
out_t = out_t + channel_add_t
rela_loss = loss_mse(out_s, out_t) / len(out_s)
return rela_loss
实验结果
文章将FGKD应用于不同的目标检测器,并分析了不同Loss、Focal distillation、Global distillation以及温度系数T的敏感度(具体参考文章实验部分)。不同目标检测器上的实验结果如下:

