论文:《LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics》(arXiv:2511.08544,Randall Balestriero & Yann LeCun)
代码:本仓库将论文中的 Sketched Isotropic Gaussian Regularization (SIGReg) 实例化为 PyTorch 模块,可直接嵌入任意自监督预训练框架。
README 中的 Key Features 概要也与上述对应,可用于 README ↔ 论文/代码的快速对照:
scripts/launch_*.md 中仅需设置 bstat_lambda。lejepa/multivariate/slicing.py 的随机切片实现。global_step 和 torch.distributed all_reduce 保持多卡确定性。SIGReg (Sketched Isotropic Gaussian Regularization) 的核心直觉是:**与其通过对比损失(Contrastive Loss)去“拉近正例、推远负例”,不如直接让所有数据的特征分布塌缩成一个各向同性的高斯分布(Isotropic Gaussian)。**
论文证明,如果特征空间 $Z$ 的分布 $p(z)$ 满足 $p(z) \propto e^{-\|z\|^2/2}$(标准正态),则下游线性分类任务的风险存在显式上界。即:特征分布越接近各向同性高斯,线性可分性越好。
在高维空间直接计算分布距离(如 KL 散度或 Wasserstein 距离)非常困难。LeJEPA 利用 Radon 变换定理:如果一个高维分布在所有 1D 投影(切片)上都是标准正态分布,那么它本身就是标准多元正态分布。
公式化描述:
给定 batch 嵌入 $Z \in \mathbb{R}^{B \times D}$,随机采样 $K$ 个单位方向向量 $u_k \in \mathbb{S}^{D-1}$。
计算投影:$y_{b,k} = z_b^T u_k$。
损失函数为所有投影分布与标准正态 $N(0,1)$ 的差异之和: $$ \mathcal{L}_{\text{SIGReg}} = \frac{1}{K} \sum_{k=1}^K \text{Dist}( \{y_{b,k}\}_{b=1}^B, \mathcal{N}(0,1) ) $$
论文中默认使用 EppsPulley 统计量作为 $\text{Dist}$,它是基于特征函数(Characteristic Function)的距离。相比 KS 或 AD 检验,它更适合梯度下降优化。
$$ T_{EP} = \int \left| \frac{1}{N} \sum_{j=1}^N e^{i t y_j} - e^{-t^2/2} \right|^2 w(t) dt $$
代码实现(lejepa/univariate/epps_pulley.py)通过离散化积分点 $t$ 来高效计算该损失。
README 中的 GIF / 图像与性能表也搬运到此,方便直接浏览:
![]() |
![]() |
![]() |
![]() |
| shots | model | params | pretrain | epochs | DTD | aircr. | cars | cifar10 | cifar100 | flowers102 | food | pets | avg. |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 1 | LeJEPA ViT-L | 304M | IN-1K | 100 | 33.21 | 9.37 | 3.40 | 51.65 | 27.01 | 48.53 | 17.14 | 46.11 | 29.55 |
| 1 | LeJEPA ConvNeXtV2-H | 660M | IN-1K | 100 | 32.15 | 8.07 | 4.28 | 50.95 | 31.48 | 48.74 | 17.95 | 58.98 | 31.58 |
| 1 | I-JEPA ViT-H | 632M | IN-1K | 300 | 27.71 | 9.86 | 4.33 | 56.52 | 30.58 | 44.69 | 14.53 | 53.38 | 30.20 |
| 10 | LeJEPA ViT-L | 304M | IN-1K | 100 | 64.72 | 35.25 | 22.25 | 85.15 | 59.77 | 92.53 | 50.90 | 77.00 | 60.95 |
| 10 | LeJEPA ConvNeXtV2-H | 660M | IN-1K | 100 | 61.84 | 30.67 | 24.46 | 85.74 | 63.29 | 91.78 | 49.32 | 78.53 | 60.70 |
| 10 | I-JEPA ViT-H | 632M | IN-1K | 300 | 57.68 | 33.82 | 21.96 | 88.77 | 66.42 | 88.24 | 43.97 | 83.23 | 60.51 |
| all | LeJEPA ViT-L | 304M | IN-1K | 100 | 78.30 | 57.01 | 57.28 | 96.50 | 83.71 | 91.21 | 82.05 | 89.74 | 79.48 |
| all | LeJEPA ConvNeXtV2-H | 660M | IN-1K | 100 | 76.60 | 52.99 | 54.88 | 96.15 | 81.34 | 91.11 | 77.64 | 89.76 | 77.56 |
| all | I-JEPA ViT-H | 632M | IN-1K | 300 | 73.32 | 56.61 | 54.47 | 97.54 | 86.42 | 86.47 | 81.02 | 92.11 | 78.50 |
根据 README.md 与 scripts/launch_*.md,LeJEPA 的输入管线遵循 DINO 风格的多视角增强,随后由 backbone + projector 输出嵌入,再通过 SIGReg 约束。
| Global Views ×2 | Local Views ×6 |
|---|---|
| RandomResizedCrop 224²,scale 0.3–1.0 | RandomResizedCrop 98²,scale 0.05–0.3 |
| RandomHorizontalFlip 0.5 | RandomHorizontalFlip 0.5 |
| ColorJitter 0.8(亮0.4/对比0.4/饱0.2/色0.1) | 同上 |
| RandomGrayscale 0.2 | RandomGrayscale 0.2 |
| GaussianBlur 0.5 + Solarize 0.2 | 同上 |
| Normalize(mean,std) | Normalize(mean,std) |
scripts/launch_inet10.py 通过 stable_pretraining 批量扫描。embedding_dim = 512、projector_dim = 512。weight_decay=5e-2,ResNet 使用 5e-4。5e-4,线性 warmup 后采用 cosine 衰减,终值为初始的 /1000。核心思想:将高维嵌入 $X \in \mathbb{R}^{N \times D}$ 投影到大量随机 1D 切片上,应用单变量统计检验,使每个切片都符合标准正态,从而逼近各向同性高斯。
from torch import nn
from torch.distributed._functional_collectives import all_reduce as functional_all_reduce
def all_reduce(x, op="AVG"):
if dist.is_available() and dist.is_initialized():
return functional_all_reduce(x, op.lower(), dist.group.WORLD)
return x
class SlicingUnivariateTest(nn.Module):
def __init__(self, univariate_test, num_slices, reduction="mean",
sampler="gaussian", clip_value=None):
super().__init__()
self.univariate_test = univariate_test
self.num_slices = num_slices
self.reduction = reduction
self.clip_value = clip_value
self.register_buffer("global_step", torch.zeros((), dtype=torch.long))
self._generator = None
self._generator_device = None
def forward(self, x):
with torch.no_grad():
seed = all_reduce(self.global_step.clone(), op="MAX").item()
g = self._get_generator(x.device, seed)
proj = torch.randn((x.size(-1), self.num_slices),
device=x.device, generator=g)
proj /= proj.norm(p=2, dim=0)
self.global_step.add_(1)
stats = self.univariate_test(x @ proj)
if self.clip_value is not None:
stats = stats.clamp_min(self.clip_value)
return stats.mean() if self.reduction == "mean" else stats
要点:
global_step 并复用同一随机数生成器,确保多 GPU 推理得到完全一致的投影矩阵 $A$。class EppsPulley(UnivariateTest):
def __init__(self, t_max=3.0, n_points=17):
super().__init__()
t = torch.linspace(0, t_max, n_points)
dt = t_max / (n_points - 1)
weights = torch.full((n_points,), 2 * dt)
weights[[0, -1]] = dt
self.register_buffer("t", t)
self.register_buffer("phi", (-0.5 * t**2).exp())
self.register_buffer("weights", weights * self.phi)
def forward(self, x):
N = x.size(-2)
x_t = x.unsqueeze(-1) * self.t
cos_mean = torch.cos(x_t).mean(-3)
sin_mean = torch.sin(x_t).mean(-3)
cos_mean = self.dist_mean(cos_mean)
sin_mean = self.dist_mean(sin_mean)
err = (cos_mean - self.phi).square() + sin_mean.square()
return (err @ self.weights) * N * self.world_size
该检验比较经验特征函数与标准正态特征函数的差异,返回正标量作为“偏离高斯”的能量,正好可直接作为损失。
class BHEP(MultivariatetTest):
def __init__(self, beta=0.1):
super().__init__()
if beta <= 0:
raise ValueError("beta must be positive")
self.beta = beta
def forward(self, x):
x = self.prepare_data(x)
N, D = x.shape
beta2 = self.beta ** 2
sq_norm = x.square().sum(dim=1)
dist = -2 * x @ x.T + sq_norm[:, None] + sq_norm[None, :]
lhs = torch.exp(-0.5 * beta2 * dist).sum() / (N ** 2)
scaling = 2 / ((1 + beta2) ** (D / 2))
rhs = scaling * torch.exp(sq_norm * (-beta2 / (2 + 2 * beta2))).sum() / N
constant = 1 / ((1 + 2 * beta2) ** (D / 2))
return lhs - rhs + constant
仓库还提供 Henze–Zirkler、Cramér–von Mises、Anderson–Darling、Shapiro–Wilk 等多种检验,可随意替换,以适配不同数据分布或灵敏度。
import lejepa
backbone = build_model(...)
projector = build_projector(...)
univar = lejepa.univariate.EppsPulley(t_max=5.0, n_points=21)
sigreg = lejepa.multivariate.SlicingUnivariateTest(
univariate_test=univar,
num_slices=1024,
reduction="mean",
)
def training_step(images):
views = augment(images) # 2 global + 6 local
embeddings = projector(backbone(views)) # [B, tokens, dim]
loss = sigreg(embeddings)
loss.backward()
optimizer.step()
(上例与 README 保持一致,可直接复制到 Lightning / PyTorch 训练脚本中。)
基于 scripts/launch_*.md 和论文数据,以下是关键超参数对模型性能的影响总结。
launch_proj_ablation.md)| Embedding Dim | Projector Dim | Num Slices | 说明 |
|---|---|---|---|
| 512 | 64 | 1024 | 投影头过窄,信息瓶颈严重,性能下降。 |
| 512 | 512 | 1024 | 默认配置。特征维度与投影一致,平衡性能与计算量。 |
| 512 | 2048 | 4096 | 投影头过宽虽能保留更多信息,但计算切片统计量开销增大,且收益边际递减。 |
launch_views_ablation.md)| Local Views | Global Views | Batch Size | Linear Probe (IN-1K) |
|---|---|---|---|
| 0 | 2 | 512 | Baseline。仅学习全局语义,局部细节丢失。 |
| 6 | 2 | 512 | 推荐配置 (LeJEPA Default)。多尺度增强显著提升分类精度。 |
| 8 | 2 | 640 | 更多局部视图通常有益,但显存占用线性增加。 |
注意:LeJEPA 发现 SIGReg 使得大 Batch Size (如 4096+) 训练极其稳定,不像对比学习方法那样需要精细调整 Learning Rate Scaling。
bstat_lambda 是 LeJEPA 唯一 需要调整的权衡参数(Trade-off)。
0.01 ~ 0.1(ImageNet 任务通常用 0.05)。| 特性 | MAE (Masked Autoencoder) | DINO / DINOv2 | I-JEPA | LeJEPA (Ours) |
|---|---|---|---|---|
| 核心目标 | 像素级重建 (Reconstruction) | 多视图对比 / 蒸馏 (Self-Distillation) | 特征级预测 (Feature Prediction) | 特征分布正则化 (SIGReg) |
| Heuristics | Mask 比例, 解码器深度 | Teacher-Student, EMA, Center/Sinkhorn, Temp. Sched. | Target Encoder EMA, Mask 策略 | 无 (Zero Heuristics) |
| 训练稳定性 | 高 (但需长周期收敛) | 中 (易崩塌,依赖 Centering/Temp) | 高 | 极高 (理论保证收敛) |
| 计算开销 | 低 (Encoder 只看 25% patches) | 高 (2x Forward + EMA) | 中 (Predictor 开销) | 低 (O(N) 统计量计算) |
| 显存占用 | 低 | 高 (存 Teacher 副本) | 中 | 极低 (无 EMA 副本, 无 Predictor) |
| 扩展性 | 强 (Vision 专用) | 强 | 强 | 极强 (可用于非图像模态) |
总结: LeJEPA 用统计学约束替代了工程上的 Trick(如 EMA、Stop-gradient),在保持高性能的同时,大幅降低了调参门槛和显存需求。
1e-6。eval/ 文件夹提供原图与 PCA 回放、GIF 动画,帮助验证表示是否捕获全局语义。LeJEPA 预训练的 Backbone 可以像 MAE 或 DINO 那样直接加载到 Detectron2 或 MMSegmentation 中。
torch.load 读取 checkpoint,提取 backbone. 开头的权重。如果你要在非 ImageNet 数据(如医疗影像、遥感图像)上训练 LeJEPA:
bstat_lambda (0.1 ~ 0.5)。数据少时更容易过拟合,需要更强的正则化防止塌缩。bstat_lambda (0.01 ~ 0.05)。数据本身的多样性已经提供了很好的约束。torch.distributed 的 all_reduce 聚合,确保 batch 拆分后仍等价于单机结果。SlicingUnivariateTest 用共享 global_step 控制随机投影,避免不同 GPU 采到不同切片导致噪声。torch.no_grad() 下生成切片和积分点,不会额外占用显存;只需反向传播单个标量 loss。无需查阅 README 即可在此获取安装与最小示例:
stable_pretraining 以使用仓库脚本。pip install lejepa stable-pretraining hydra-core
最小 SIGReg 使用示例:
import lejepa
univariate_test = lejepa.univariate.EppsPulley(num_points=17)
loss_fn = lejepa.multivariate.SlicingUnivariateTest(
univariate_test=univariate_test,
num_slices=1024,
reduction="mean"
)
embeddings = backbone_outputs # 假设 shape = [batch, dim]
loss = loss_fn(embeddings)
loss.backward()
pip install lejepa stable-pretraining hydra-core,确保 PyTorch≥1.10、Python≥3.8。DataLoader。scripts/launch_inet10.py 中的 Hydra 参数(bstat_name、bstat_lambda、num_slices 等),按需要枚举骨干。loss_sigreg 即可监控训练。具体的 Hydra 启动示例(与 README 中保持一致)如下:
HYDRA_FULL_ERROR=1 python scripts/je.py \
--config-dir scripts/configs \
--config-name base \
+accelerator=single_gpu_sc \
++dataset_name="inet10" \
++bstat_name="epps_pulley" \
++bstat_num_slices=1000 \
++bstat_lambda=0.05 \
++embedding_dim=512 \
++projector_dim=512 \
++lr=3e-3 \
++weight_decay=3e-2 \
++n_views=8 \
++teacher_student=false \
++predictor=false
该命令直接复用了 README 的建议参数,可根据需要修改 backbone/节点数/视角策略。
LeJEPA 证明了“好的表征 = 各向同性高斯”这一简单假设的强大威力。它不仅在 ImageNet 上取得了 SOTA 级别的性能,更重要的是,它为自监督学习提供了一个可解释、可证明、无 Trick 的新范式。
未来方向包括: