Source code for gunz_ml.metrics.classification.cosface

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math

[docs] class CosFace(nn.Module): """ CosFace: Large Margin Cosine Loss for Deep Face Recognition. Reference: Wang, H., Wang, Y., Zhou, Z., Ji, X., Gong, D., Zhou, J., ... & Liu, W. (2018). CosFace: Large Margin Cosine Loss for Deep Face Recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Parameters ---------- num_features : int Number of input features (dimension of the embedding). num_classes : int Number of classes for classification. s : float, optional Scaling factor. Defaults to 30.0. m : float, optional Cosine margin penalty. Defaults to 0.35. """ def __init__(self, num_features, num_classes, s=30.0, m=0.35): super(CosFace, self).__init__() self.num_features = num_features self.n_classes = num_classes self.s = s self.m = m self.W = Parameter(torch.FloatTensor(num_classes, num_features)) nn.init.xavier_uniform_(self.W)
[docs] def forward(self, input, label=None): """ Forward pass of CosFace. Parameters ---------- input : torch.Tensor Input features with shape (batch_size, num_features). label : torch.Tensor, optional Ground truth labels with shape (batch_size,). If None, returns raw logits. Returns ------- torch.Tensor Scaled logits with shape (batch_size, num_classes). """ #? normalize features x = F.normalize(input) #? normalize weights W = F.normalize(self.W) #? dot product logits = F.linear(x, W) if label is None: return logits #? add margin target_logits = logits - self.m one_hot = torch.zeros_like(logits) one_hot.scatter_(1, label.view(-1, 1).long(), 1) output = logits * (1 - one_hot) + target_logits * one_hot #? feature re-scale output *= self.s return output