2 분 소요

실습 예제: I-JEPA official repository 논문 리뷰: I-JEPA 논문 리뷰

0. I-JEPA

I-JEPA는 the target block representations를 single context block으로 부터 predict 하는 것을 목표로 하는 self-supervised learning architecture이다.

따라서 target block representation을 만들기 위한 target encoder와 single context block을 만들기 위한 context encoder, 그리고 predict를 하기 위한 predictor가 필요하다.

이번 코드 리뷰는 공식 repo의 src/train.py 아래 5가지로 나뉘며 필요한 부분을 주석으로 설명한다.

Epoch loop    
  └─ Iteration loop (배치마다)   
    ├─ 1. 데이터 로드 (이미지 + 마스크)   
    ├─ 2. Forward (Target / Context)   
    ├─ 3. Loss 계산   
    ├─ 4. Backward & Optimizer step   
    └─ 5. Target encoder momentum update (EMA)   

1. 데이터 로드 (이미지 + 마스크)

# -- TRAINING LOOP
for epoch in range(start_epoch, num_epochs):
    logger.info('Epoch %d' % (epoch + 1))

    # -- update distributed-data-loader epoch
    unsupervised_sampler.set_epoch(epoch)

    loss_meter = AverageMeter()
    maskA_meter = AverageMeter()
    maskB_meter = AverageMeter()
    time_meter = AverageMeter()

    for itr, (udata, masks_enc, masks_pred) in enumerate(unsupervised_loader):

        def load_imgs():
            # -- unsupervised imgs
            imgs = udata[0].to(device, non_blocking=True)
            masks_1 = [u.to(device, non_blocking=True) for u in masks_enc]
            masks_2 = [u.to(device, non_blocking=True) for u in masks_pred]
            return (imgs, masks_1, masks_2)

        # 원본 이미지, context 위치 정보, target 위치 정보
        imgs, masks_enc, masks_pred = load_imgs() 
        maskA_meter.update(len(masks_enc[0][0]))
        maskB_meter.update(len(masks_pred[0][0]))

        def train_step():
            _new_lr = scheduler.step()
            _new_wd = wd_scheduler.step()

2. Forward (Target / Context)

            # --
            def forward_target():
                # target encoder는 EMA based optimization이라 grad 계산 필요 없음.
                with torch.no_grad():
                    h = target_encoder(imgs)
                    # normalize over feature-dim
                    # 각 토큰 벡터를 독립적으로 평균 0, 분산 1로 만들어 예측 타깃의 스케일을 고정
                    h = F.layer_norm(h, (h.size(-1),)) 
                    B = len(h)
                    # -- create targets (masked regions of h)
                    # target embedding에서 예측 대상 위치(masks_pred)만 추출
                    h = apply_masks(h, masks_pred)
                    
                    # context block 수(len(masks_enc))만큼 target을 복제하여 shape 맞춤.
                    # 현재 공식 구현에서는 len(masks_enc)==1이지만,
                    # multi-context 확장을 고려한 일반화 코드임.
                    h = repeat_interleave_batch(
                        h, B, 
                        repeat=len(masks_enc)
                      )
                    return h

            def forward_context():
                # 인코딩 후
                z = encoder(imgs, masks_enc)
                # context embedding, context 위치 정보, target 위치 정보를 이용하여 예측
                z = predictor(z, masks_enc, masks_pred)
                return z

3. Loss 계산

            def loss_fn(z, h):
                # 논문은 L2이나 실제 구현으로는 L1
                # gradient 안정성을 위한 것으로 예측됨. 
                loss = F.smooth_l1_loss(z, h)
                loss = AllReduce.apply(loss)
                return loss

4. Backward & Optimizer step

            # Step 1. Forward
            with torch.cuda.amp.autocast(
                dtype=torch.bfloat16, 
                enabled=use_bfloat16
              ):
                h = forward_target()
                z = forward_context()
                loss = loss_fn(z, h)

            #  Step 2. Backward & step (context encoder와 prediction encoder만 gradient based optimization)
            if use_bfloat16:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            grad_stats = grad_logger(
                encoder.named_parameters()
              )
            optimizer.zero_grad()

5. Target encoder momentum update (EMA)

            # Step 3. momentum update of target encoder는 EMA based optimization
            with torch.no_grad():
                m = next(momentum_scheduler)
                for param_q, param_k in zip(
                    encoder.parameters(), # 파라미터를 들고 와서
                    target_encoder.parameters()
                  ):
                    # θk​←m⋅θk​+(1−m)⋅θq​
                    param_k.data.mul_(m).add_(
                        (1.-m) * param_q.detach().data 
                      )

            return (float(loss), _new_lr, _new_wd, grad_stats)