超大规模分类(四):Partial FC

news/2025/2/27 5:16:32

人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题:

  • 真实的人脸识别数据噪声严重
  • 真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少
  • 人脸类别越来越多,全连接层训练成本越来越高,难度越来越大

于是,有研究人员提出Partial FC,拒绝全量更新负类别中心,而是仅更新少部分负类别中心。该做法优势在于

  • 降低噪声数据被采样的概率
  • 降低高频负类别中心被选中的概率
  • 降低负类别中心的更新频率,降低训练难度
    如下图所示:
    ![[Pasted image 20250219215540.png]]

问题建模

人脸识别领域,常用的分类损失公式化定义如下:
L = − 1 B ∑ i = 1 B l o g e W y i T ⋅ x i e W y i T ⋅ x i + ∑ j = 1 , j ≠ y i C e W j T ⋅ x i (1) L=-\frac{1}{B}\sum_{i=1}^{B}log\frac{e^{W_{y_i}^T\cdot x_i}}{e^{W_{y_i}^T}\cdot x_i+\sum_{j=1,j\neq y_i}^Ce^{W_j^T\cdot x_i}} \tag{1} L=B1i=1BlogeWyiTxi+j=1,j=yiCeWjTxieWyiTxi(1)
,其中, B B B表示batch size, C C C表示类别个数, W j T W_{j}^T WjT表示第 j j j个类别中心的特征, ( x i , y i ) (x_i,y_i) (xi,yi)表示第 i i i个样本的特征为 x i x_i xi,类别为 y i y_i yi

真实大规模人脸数据实际使用时,有以下问题:![[Pasted image 20250219222626.png]]

  • 噪声问题:见上图(a),图片对都是一个人的图片,但是被分到不同的类别,这对模型训练有非常大的干扰。
  • 长尾分布:见上图(b),大部分类别(identity)包含的图像数量很少,在WebFace42M中,44.57%的类别包含的图像数量少于10张。这会导致低频类别的类别中心更新缓慢,而高频类别的类别中心更新频繁。
  • 训练资源:全连接层一般表示为 W ∈ R D × C W\in \mathbb{R}^{D\times C} WRD×C,其中 D D D表示维度, C C C表示类别数。假设 D = 512 D=512 D=512,如果类别数是1000,000(一百万)
    • fp16下,全连接层的显存消耗为: 512 × 100 , 000 × 2 1024 × 1024 × 1024 = 0.95 G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}=0.95GB 1024×1024×1024512×100,000×2=0.95GB
    • 公式(1)中,需要计算 B B B x i x_i xi属于类别中心 W j T W_{j}^T WjT的logit,维度是 R B × D × C \mathbb{R}^{B\times D\times C} RB×D×C,显存消耗为 512 × 100 , 000 × 2 1024 × 1024 × 1024 ⋅ B = 0.95 B   G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}\cdot B=0.95B \,GB 1024×1024×1024512×100,000×2B=0.95BGB,batchsize越大,需要的显存越大。
    • 在下图,进行了模型并行和partial fc在显存消耗和训练速度上的比较,可以发现:
      • partial fc显著降低了对logit的显存消耗
      • partial fc略微降低了对存储类别中心的显存消耗
      • partial fc未降低对特征抽取网络的显存消耗(将原图像转换为特征的模型的消耗)
      • 由于partial fc减少了负类别中心的数量,降低了logit计算的复杂度,随着训练类别越多,加速比越高。
        ![[Pasted image 20250219222640.png]]

partial fc

为了缓解上述问题,提出了partial fc,通过稀疏更新全连接层的参数,来支持大规模人脸识别模型的训练。
整体架构如下图所示:![[Pasted image 20250219222722.png]]

模型通过数据并行训练的,不同GPU包含了不同数据的特征,整体步骤如下:

  • 汇总不同GPU里的图像特征和图像标签
  • 将汇总的图像特征和图像标签送到每张GPU上
  • 将全连接层(即 C C C个类别中心)均分到每张GPU上
  • 在单张卡上,保留需要的正类别中心,以及采样固定比例的负类别中心
  • 利用样本、正类别中心、负类别中心计算损失函数

代码实现

    def forward(
        self,
        local_embeddings: torch.Tensor,
        local_labels: torch.Tensor,
    ):
        """
        Parameters:
        ----------
        local_embeddings: torch.Tensor
            feature embeddings on each GPU(Rank).
        local_labels: torch.Tensor
            labels on each GPU(Rank).
        Returns:
        -------
        loss: torch.Tensor
            pass
        """
        local_labels.squeeze_()
        local_labels = local_labels.long()

        batch_size = local_embeddings.size(0)
        if self.last_batch_size == 0:
            self.last_batch_size = batch_size
        assert self.last_batch_size == batch_size, (
            f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")

        _gather_embeddings = [
            torch.zeros((batch_size, self.embedding_size)).cuda()
            for _ in range(self.world_size)
        ]
        _gather_labels = [
            torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
        ]
        _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
        distributed.all_gather(_gather_labels, local_labels)

		# 汇总不同GPU里的图像特征和图像标签
        embeddings = torch.cat(_list_embeddings)
        labels = torch.cat(_gather_labels)

        labels = labels.view(-1, 1)
        # self.class_start表示该GPU中,分配的类别中心起始id
        # self.num_local表示该GPU中,分配的类别中心数量
        # 于是,该GPU的类别中心id范围是[类别中心起始id, 类别中心起始id + 类别中心数量]
        # 在单张卡上,仅保留需要的正类别中心
        index_positive = (self.class_start <= labels) & (
            labels < self.class_start + self.num_local
        )
        labels[~index_positive] = -1
        labels[index_positive] -= self.class_start

		# 在单张卡上,采样固定比例的负类别中心
        if self.sample_rate < 1:
            weight = self.sample(labels, index_positive)
        else:
            weight = self.weight

        with torch.cuda.amp.autocast(self.fp16):
            norm_embeddings = normalize(embeddings)
            norm_weight_activated = normalize(weight)
            logits = linear(norm_embeddings, norm_weight_activated)
        if self.fp16:
            logits = logits.float()
        logits = logits.clamp(-1, 1)

		# 基于样本特征、样本标签、正类别中心,采样的负类别中心,计算损失
        logits = self.margin_softmax(logits, labels)
        loss = self.dist_cross_entropy(logits, labels)
        return loss

优势

partial fc的核心思想是”降低训练中负类别中心数量,显式得减少需要更新的参数量“。负类别中心采样比例越低,节约的显存越多。

为了更好理解partial fc对长尾分布、噪声问题的影响,计算分类损失对于样本 x i x_i xi的梯度,如下:
∂ L ∂ x i = − ( ( 1 − p + ) W + − ∑ j ∈ S , j ≠ y i p j − W j − ) (2) \frac{\partial L}{\partial x_i}=-((1-p^+)W^+-\sum_{j\in \mathbb{S}, j\neq y_i}p_j^-W_j^-) \tag{2} xiL=((1p+)W+jS,j=yipjWj)(2)
其中, p + p^+ p+ p − p^- p分别表示通过样本特征 x i x_i xi计算的logit分数, S \mathbb{S} S表示负类别中心, ∣ S ∣ = C × r |\mathbb{S}|=C\times r S=C×r,通过采样比例 r r r,调整训练时的负样本数量。

样本特征 x i x_i xi的更新方向和正类别中心和负类别中心都有关系,partial fc随机减少负类别中心数量,减低噪声数据被采样的概率,降低高频负类别中心被选中的概率,有效缓解长尾问题和噪声问题。

注意:采样率为1,等同于选取所有负类别中心,进行模型训练。相当于原始fc分类

为了进一步验证partial fc的作用原理,做了下述验证下实验

探究采样率与类内、类间相似度关系

![[Pasted image 20250224175118.png]]

(a)图中,采样率越低,APCS收敛至更高数值。APCS表示类内距离 A P C S = 1 B ∑ i = 1 B W y i T x i ∣ ∣ W y i ∣ ∣ ⋅ ∣ ∣ x i ∣ ∣ APCS=\frac{1}{B}\sum_{i=1}^B\frac{W_{y_i}^Tx_i}{||W_{y_i}||\cdot ||x_i||} APCS=B1i=1B∣∣Wyi∣∣∣∣xi∣∣WyiTxi,说明采样率越低,类内相似度越大,类内越紧密。

(b)图中,采样率越低,MICS分布越往右,整体数值越大。MICS表示最大的类间余弦相似度 MICS i = max ⁡ j ≠ i W i T W j ∥ W i ∥   ∥ W j ∥ \text{MICS}_i = \max_{j \neq i} \frac{W_i^T W_j}{\|W_i\| \, \|W_j\|} MICSi=j=imaxWiWjWiTWj,说明采样率越低,类间相似度越大,类间拉不开。

探究采样率与评测集合效果关系

随着采样率越来越低,IJB-C、MFR-All评测集上的效果越来越差,如下:
![[Pasted image 20250226180734.png]]

探究采样率对噪声数据的鲁棒性

为验证在噪声数据上的效果,做了如下实验:
![[Pasted image 20250226181829.png]]
构造了WebFace12M-Conflict数据集,随机将20万个类的样本放到另外60万个类别数据中。
图(a)纵坐标为AMNCS,指最小的负类中心距离(越大,说明样本和负类离得远)。可以发现,降低采样率(1.0->0.1),在干净数据上,效果接近;在噪声数据上,缓解过拟合问题。

图(b)的MICS,指最大的类间余弦相似度(越大,说明类别越相似,越区分不开类别)。可以发现,降低采样率(1.0->0.1),MICS分布往右,整体数值偏大。由于WebFace12M-Conflict数据集中,20万类的样本随机分布在其他类别中,类间余弦相似度本就很大,图(b)更好的刻画实际噪声分布。

![[Pasted image 20250226210518.png]]

上图定义了两个概念,分别是conflict-hard和conflict-noise。conflict-hard表示利用真实负样本计算AMNCS;conflict-noise表示利用噪声负样本计算AMNCS。结果表明:

  • r=1.0时,针对AMNCS指标,conflict-hard>conflict-noise,表明负样本不采样,会使得模型过分拟合数据集,导致对噪声数据不鲁棒(按理说应该是conflict-noise>conflict-hard)。
  • r=0.1时,针对AMNCS指标,conflict-hard<conflict-noise,刻画出真实数据特性。

消融实验

不同数据集、不同采样率下partial fc

![[Pasted image 20250226212125.png]]

不同网络结构下partial fc

![[Pasted image 20250226212145.png]]

对噪声数据鲁棒

采用WebFace12M-Conflict作为训练集合。
![[Pasted image 20250226212355.png]]

![[Pasted image 20250226212449.png]]

对长尾数据鲁棒

![[Pasted image 20250226212618.png]]

收敛速度、训练时间

![[Pasted image 20250226212802.png]]


http://www.niftyadmin.cn/n/5869478.html

相关文章

Nacos + Dubbo3 实现微服务的Rpc调用

文章目录 概念整理基本概念概念助记前提RPC与HTTP类比RPC接口类的一些理解 实例代码主体结构父项目公共接口项目提供者项目项目结构POM文件实现配置文件实现公共接口实现程序入口配置启动项目检查是否可以注入到Nacos 消费者项目项目结构POM文件实现配置文件实现注册RPC服务类实…

wordpress按不同页调用不同的标题3种形式

在WordPress中&#xff0c;可以通过多种方式根据不同的页面调用不同的标题。这通常用于实现SEO优化、自定义页面标题或根据页面类型显示不同的标题内容。 使用wp_title函数 wp_title函数用于在HTML的title标签中输出页面标题。你可以通过修改主题的header.php文件来实现自定义…

1.2 Kaggle大白话:Eedi竞赛Transformer框架解决方案02-GPT_4o生成训练集缺失数据

目录 0. 本栏目竞赛汇总表1. 本文主旨2. AI工程架构3. 数据预处理模块3.1 配置数据路径和处理参数3.2 配置API参数3.3 配置输出路径 4. AI并行处理模块4.1 定义LLM客户端类4.2 定义数据处理函数4.3 定义JSON保存函数4.4 定义数据分片函数4.5 定义分片处理函数4.5 定义文件名排序…

第6章 数据工程(二)

6.3 数据治理和建模 数据治理是开展数据价值化活动的基础&#xff0c;关注对数字要素的管控能力覆盖组织对数据相关活动的统筹、评估、指导和监督等工作&#xff0c;需要重点关注元数据、数据标准化、数据质量数据模型和数据建模等方面的内容。 6.3.1 元数据 元数据是关于数…

量子计算可能改变世界的四种方式

世界各地的组织和政府正将数十亿美元投入到量子研究与开发中&#xff0c;谷歌、微软和英特尔等公司都在竞相实现量子霸权。 这其中的利害关系重大&#xff0c;有这么多重要的参与者&#xff0c;量子计算机的问世可能指日可待。 为做好准备&#xff0c;&#xff0c;我们必须了…

Storage Gateway:解锁企业混合云存储的智能钥匙

在数字化转型的浪潮中&#xff0c;企业数据量呈指数级增长&#xff0c;传统本地存储面临成本高、扩展难、管理复杂等挑战。如何实现本地基础设施与云端的无缝协同&#xff0c;构建灵活、安全且经济的存储架构&#xff1f;AWS Storage Gateway 作为混合云存储的核心枢纽&#xf…

健康检查、k8s探针、Grails+Liquibase框架/health 404 Not Found排查及解决

概述 健康检查对于一个pod而言&#xff0c;其重要性不言而喻。 k8s通过探针来实现健康检查。 探针 k8s提供三种探针&#xff1a; 存活探针&#xff1a;livenessProbe就绪探针&#xff1a;readinessProbe启动探针&#xff1a;startupProbe 存活探针 存活探针决定何时重启…

ChatGPT入驻Safari,AI搜索时代加速到来

2月25日&#xff0c;人工智能领域巨头OpenAI宣布了一项重磅更新&#xff1a;为其广受欢迎的ChatGPT应用新增Safari浏览器扩展功能&#xff0c;并支持用户将ChatGPT设置为Safari地址栏的默认搜索引擎。这一举措标志着OpenAI在将ChatGPT整合进用户日常网络浏览体验方面迈出了重要…