import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
[docs]
class CosFace(nn.Module):
"""
CosFace: Large Margin Cosine Loss for Deep Face Recognition.
Reference:
Wang, H., Wang, Y., Zhou, Z., Ji, X., Gong, D., Zhou, J., ... & Liu, W. (2018).
CosFace: Large Margin Cosine Loss for Deep Face Recognition.
In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
Parameters
----------
num_features : int
Number of input features (dimension of the embedding).
num_classes : int
Number of classes for classification.
s : float, optional
Scaling factor. Defaults to 30.0.
m : float, optional
Cosine margin penalty. Defaults to 0.35.
"""
def __init__(self, num_features, num_classes, s=30.0, m=0.35):
super(CosFace, self).__init__()
self.num_features = num_features
self.n_classes = num_classes
self.s = s
self.m = m
self.W = Parameter(torch.FloatTensor(num_classes, num_features))
nn.init.xavier_uniform_(self.W)
[docs]
def forward(self, input, label=None):
"""
Forward pass of CosFace.
Parameters
----------
input : torch.Tensor
Input features with shape (batch_size, num_features).
label : torch.Tensor, optional
Ground truth labels with shape (batch_size,). If None, returns raw logits.
Returns
-------
torch.Tensor
Scaled logits with shape (batch_size, num_classes).
"""
#? normalize features
x = F.normalize(input)
#? normalize weights
W = F.normalize(self.W)
#? dot product
logits = F.linear(x, W)
if label is None:
return logits
#? add margin
target_logits = logits - self.m
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = logits * (1 - one_hot) + target_logits * one_hot
#? feature re-scale
output *= self.s
return output