论文题目:《The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification》
期刊与时间:IEEE Transactions on Image Processing 2020 (TIP 2020)
论文地址:https://arxiv.org/pdf/2002.04264
源码地址(PyTorch版本):https://github.com/dongliangchang/Mutual-Channel-Loss
针对领域:细粒度图像分类(FGVC)
以resnet为例
class model_bn(nn.Module): def __init__(self, feature_size=512,classes_num=200): super(model_bn, self).__init__() # 定义特征提取网络,删掉原resnet中的全局平均池化和全连接层 # 注意,作者将最后一层输出的特征图通道数改了,改成了分类数*ζ,作者源码中以600为例(200*3) # 只改变了layer4输出特征图的通道数,其他部分与原始的resnet相同 self.features = nn.Sequential(*list(net.children())[:-2]) # 全局最大池化 self.max = nn.MaxPool2d(kernel_size=14, stride=14) # 特征图通道数。这里做了修改,与原通道数不一样 self.num_ftrs = 600*1*1 # 分类器,依次为批量标准化、线性回归、批量标准化ELU激活函数、线性回归 self.classifier = nn.Sequential( nn.BatchNorm1d(self.num_ftrs), #nn.Dropout(0.5), nn.Linear(self.num_ftrs, feature_size), nn.BatchNorm1d(feature_size), nn.ELU(inplace=True), #nn.Dropout(0.5), nn.Linear(feature_size, classes_num), ) def forward(self, x, targets): # 首先图片经过特征提取,得到特征图 x = self.features(x) # 之后经过MC_Loss模块,得到MC损失 if self.training: MC_loss = supervisor(x, targets, height=14, cnum=3) # 特征图依次经过全局最大池化,得到特征向量 x = self.max(x) x = x.view(x.size(0), -1) # 再经过分类器,得到网络的预测值 x = self.classifier(x) # 求交叉熵损失 loss = criterion(x, targets) # 如果是训练阶段,则返回预测值和预测损失的同时,还需要返回MC损失 if self.training: return x, loss, MC_loss # 如果是测试阶段,则只需要返回预测值与损失 else: return x, loss
def supervisor(x, targets, height, cnum): # 首先得到掩模图 mask = Mask(x.size(0), cnum).cpu() branch = x # 将特征图改变形状,变成(batch,200*ζ,h*w),ζ表示多少特征图代表一类,作者以ζ=3为例 # 第二维度(dim=2)表示特征图中的特征数据 branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3)) # 将特征数据放入softmax,沿第二维度进行归一化操作(相当沿原来特征图上的数据进行扫描),对应论文中公式(7)后半段 branch = F.softmax(branch, 2) # 再将特征图变回原来的形状 branch = branch.reshape(branch.size(0), branch.size(1), x.size(2), x.size(2)) # 将归一化后的数据传入CCMP模块,对应论文中公式(7)前半段 branch = my_MaxPool2d(kernel_size=(1, cnum), stride=(1, cnum))(branch) # 特征图经过CCMP之后,通道数变为分类数,之后再转化一下形状 # 转化为(batch,200,w*h) branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3)) # 之后首先对branch中的元素按第二维度求和,即对特征数据求和 # 之后再对所有通道取平均值,对应论文公式(6) loss_2 = 1.0 - 1.0 * torch.mean(torch.sum(branch, 2)) / cnum# set margin = 3.0 # CWA模块:掩模图M与特征图相乘 branch_1 = x * mask # CCMP模块,将所有特征图取相应类别的最大值,(对于每一类,3张压缩成1张),得到的特征图尺寸为(batch,200,h,w) branch_1 = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch_1) # 全局平均化,得到每一类的预测分数(h*w个值压缩成1个数),最终得到论文中公式(5)的结果 branch_1 = nn.AvgPool2d(kernel_size=(height,height))(branch_1) # 压扁,便于后续取交叉熵损失 branch_1 = branch_1.view(branch_1.size(0), -1) # 取交叉熵损失,对于论文中公式(4) loss_1 = criterion(branch_1, targets) # 返回损失 return [loss_1, loss_2]
计算CWA模块中的掩模图M:
# 得到CWA模块中的掩模图M def Mask(nb_batch, channels): # 假设三张特征图表示一个类别,即论文中的参数ζ为3 # 此时一组掩模M_i中由两个1,一个0组成 foo = [1] * 2 + [0] * 1 # 初始化总的M列表 bar = [] # 这里的200表示分类数 for i in range(200): # 打乱初始化后M_i中的元素,表示随机生成M_i random.shuffle(foo) # 与总列表合并 bar += foo # 按批次(batch)复制 bar = [bar for i in range(nb_batch)] # 转换成array格式 bar = np.array(bar).astype("float32") # 转换形状,转换成(batch,200*ζ,1,1),前两个维度中,掩模和特征图大小相同,便于后续的点乘操作 bar = bar.reshape(nb_batch, 200 * channels, 1, 1) # 转换成tensor格式,之后放入显卡,再令其可求导 bar = torch.from_numpy(bar) bar = bar.cuda() bar = Variable(bar) # 最后返回掩模M return bar
CCMP模块:
class my_MaxPool2d(Module): def __init__(self, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False): super(my_MaxPool2d, self).__init__() # 最大池化的一系列参数,可以在定义my_MaxPool2d的同时引入 self.kernel_size = kernel_size self.stride = stride or kernel_size self.padding = padding self.dilation = dilation self.return_indices = return_indices self.ceil_mode = ceil_mode def forward(self, input): # 将输入的1,3维度进行交换,即将通道维度与图片的宽w交换,得到(batch,w,h,600)的数据(以CUB数据集为例) input = input.transpose(3,1) # 最大池化,注意,此时池化核为(1, cnum) # 相当于在原始的三张特征图中沿通道选择最大值,对应论文中公式(5)的中间部分(CCMP) input = F.max_pool2d(input, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode, self.return_indices) # 再将特征图维度变回去,变成正常的尺寸,即(1,200,h,w) input = input.transpose(3,1).contiguous() # 最后返回特征图 return input
以上内容仅是笔者的个人观点,若有错误,欢迎大家批评指正。
笔记原创,未经同意禁止转载!