[코드 리뷰] I-JEPA-(1) overall train code
실습 예제: I-JEPA official repository 논문 리뷰: I-JEPA 논문 리뷰
0. I-JEPA
I-JEPA는 the target block representations를 single context block으로 부터 predict 하는 것을 목표로 하는 self-supervised learning architecture이다.
따라서 target block representation을 만들기 위한 target encoder와 single context block을 만들기 위한 context encoder, 그리고 predict를 하기 위한 predictor가 필요하다.
이번 코드 리뷰는 공식 repo의 src/train.py 아래 5가지로 나뉘며 필요한 부분을 주석으로 설명한다.
Epoch loop
└─ Iteration loop (배치마다)
├─ 1. 데이터 로드 (이미지 + 마스크)
├─ 2. Forward (Target / Context)
├─ 3. Loss 계산
├─ 4. Backward & Optimizer step
└─ 5. Target encoder momentum update (EMA)
1. 데이터 로드 (이미지 + 마스크)
# -- TRAINING LOOP
for epoch in range(start_epoch, num_epochs):
logger.info('Epoch %d' % (epoch + 1))
# -- update distributed-data-loader epoch
unsupervised_sampler.set_epoch(epoch)
loss_meter = AverageMeter()
maskA_meter = AverageMeter()
maskB_meter = AverageMeter()
time_meter = AverageMeter()
for itr, (udata, masks_enc, masks_pred) in enumerate(unsupervised_loader):
def load_imgs():
# -- unsupervised imgs
imgs = udata[0].to(device, non_blocking=True)
masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
return (imgs, masks_1, masks_2)
# 원본 이미지, context 위치 정보, target 위치 정보
imgs, masks_enc, masks_pred = load_imgs()
maskA_meter.update(len(masks_enc[0][0]))
maskB_meter.update(len(masks_pred[0][0]))
def train_step():
_new_lr = scheduler.step()
_new_wd = wd_scheduler.step()
2. Forward (Target / Context)
# --
def forward_target():
# target encoder는 EMA based optimization이라 grad 계산 필요 없음.
with torch.no_grad():
h = target_encoder(imgs)
# normalize over feature-dim
# 각 토큰 벡터를 독립적으로 평균 0, 분산 1로 만들어 예측 타깃의 스케일을 고정
h = F.layer_norm(h, (h.size(-1),))
B = len(h)
# -- create targets (masked regions of h)
# target embedding에서 예측 대상 위치(masks_pred)만 추출
h = apply_masks(h, masks_pred)
# context block 수(len(masks_enc))만큼 target을 복제하여 shape 맞춤.
# 현재 공식 구현에서는 len(masks_enc)==1이지만,
# multi-context 확장을 고려한 일반화 코드임.
h = repeat_interleave_batch(
h, B,
repeat=len(masks_enc)
)
return h
def forward_context():
# 인코딩 후
z = encoder(imgs, masks_enc)
# context embedding, context 위치 정보, target 위치 정보를 이용하여 예측
z = predictor(z, masks_enc, masks_pred)
return z
3. Loss 계산
def loss_fn(z, h):
# 논문은 L2이나 실제 구현으로는 L1
# gradient 안정성을 위한 것으로 예측됨.
loss = F.smooth_l1_loss(z, h)
loss = AllReduce.apply(loss)
return loss
4. Backward & Optimizer step
# Step 1. Forward
with torch.cuda.amp.autocast(
dtype=torch.bfloat16,
enabled=use_bfloat16
):
h = forward_target()
z = forward_context()
loss = loss_fn(z, h)
# Step 2. Backward & step (context encoder와 prediction encoder만 gradient based optimization)
if use_bfloat16:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
grad_stats = grad_logger(
encoder.named_parameters()
)
optimizer.zero_grad()
5. Target encoder momentum update (EMA)
# Step 3. momentum update of target encoder는 EMA based optimization
with torch.no_grad():
m = next(momentum_scheduler)
for param_q, param_k in zip(
encoder.parameters(), # 파라미터를 들고 와서
target_encoder.parameters()
):
# θk←m⋅θk+(1−m)⋅θq
param_k.data.mul_(m).add_(
(1.-m) * param_q.detach().data
)
return (float(loss), _new_lr, _new_wd, grad_stats)