import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
[docs]
class AdaCos(nn.Module):
"""
AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations.
Reference:
Zhang, X., Zhao, R., Qiao, Y., Wang, X., & Li, H. (2019).
AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations.
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).
Parameters
----------
num_features : int
Number of input features (dimension of the embedding).
num_classes : int
Number of classes for classification.
m : float, optional
Margin parameter. Defaults to 0.50.
"""
def __init__(self, num_features, num_classes, m=0.50):
super(AdaCos, self).__init__()
self.num_features = num_features
self.n_classes = num_classes
self.s = math.sqrt(2) * math.log(num_classes - 1)
self.m = m
self.W = Parameter(torch.FloatTensor(num_classes, num_features))
nn.init.xavier_uniform_(self.W)
[docs]
def forward(self, input, label=None):
"""
Forward pass of AdaCos.
Parameters
----------
input : torch.Tensor
Input features with shape (batch_size, num_features).
label : torch.Tensor, optional
Ground truth labels with shape (batch_size,). If None, returns raw logits.
Returns
-------
torch.Tensor
Scaled logits with shape (batch_size, num_classes).
"""
# normalize features
x = F.normalize(input)
# normalize weights
W = F.normalize(self.W)
# dot product
logits = F.linear(x, W)
if label is None:
return logits
# feature re-scale
theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
with torch.no_grad():
B_avg = torch.where(one_hot < 1, torch.exp(self.s * logits), torch.zeros_like(logits))
B_avg = torch.sum(B_avg) / input.size(0)
# print(B_avg)
theta_med = torch.median(theta[one_hot == 1])
self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med))
output = self.s * logits
return output