Source code for gunz_ml.metrics.classification.sphereface

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math

[docs] class SphereFace(nn.Module): """ SphereFace: Deep Hypersphere Embedding for Face Recognition. Reference: Liu, W., Wen, Y., Yu, Z., Li, M., Raj, B., & Song, L. (2017). SphereFace: Deep Hypersphere Embedding for Face Recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Parameters ---------- num_features : int Number of input features (dimension of the embedding). num_classes : int Number of classes for classification. s : float, optional Scaling factor. Defaults to 30.0. m : float, optional Angular margin multiplicative factor. Defaults to 1.35. """ def __init__(self, num_features, num_classes, s=30.0, m=1.35): super(SphereFace, self).__init__() self.num_features = num_features self.n_classes = num_classes self.s = s self.m = m self.W = Parameter(torch.FloatTensor(num_classes, num_features)) nn.init.xavier_uniform_(self.W)
[docs] def forward(self, input, label=None): """ Forward pass of SphereFace. Parameters ---------- input : torch.Tensor Input features with shape (batch_size, num_features). label : torch.Tensor, optional Ground truth labels with shape (batch_size,). If None, returns raw logits. Returns ------- torch.Tensor Scaled logits with shape (batch_size, num_classes). """ # normalize features x = F.normalize(input) # normalize weights W = F.normalize(self.W) # dot product logits = F.linear(x, W) if label is None: return logits # add margin theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7)) target_logits = torch.cos(self.m * theta) one_hot = torch.zeros_like(logits) one_hot.scatter_(1, label.view(-1, 1).long(), 1) output = logits * (1 - one_hot) + target_logits * one_hot # feature re-scale output *= self.s return output