import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math
[docs]
class SphereFace(nn.Module):
"""
SphereFace: Deep Hypersphere Embedding for Face Recognition.
Reference:
Liu, W., Wen, Y., Yu, Z., Li, M., Raj, B., & Song, L. (2017).
SphereFace: Deep Hypersphere Embedding for Face Recognition.
In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR).
Parameters
----------
num_features : int
Number of input features (dimension of the embedding).
num_classes : int
Number of classes for classification.
s : float, optional
Scaling factor. Defaults to 30.0.
m : float, optional
Angular margin multiplicative factor. Defaults to 1.35.
"""
def __init__(self, num_features, num_classes, s=30.0, m=1.35):
super(SphereFace, self).__init__()
self.num_features = num_features
self.n_classes = num_classes
self.s = s
self.m = m
self.W = Parameter(torch.FloatTensor(num_classes, num_features))
nn.init.xavier_uniform_(self.W)
[docs]
def forward(self, input, label=None):
"""
Forward pass of SphereFace.
Parameters
----------
input : torch.Tensor
Input features with shape (batch_size, num_features).
label : torch.Tensor, optional
Ground truth labels with shape (batch_size,). If None, returns raw logits.
Returns
-------
torch.Tensor
Scaled logits with shape (batch_size, num_classes).
"""
# normalize features
x = F.normalize(input)
# normalize weights
W = F.normalize(self.W)
# dot product
logits = F.linear(x, W)
if label is None:
return logits
# add margin
theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
target_logits = torch.cos(self.m * theta)
one_hot = torch.zeros_like(logits)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = logits * (1 - one_hot) + target_logits * one_hot
# feature re-scale
output *= self.s
return output