Source code for gunz_ml.metrics.classification.adacos

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math

[docs] class AdaCos(nn.Module): """ AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations. Reference: Zhang, X., Zhao, R., Qiao, Y., Wang, X., & Li, H. (2019). AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). Parameters ---------- num_features : int Number of input features (dimension of the embedding). num_classes : int Number of classes for classification. m : float, optional Margin parameter. Defaults to 0.50. """ def __init__(self, num_features, num_classes, m=0.50): super(AdaCos, self).__init__() self.num_features = num_features self.n_classes = num_classes self.s = math.sqrt(2) * math.log(num_classes - 1) self.m = m self.W = Parameter(torch.FloatTensor(num_classes, num_features)) nn.init.xavier_uniform_(self.W)
[docs] def forward(self, input, label=None): """ Forward pass of AdaCos. Parameters ---------- input : torch.Tensor Input features with shape (batch_size, num_features). label : torch.Tensor, optional Ground truth labels with shape (batch_size,). If None, returns raw logits. Returns ------- torch.Tensor Scaled logits with shape (batch_size, num_classes). """ # normalize features x = F.normalize(input) # normalize weights W = F.normalize(self.W) # dot product logits = F.linear(x, W) if label is None: return logits # feature re-scale theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7)) one_hot = torch.zeros_like(logits) one_hot.scatter_(1, label.view(-1, 1).long(), 1) with torch.no_grad(): B_avg = torch.where(one_hot < 1, torch.exp(self.s * logits), torch.zeros_like(logits)) B_avg = torch.sum(B_avg) / input.size(0) # print(B_avg) theta_med = torch.median(theta[one_hot == 1]) self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med)) output = self.s * logits return output