LeJEPA:论文与代码综述

目录

论文:《LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics》(arXiv:2511.08544,Randall Balestriero & Yann LeCun)
代码:本仓库将论文中的 Sketched Isotropic Gaussian Regularization (SIGReg) 实例化为 PyTorch 模块,可直接嵌入任意自监督预训练框架。

1. 论文要点速览

README 中的 Key Features 概要也与上述对应,可用于 README ↔ 论文/代码的快速对照:

2. SIGReg:理论直观与数学推导

SIGReg (Sketched Isotropic Gaussian Regularization) 的核心直觉是:**与其通过对比损失(Contrastive Loss)去“拉近正例、推远负例”,不如直接让所有数据的特征分布塌缩成一个各向同性的高斯分布(Isotropic Gaussian)。**

2.1 为什么是高斯?

论文证明,如果特征空间 $Z$ 的分布 $p(z)$ 满足 $p(z) \propto e^{-\|z\|^2/2}$(标准正态),则下游线性分类任务的风险存在显式上界。即:特征分布越接近各向同性高斯,线性可分性越好。

2.2 降维打击:随机切片 (Slicing)

在高维空间直接计算分布距离(如 KL 散度或 Wasserstein 距离)非常困难。LeJEPA 利用 Radon 变换定理如果一个高维分布在所有 1D 投影(切片)上都是标准正态分布,那么它本身就是标准多元正态分布。

公式化描述:

给定 batch 嵌入 $Z \in \mathbb{R}^{B \times D}$,随机采样 $K$ 个单位方向向量 $u_k \in \mathbb{S}^{D-1}$。

计算投影:$y_{b,k} = z_b^T u_k$。

损失函数为所有投影分布与标准正态 $N(0,1)$ 的差异之和: $$ \mathcal{L}_{\text{SIGReg}} = \frac{1}{K} \sum_{k=1}^K \text{Dist}( \{y_{b,k}\}_{b=1}^B, \mathcal{N}(0,1) ) $$

2.3 距离度量:Epps-Pulley 检验

论文中默认使用 EppsPulley 统计量作为 $\text{Dist}$,它是基于特征函数(Characteristic Function)的距离。相比 KS 或 AD 检验,它更适合梯度下降优化。

$$ T_{EP} = \int \left| \frac{1}{N} \sum_{j=1}^N e^{i t y_j} - e^{-t^2/2} \right|^2 w(t) dt $$

代码实现(lejepa/univariate/epps_pulley.py)通过离散化积分点 $t$ 来高效计算该损失。

3. Demo 与性能速览

README 中的 GIF / 图像与性能表也搬运到此,方便直接浏览:

demo 1 demo 2 demo 3
shotsmodelparamspretrainepochs DTDaircr.carscifar10cifar100flowers102foodpetsavg.
1LeJEPA ViT-L304MIN-1K10033.219.373.4051.6527.0148.5317.1446.1129.55
1LeJEPA ConvNeXtV2-H660MIN-1K10032.158.074.2850.9531.4848.7417.9558.9831.58
1I-JEPA ViT-H632MIN-1K30027.719.864.3356.5230.5844.6914.5353.3830.20
10LeJEPA ViT-L304MIN-1K10064.7235.2522.2585.1559.7792.5350.9077.0060.95
10LeJEPA ConvNeXtV2-H660MIN-1K10061.8430.6724.4685.7463.2991.7849.3278.5360.70
10I-JEPA ViT-H632MIN-1K30057.6833.8221.9688.7766.4288.2443.9783.2360.51
allLeJEPA ViT-L304MIN-1K10078.3057.0157.2896.5083.7191.2182.0589.7479.48
allLeJEPA ConvNeXtV2-H660MIN-1K10076.6052.9954.8896.1581.3491.1177.6489.7677.56
allI-JEPA ViT-H632MIN-1K30073.3256.6154.4797.5486.4286.4781.0292.1178.50

4. 预训练数据流与配置

根据 README.mdscripts/launch_*.md,LeJEPA 的输入管线遵循 DINO 风格的多视角增强,随后由 backbone + projector 输出嵌入,再通过 SIGReg 约束。

2.1 视角生成(多 Crop)

Global Views ×2Local Views ×6
RandomResizedCrop 224²,scale 0.3–1.0RandomResizedCrop 98²,scale 0.05–0.3
RandomHorizontalFlip 0.5RandomHorizontalFlip 0.5
ColorJitter 0.8(亮0.4/对比0.4/饱0.2/色0.1)同上
RandomGrayscale 0.2RandomGrayscale 0.2
GaussianBlur 0.5 + Solarize 0.2同上
Normalize(mean,std)Normalize(mean,std)

2.2 Backbone + Projector

2.3 优化器与调度

5. SIGReg 损失在代码中的实现

核心思想:将高维嵌入 $X \in \mathbb{R}^{N \times D}$ 投影到大量随机 1D 切片上,应用单变量统计检验,使每个切片都符合标准正态,从而逼近各向同性高斯。

5.1 随机切片 + 分布同步

from torch import nn
from torch.distributed._functional_collectives import all_reduce as functional_all_reduce

def all_reduce(x, op="AVG"):
    if dist.is_available() and dist.is_initialized():
        return functional_all_reduce(x, op.lower(), dist.group.WORLD)
    return x

class SlicingUnivariateTest(nn.Module):
    def __init__(self, univariate_test, num_slices, reduction="mean",
                 sampler="gaussian", clip_value=None):
        super().__init__()
        self.univariate_test = univariate_test
        self.num_slices = num_slices
        self.reduction = reduction
        self.clip_value = clip_value
        self.register_buffer("global_step", torch.zeros((), dtype=torch.long))
        self._generator = None
        self._generator_device = None

    def forward(self, x):
        with torch.no_grad():
            seed = all_reduce(self.global_step.clone(), op="MAX").item()
            g = self._get_generator(x.device, seed)
            proj = torch.randn((x.size(-1), self.num_slices),
                               device=x.device, generator=g)
            proj /= proj.norm(p=2, dim=0)
            self.global_step.add_(1)

        stats = self.univariate_test(x @ proj)
        if self.clip_value is not None:
            stats = stats.clamp_min(self.clip_value)
        return stats.mean() if self.reduction == "mean" else stats

要点:

5.2 单变量 Epps–Pulley 统计

class EppsPulley(UnivariateTest):
    def __init__(self, t_max=3.0, n_points=17):
        super().__init__()
        t = torch.linspace(0, t_max, n_points)
        dt = t_max / (n_points - 1)
        weights = torch.full((n_points,), 2 * dt)
        weights[[0, -1]] = dt
        self.register_buffer("t", t)
        self.register_buffer("phi", (-0.5 * t**2).exp())
        self.register_buffer("weights", weights * self.phi)

    def forward(self, x):
        N = x.size(-2)
        x_t = x.unsqueeze(-1) * self.t
        cos_mean = torch.cos(x_t).mean(-3)
        sin_mean = torch.sin(x_t).mean(-3)
        cos_mean = self.dist_mean(cos_mean)
        sin_mean = self.dist_mean(sin_mean)
        err = (cos_mean - self.phi).square() + sin_mean.square()
        return (err @ self.weights) * N * self.world_size

该检验比较经验特征函数与标准正态特征函数的差异,返回正标量作为“偏离高斯”的能量,正好可直接作为损失。

5.3 其他统计(BHEP 示例)

class BHEP(MultivariatetTest):
    def __init__(self, beta=0.1):
        super().__init__()
        if beta <= 0:
            raise ValueError("beta must be positive")
        self.beta = beta

    def forward(self, x):
        x = self.prepare_data(x)
        N, D = x.shape
        beta2 = self.beta ** 2
        sq_norm = x.square().sum(dim=1)
        dist = -2 * x @ x.T + sq_norm[:, None] + sq_norm[None, :]
        lhs = torch.exp(-0.5 * beta2 * dist).sum() / (N ** 2)
        scaling = 2 / ((1 + beta2) ** (D / 2))
        rhs = scaling * torch.exp(sq_norm * (-beta2 / (2 + 2 * beta2))).sum() / N
        constant = 1 / ((1 + 2 * beta2) ** (D / 2))
        return lhs - rhs + constant

仓库还提供 Henze–Zirkler、Cramér–von Mises、Anderson–Darling、Shapiro–Wilk 等多种检验,可随意替换,以适配不同数据分布或灵敏度。

5.4 将 SIGReg 接入训练循环

import lejepa

backbone = build_model(...)
projector = build_projector(...)

univar = lejepa.univariate.EppsPulley(t_max=5.0, n_points=21)
sigreg = lejepa.multivariate.SlicingUnivariateTest(
    univariate_test=univar,
    num_slices=1024,
    reduction="mean",
)

def training_step(images):
    views = augment(images)         # 2 global + 6 local
    embeddings = projector(backbone(views))  # [B, tokens, dim]
    loss = sigreg(embeddings)
    loss.backward()
    optimizer.step()

(上例与 README 保持一致,可直接复制到 Lightning / PyTorch 训练脚本中。)

6. Ablation Study (消融实验)

基于 scripts/launch_*.md 和论文数据,以下是关键超参数对模型性能的影响总结。

6.1 投影头与切片数量 (launch_proj_ablation.md)

Embedding DimProjector DimNum Slices说明
512641024投影头过窄,信息瓶颈严重,性能下降。
5125121024默认配置。特征维度与投影一致,平衡性能与计算量。
51220484096投影头过宽虽能保留更多信息,但计算切片统计量开销增大,且收益边际递减。

6.2 视图策略与 Batch Size (launch_views_ablation.md)

Local ViewsGlobal ViewsBatch SizeLinear Probe (IN-1K)
02512Baseline。仅学习全局语义,局部细节丢失。
62512推荐配置 (LeJEPA Default)。多尺度增强显著提升分类精度。
82640更多局部视图通常有益,但显存占用线性增加。

注意:LeJEPA 发现 SIGReg 使得大 Batch Size (如 4096+) 训练极其稳定,不像对比学习方法那样需要精细调整 Learning Rate Scaling。

6.3 正则化强度 $\lambda$

bstat_lambda 是 LeJEPA 唯一 需要调整的权衡参数(Trade-off)。

7. 优势对比:LeJEPA vs DINO / I-JEPA / MAE

特性 MAE (Masked Autoencoder) DINO / DINOv2 I-JEPA LeJEPA (Ours)
核心目标 像素级重建 (Reconstruction) 多视图对比 / 蒸馏 (Self-Distillation) 特征级预测 (Feature Prediction) 特征分布正则化 (SIGReg)
Heuristics Mask 比例, 解码器深度 Teacher-Student, EMA, Center/Sinkhorn, Temp. Sched. Target Encoder EMA, Mask 策略 无 (Zero Heuristics)
训练稳定性 高 (但需长周期收敛) 中 (易崩塌,依赖 Centering/Temp) 极高 (理论保证收敛)
计算开销 低 (Encoder 只看 25% patches) 高 (2x Forward + EMA) 中 (Predictor 开销) 低 (O(N) 统计量计算)
显存占用 高 (存 Teacher 副本) 极低 (无 EMA 副本, 无 Predictor)
扩展性 强 (Vision 专用) 极强 (可用于非图像模态)

总结: LeJEPA 用统计学约束替代了工程上的 Trick(如 EMA、Stop-gradient),在保持高性能的同时,大幅降低了调参门槛和显存需求。

8. 推理与评估流程

9. 应用指南:迁移与调优

9.1 迁移到下游任务 (检测/分割)

LeJEPA 预训练的 Backbone 可以像 MAE 或 DINO 那样直接加载到 Detectron2 或 MMSegmentation 中。

  1. 提取权重: 使用 torch.load 读取 checkpoint,提取 backbone. 开头的权重。
  2. 移除 Projector: 丢弃 MLP Projector 和 SIGReg 模块,它们只在预训练时用于约束分布。
  3. LayerNorm 缩放: 由于 LeJEPA 强约束特征为 $N(0,1)$,下游任务初始化时建议不对 Backbone 输出再做大幅度的 Scale/Shift 操作,或者使用学习率 warmup。

9.2 适配自定义数据集

如果你要在非 ImageNet 数据(如医疗影像、遥感图像)上训练 LeJEPA:

10. 分布式与确定性细节

11. 安装与快速上手

无需查阅 README 即可在此获取安装与最小示例:

pip install lejepa stable-pretraining hydra-core

最小 SIGReg 使用示例:

import lejepa

univariate_test = lejepa.univariate.EppsPulley(num_points=17)
loss_fn = lejepa.multivariate.SlicingUnivariateTest(
    univariate_test=univariate_test,
    num_slices=1024,
    reduction="mean"
)

embeddings = backbone_outputs  # 假设 shape = [batch, dim]
loss = loss_fn(embeddings)
loss.backward()

12. 复现步骤建议

  1. 环境:pip install lejepa stable-pretraining hydra-core,确保 PyTorch≥1.10、Python≥3.8。
  2. 数据:准备 ImageNet-1K 或更小的数据集(如 INet-10),按 README 的 multi-crop 规则构建 DataLoader
  3. 配置:参考 scripts/launch_inet10.py 中的 Hydra 参数(bstat_namebstat_lambdanum_slices 等),按需要枚举骨干。
  4. 训练:关闭 predictor/teacher 相关 flag,启用 bf16、AdamW、cosine 调度;记录 loss_sigreg 即可监控训练。
  5. 评估:按 README 的线性探针流程在 DTD、Aircraft、CIFAR、Flowers、Food、Pets 等数据集上对比 I-JEPA / DINO。

具体的 Hydra 启动示例(与 README 中保持一致)如下:

HYDRA_FULL_ERROR=1 python scripts/je.py \
  --config-dir scripts/configs \
  --config-name base \
  +accelerator=single_gpu_sc \
  ++dataset_name="inet10" \
  ++bstat_name="epps_pulley" \
  ++bstat_num_slices=1000 \
  ++bstat_lambda=0.05 \
  ++embedding_dim=512 \
  ++projector_dim=512 \
  ++lr=3e-3 \
  ++weight_decay=3e-2 \
  ++n_views=8 \
  ++teacher_student=false \
  ++predictor=false

该命令直接复用了 README 的建议参数,可根据需要修改 backbone/节点数/视角策略。

13. 结语与展望

LeJEPA 证明了“好的表征 = 各向同性高斯”这一简单假设的强大威力。它不仅在 ImageNet 上取得了 SOTA 级别的性能,更重要的是,它为自监督学习提供了一个可解释、可证明、无 Trick 的新范式。

未来方向包括: