<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="3.10.0">Jekyll</generator><link href="https://kitewatermelon.github.io/feed.xml" rel="self" type="application/atom+xml" /><link href="https://kitewatermelon.github.io/" rel="alternate" type="text/html" /><updated>2026-03-20T22:28:18+09:00</updated><id>https://kitewatermelon.github.io/feed.xml</id><title type="html">개발새발 01한 인생</title><subtitle>Computer Vision, Medical AI 논문 리뷰 및 개발 이야기를 다루는 기술 블로그입니다.</subtitle><author><name>YSPARK</name></author><entry><title type="html">[코드 리뷰] I-JEPA-(1) overall train code</title><link href="https://kitewatermelon.github.io/code-review/ijepa-1/" rel="alternate" type="text/html" title="[코드 리뷰] I-JEPA-(1) overall train code" /><published>2026-03-20T00:00:00+09:00</published><updated>2026-03-20T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/ijepa-1</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/ijepa-1/"><![CDATA[<p>실습 예제: <a href="https://github.com/facebookresearch/ijepa/tree/main">I-JEPA official repository</a>
논문 리뷰: <a href="/paper-review/ijepa/">I-JEPA 논문 리뷰</a></p>

<h3 id="0-i-jepa">0. I-JEPA</h3>

<p>I-JEPA는 the target block representations를 single context block으로 부터 predict 하는 것을 목표로 하는 self-supervised learning architecture이다.</p>

<p>따라서 target block representation을 만들기 위한 target encoder와 single context block을 만들기 위한 context encoder, 그리고 predict를 하기 위한 predictor가 필요하다.</p>

<p>이번 코드 리뷰는 공식 repo의 src/train.py 아래 5가지로 나뉘며 필요한 부분을 주석으로 설명한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">Epoch</span> <span class="n">loop</span>    
  <span class="err">└─</span> <span class="n">Iteration</span> <span class="n">loop</span> <span class="p">(</span><span class="n">배치마다</span><span class="p">)</span>   
    <span class="err">├─</span> <span class="mf">1.</span> <span class="n">데이터</span> <span class="n">로드</span> <span class="p">(</span><span class="n">이미지</span> <span class="o">+</span> <span class="n">마스크</span><span class="p">)</span>   
    <span class="err">├─</span> <span class="mf">2.</span> <span class="n">Forward</span> <span class="p">(</span><span class="n">Target</span> <span class="o">/</span> <span class="n">Context</span><span class="p">)</span>   
    <span class="err">├─</span> <span class="mf">3.</span> <span class="n">Loss</span> <span class="n">계산</span>   
    <span class="err">├─</span> <span class="mf">4.</span> <span class="n">Backward</span> <span class="o">&amp;</span> <span class="n">Optimizer</span> <span class="n">step</span>   
    <span class="err">└─</span> <span class="mf">5.</span> <span class="n">Target</span> <span class="n">encoder</span> <span class="n">momentum</span> <span class="n">update</span> <span class="p">(</span><span class="n">EMA</span><span class="p">)</span>   
</code></pre></div></div>

<h3 id="1-데이터-로드-이미지--마스크">1. 데이터 로드 (이미지 + 마스크)</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># -- TRAINING LOOP
</span><span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_epoch</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">):</span>
    <span class="n">logger</span><span class="p">.</span><span class="n">info</span><span class="p">(</span><span class="s">'Epoch %d'</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span>

    <span class="c1"># -- update distributed-data-loader epoch
</span>    <span class="n">unsupervised_sampler</span><span class="p">.</span><span class="n">set_epoch</span><span class="p">(</span><span class="n">epoch</span><span class="p">)</span>

    <span class="n">loss_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    <span class="n">maskA_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    <span class="n">maskB_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>
    <span class="n">time_meter</span> <span class="o">=</span> <span class="n">AverageMeter</span><span class="p">()</span>

    <span class="k">for</span> <span class="n">itr</span><span class="p">,</span> <span class="p">(</span><span class="n">udata</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">,</span> <span class="n">masks_pred</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">unsupervised_loader</span><span class="p">):</span>

        <span class="k">def</span> <span class="nf">load_imgs</span><span class="p">():</span>
            <span class="c1"># -- unsupervised imgs
</span>            <span class="n">imgs</span> <span class="o">=</span> <span class="n">udata</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
            <span class="n">masks_1</span> <span class="o">=</span> <span class="p">[</span><span class="n">u</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="n">masks_enc</span><span class="p">]</span>
            <span class="n">masks_2</span> <span class="o">=</span> <span class="p">[</span><span class="n">u</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">,</span> <span class="n">non_blocking</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span> <span class="k">for</span> <span class="n">u</span> <span class="ow">in</span> <span class="n">masks_pred</span><span class="p">]</span>
            <span class="k">return</span> <span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">masks_1</span><span class="p">,</span> <span class="n">masks_2</span><span class="p">)</span>

        <span class="c1"># 원본 이미지, context 위치 정보, target 위치 정보
</span>        <span class="n">imgs</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">,</span> <span class="n">masks_pred</span> <span class="o">=</span> <span class="n">load_imgs</span><span class="p">()</span> 
        <span class="n">maskA_meter</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">masks_enc</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>
        <span class="n">maskB_meter</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">masks_pred</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]))</span>

        <span class="k">def</span> <span class="nf">train_step</span><span class="p">():</span>
            <span class="n">_new_lr</span> <span class="o">=</span> <span class="n">scheduler</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
            <span class="n">_new_wd</span> <span class="o">=</span> <span class="n">wd_scheduler</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="2-forward-target--context">2. Forward (Target / Context)</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="c1"># --
</span>            <span class="k">def</span> <span class="nf">forward_target</span><span class="p">():</span>
                <span class="c1"># target encoder는 EMA based optimization이라 grad 계산 필요 없음.
</span>                <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
                    <span class="n">h</span> <span class="o">=</span> <span class="n">target_encoder</span><span class="p">(</span><span class="n">imgs</span><span class="p">)</span>
                    <span class="c1"># normalize over feature-dim
</span>                    <span class="c1"># 각 토큰 벡터를 독립적으로 평균 0, 분산 1로 만들어 예측 타깃의 스케일을 고정
</span>                    <span class="n">h</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">),))</span> 
                    <span class="n">B</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">h</span><span class="p">)</span>
                    <span class="c1"># -- create targets (masked regions of h)
</span>                    <span class="c1"># target embedding에서 예측 대상 위치(masks_pred)만 추출
</span>                    <span class="n">h</span> <span class="o">=</span> <span class="n">apply_masks</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">masks_pred</span><span class="p">)</span>
                    
                    <span class="c1"># context block 수(len(masks_enc))만큼 target을 복제하여 shape 맞춤.
</span>                    <span class="c1"># 현재 공식 구현에서는 len(masks_enc)==1이지만,
</span>                    <span class="c1"># multi-context 확장을 고려한 일반화 코드임.
</span>                    <span class="n">h</span> <span class="o">=</span> <span class="n">repeat_interleave_batch</span><span class="p">(</span>
                        <span class="n">h</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> 
                        <span class="n">repeat</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">masks_enc</span><span class="p">)</span>
                      <span class="p">)</span>
                    <span class="k">return</span> <span class="n">h</span>

            <span class="k">def</span> <span class="nf">forward_context</span><span class="p">():</span>
                <span class="c1"># 인코딩 후
</span>                <span class="n">z</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">imgs</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">)</span>
                <span class="c1"># context embedding, context 위치 정보, target 위치 정보를 이용하여 예측
</span>                <span class="n">z</span> <span class="o">=</span> <span class="n">predictor</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">masks_enc</span><span class="p">,</span> <span class="n">masks_pred</span><span class="p">)</span>
                <span class="k">return</span> <span class="n">z</span>
</code></pre></div></div>

<h3 id="3-loss-계산">3. Loss 계산</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">h</span><span class="p">):</span>
                <span class="c1"># 논문은 L2이나 실제 구현으로는 L1
</span>                <span class="c1"># gradient 안정성을 위한 것으로 예측됨. 
</span>                <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">smooth_l1_loss</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>
                <span class="n">loss</span> <span class="o">=</span> <span class="n">AllReduce</span><span class="p">.</span><span class="nb">apply</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
                <span class="k">return</span> <span class="n">loss</span>
</code></pre></div></div>

<h3 id="4-backward--optimizer-step">4. Backward &amp; Optimizer step</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="c1"># Step 1. Forward
</span>            <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">amp</span><span class="p">.</span><span class="n">autocast</span><span class="p">(</span>
                <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">bfloat16</span><span class="p">,</span> 
                <span class="n">enabled</span><span class="o">=</span><span class="n">use_bfloat16</span>
              <span class="p">):</span>
                <span class="n">h</span> <span class="o">=</span> <span class="n">forward_target</span><span class="p">()</span>
                <span class="n">z</span> <span class="o">=</span> <span class="n">forward_context</span><span class="p">()</span>
                <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span>

            <span class="c1">#  Step 2. Backward &amp; step (context encoder와 prediction encoder만 gradient based optimization)
</span>            <span class="k">if</span> <span class="n">use_bfloat16</span><span class="p">:</span>
                <span class="n">scaler</span><span class="p">.</span><span class="n">scale</span><span class="p">(</span><span class="n">loss</span><span class="p">).</span><span class="n">backward</span><span class="p">()</span>
                <span class="n">scaler</span><span class="p">.</span><span class="n">step</span><span class="p">(</span><span class="n">optimizer</span><span class="p">)</span>
                <span class="n">scaler</span><span class="p">.</span><span class="n">update</span><span class="p">()</span>
            <span class="k">else</span><span class="p">:</span>
                <span class="n">loss</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
                <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
            <span class="n">grad_stats</span> <span class="o">=</span> <span class="n">grad_logger</span><span class="p">(</span>
                <span class="n">encoder</span><span class="p">.</span><span class="n">named_parameters</span><span class="p">()</span>
              <span class="p">)</span>
            <span class="n">optimizer</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
</code></pre></div></div>

<h3 id="5-target-encoder-momentum-update-ema">5. Target encoder momentum update (EMA)</h3>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>            <span class="c1"># Step 3. momentum update of target encoder는 EMA based optimization
</span>            <span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
                <span class="n">m</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">momentum_scheduler</span><span class="p">)</span>
                <span class="k">for</span> <span class="n">param_q</span><span class="p">,</span> <span class="n">param_k</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span>
                    <span class="n">encoder</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="c1"># 파라미터를 들고 와서
</span>                    <span class="n">target_encoder</span><span class="p">.</span><span class="n">parameters</span><span class="p">()</span>
                  <span class="p">):</span>
                    <span class="c1"># θk​←m⋅θk​+(1−m)⋅θq​
</span>                    <span class="n">param_k</span><span class="p">.</span><span class="n">data</span><span class="p">.</span><span class="n">mul_</span><span class="p">(</span><span class="n">m</span><span class="p">).</span><span class="n">add_</span><span class="p">(</span>
                        <span class="p">(</span><span class="mf">1.</span><span class="o">-</span><span class="n">m</span><span class="p">)</span> <span class="o">*</span> <span class="n">param_q</span><span class="p">.</span><span class="n">detach</span><span class="p">().</span><span class="n">data</span> 
                      <span class="p">)</span>

            <span class="k">return</span> <span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="n">loss</span><span class="p">),</span> <span class="n">_new_lr</span><span class="p">,</span> <span class="n">_new_wd</span><span class="p">,</span> <span class="n">grad_stats</span><span class="p">)</span>
</code></pre></div></div>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="JEPA" /><summary type="html"><![CDATA[I-JEPA의 전반적인 학습 코드를 본다.]]></summary></entry><entry><title type="html">[논문리뷰] Brain Network Transformer</title><link href="https://kitewatermelon.github.io/paper-review/bnt/" rel="alternate" type="text/html" title="[논문리뷰] Brain Network Transformer" /><published>2026-03-16T00:00:00+09:00</published><updated>2026-03-16T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/bnt</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/bnt/"><![CDATA[<blockquote>
  <p>NeurIPS 2022 [<a href="https://arxiv.org/pdf/2210.06681">Paper</a>] [<a href="https://github.com/Wayfear/BrainNetworkTransformer">GitHub</a>]<br />
 Xuan Kan, Wei Dai, Hejie Cui, Zilong Zhang, Ying Guo, Carl Yang
 15 Oct 2022</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>Brain network analysis는 신경 과학자들에게 사람의 뇌 구조 이해와 임상 결과 예측을 위한 흥미로운 연구이다. 다양한 모달 중 fMRI가 뇌 네트워크 구조를 위해 주로 사용된다. fMRI의 노드는 atlas 기반의 ROIs로 정의되고 각 노드간 BOLD 신호의 상관관계로 엣지를 정의한다. 연구자들은 action, language, and vision 같은 cognitive-related tasks가 일어날때 특정 영역이 동시에 활성화되거나 비활성화 되는 것을 관측해왔다. 이런 패턴을 바탕으로 brain regions을 다양한 기능적 모듈로 분류하여 질병을 분석하여 진단, 진행 이해 및 치료에 사용할 수 있다.</p>

<p>Transfer의 여러 분야에서의 성공을 해왔고 그 중 GAT라는 모델이 처음으로 GNNs의 영역에 적용되었다. 그러나 이는 이웃 노드의 local 구조만 고려하였다. Graph Transformer는 edge information을 attention mechanism에 집어넣었고 각 노드의 eigenvectors를 position embedding(PE)로 활용하였다. SAN은 eigenvalue와 eigenvectors를 동시에 고려하여 PE를 더 강화하고 attention을 local 구조에서 global 구조로 확장하였다. Graphomer는 독창적인 메커니즘으로 OGB Large-Scale Challenge에서 우승하였다.</p>

<p>그러나 brain networks는 기존 Graph Transformer 모델이 실용적이지 않은 독특한 특성이 여럿 있다.</p>

<ol>
  <li>주로 사용되는 방법이 ROIs 간의 BOLD 신호의 correlation이다. 이는 centrality, spatial 그리고 edge encoding 같은 디자인을 방해한다. 왜냐하면 각 노드들은 같은 차수를 가지고 모든 노드간 single hop으로 연결되어 있기 때문이다.</li>
  <li>기존 Graph transformer models는 eigenvalue와 eigenvectors를 주로 PE에 사용한다. 왜냐하면 그들은 각 노드의 identity와 positional information을 제공하기 때문이다. 그러나 brain network에서는, brain network adjacency matrix의 각 노드의 해당 행으로 정의되는 connection profile이 가장 효과적인 node feature로 인식된다. 이 node feature는 구조적, 위치적 정보 둘 다 자연스럽게 인코딩한다. 이는 앞서 설명한 PE 디자인이 중복으로 여겨 진다.</li>
  <li>Scalability도 중요하다. 보통은 노드와 엣지의 수가 50과 2500개 미만인데, brain network는 atlas에 따라 100~400 개의 노드를 가지고 이는 최대 160k개의 엣지가 생긴다. 그러므로 현존하는 Graph transformer model으로 모든 엣지의 기능 생성과 같은 작업이 불가능하지 않더라도 시간이 많이 걸릴 수 있다.</li>
</ol>

<p>본 논문에서는 brain network analysis를 위해 brain network의 독특한 특성을 transformer-based moedl의 힘으로 완전히 해방하는 BRAIN NETWORK TRANSFORMER (BRAINNETTF, 다른 논무에서는 주로 BNT로 불림)을 제안한다. 특히, 기존 GNN에서의 발견을 통해 connection profiles의 초기 node feature를 효과적으로 초기화하는 방법을 제안한다.</p>

<p>한 단계 나아가서 brain network analysis에 GNNs를 사용할 때는 학습된 node embedding을 기반으로 readout function을 통해 graph-level embedding을 생성해야 한다. brain network의 특성상 동일한 기능적 모듈에 속하는 노드들은 다양한 자극에 대한 활성화 및 비활성화 반응에서 유사한 행동 양식을 공유하는 경우가 많다. 이를 위해 ORTHONORMAL CLUSTERING READOUT을 설계하여 노드들의 cluster에서 graph-level embedding을 pooling한다.</p>

<p>마지막으로 공개 데이터셋의 부족은 brain network analysis에서 무시할 수 없는 과제이다. 예를 들어 ABIDE는 별도의 접근 허가 없이 fMRI를 완전히 사용할 수 있지만 17개의 기관에서 서로 다른 스캐너와 파라미터를 사용하여 획득된다. 이러한 inter-site variability는 실질적으로 의미 있는 집단 간 차이를 가려버리며 이로 인해 학습시 불안정성 증가 및 검증/테스트 셋과의 유의미한 격차로 나타난다. 이를 해결하기 위해 stratified sampling을 제안하고 표준화 할 것을 제안한다.</p>

<h2 id="2-background-and-related-work">2 Background and Related Work</h2>
<p>해당 섹션에서는 배경과 관련 연구를 설명한다. 본문에서는 다루지 않는다.</p>

<h2 id="3-brain-network-transformer">3 BRAIN NETWORK TRANSFORMER</h2>
<h3 id="31-problem-definition">3.1 Problem Definition</h3>
<p>brain network analysis에서 brain network: $X \in \mathbb{R}^{V \times V} $ where $V$는 node (ROIs)의 수 이며 주로 성, 질병의 유무, brain subject의 특징을 예측하는 것이 모델의 주 목표이다. BNT의 전반적인 프레임워크는 아래 그림과 같다. $L$개의 MHSA layer와 graph pooling operator OCREAD로 두개의 main component로 이루어져있다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig2.webp" width="80%" />
</center>
<p><br /></p>

<p>MHSA에서 non-linear mapping $X \rarr Z^L \in \mathbb{R}^{V \times V}$을 통해 attention-enhanced node features $Z^L$을 학습한다. OCREAD은 enhanced node embeddings $Z^L$을 graph-level embeddings $Z_G \in \mathbb{R}^{K \times V}$ 로 압축한다. $K$는 hyperparam인 number of clusters이다. $Z_G$는 flatten 후 MLP를 통과하여 graph-level prediction을 한다. 모든 학습 과정은 CE를 통한 supervised learning을 통해 이루어진다.</p>

<h3 id="32-multi-head-self-attention-module-mhsa">3.2 Multi-Head Self-Attention Module (MHSA)</h3>
<p>brain network에 적합한 Transformer-based model을 개발하기 위해 PE와 attention mechanism이라는 두가지 기초 디자인에 대한 재고가 필요하다. 현존하는 모델은 eigendecomposition을 통해 주로 위치 정보를 인코딩하지만, 이는 dense한 brain network에서 비용이 많이 들고 edge의 존재가 유익하지 않다.</p>

<p>brain networks의 ROI 노드는 이미 필요한 위치 정보를 가지고 있으므로 eigendecomposition을 통한 position encoding은 중복이다. 기존 연구에서 connection profile $X_i$을 통한 분석이 항상 eigenvectors를 이용한 방법보다 좋은 성능을 보였다. 또한 이전 연구에서 edge weight를 attention score에 통합하면 완전 그래프에서 attention의 효과를 크게 저하시킬 수 있음을 경험적으로 입증하였다. 또한 brain network는 edge가 매우 많기 때문에 edge-wise embedding을 생성하는 것이 계산적으로 감당하기 어렵다. 또한 이 케이스에서 모든 엣지가 단순히 존재하기 때문에 존재 여부 자체도 attention score 계산에 유용한 정보를 제공하지 않는다.</p>

<p>이러한 관점에서 BNT는</p>
<ol>
  <li>connection profile을 초기 node feature로 삼고 PE를 없앤다.</li>
  <li>edge weight나 상대 위치 정보를를 사용하지 않는 바닐라 pair-wise mechanism을 사용한다.</li>
</ol>

<p>$Z^L = \text{MHSA}(X) \in \mathbb{R}^{V \times V}$를 생성하는 $L$-Layer non-linear mapping module인 MHSA 수식은 다음과 같다. 각 레이어 $l$에 대하여 출력 $Z^l$은 다음과 같이 얻어진다.</p>

\[\begin{equation}
Z^l = \left( \Big\|_{m=1}^{M} h^{l,m} \right) W^l_Oh^{l,m} = \text{Softmax} \left( 
    \frac{W^{l,m}_Q Z^{l-1} 
    \left( W^{l,m}_K Z^{l-1} \right)^\top}
    {\sqrt{d^{l,m}_K}} 
\right) W^{l,m}_V Z^{l-1}
\end{equation}\]

<p>여기서 $Z^0 = X$, ∥는 concatenation 연산자, M은 헤드 수, l은 레이어 인덱스, $W^l_O,\ W^{l,m}_Q,\ W^{l,m}_K,\ W^{l,m}_V$는 학습 가능한 파라미터, $d^{l,m}_K$ 는 $W^{l,m}_K$ ​의 첫 번째 차원이다.</p>

<h3 id="33-orthonormal-clustering-readout-ocread">3.3 ORTHONORMAL CLUSTERING READOUT (OCREAD)</h3>
<p>graph-level representation을 학습하는 readout function은 brain network analysis를 위해 필수적이다. Mean, Sum, Max가 주로 사용된다. 그러나 현존하는 방법 중 어느 것도 fig1(a)에 표시된 것 처럼 brain network의 동일한 기능 모듈의 노드가 유사한 동작과 clustering representation을 갖는 경향이 있는 속성을 사용하지 않는다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig1.webp" width="80%" />
</center>
<p><br /></p>

<p>이를 해결하기 위해 ROIs긴 modular-level similarities의 이점을 활용하는 novel readout function을 제안한다.</p>

<p>$V$ 차원을 가진 $K$ 개의 cluster center $E \in \mathbb{R}^{K \times V}$ 가 주어졌을때 Softmax projection operator가 노드 $i$를 cluster $k$로 할당하는 probability $P_{ik}$를 계산하는 함수로 사용된다.</p>

\[\begin{equation}
P_{ik} = \frac{e^{\langle Z^L_{i\cdot},\, E_{k\cdot} \rangle}}{\sum_{k'}^{K} e^{\langle Z^L_{i\cdot},\, E_{k'\cdot} \rangle}},
\end{equation}\]

<p>soft assignment가 계산된 이후로 $Z^L$은 soft cluster information의 가이드를 받아 graph-level embedding $Z_G$로 집약된다. $Z_G = P^TZ^L$
그러나 GT 없이 node embedding과 cluster를 학습하는 것은 어렵다. 따라서 클러스터 센터 초기화가 매우 중요하다. 이를 해결하기 위해 Fig1(b)에서의 관측을 활용한다. 이는 orthonormal(직교) embedding이 brain network내에서 node clustering을 향상시키는 것이다.</p>

<blockquote>
  <p>이 부분은 완전히 이해하지 못했다. 논문 본문에 이론적 정의가 잘 나와 있으니 참고하면 될 것 같다.</p>
</blockquote>

<h3 id="34-generalizing-ocread-to-other-graph-tasks-and-domains">3.4 Generalizing OCREAD to Other Graph Tasks and Domains</h3>
<p>본 논문에서 OCREAD는 FC based brain network를 이용하였다. 그러나 이에 국한되지 않고 Structural connectivities(SC)등에도 사용가능하며 protein-protein interaction networks나 유전자 발현 network에서도 사용가능하다.</p>

<h2 id="4-experiments">4 Experiments</h2>
<p>저자들은 다음 세가지 RQ에 대한 검증을 위주로 실험했다.</p>

<p>RQ1. How does BRAINNETTF perform compared with state-of-the-art models of various types? (SOTA 급인지?)</p>

<p>RQ2. How does our proposed OCREAD module perform with different model choices? (OCREAD 모듈은 다양한 모델 선택에서 어떻게 작동하는지?)</p>

<p>RQ3. Does the learned model of BRAINNETTF exhibit consistency with existing neuroscience knowledge and suggest reasonable explainability? (BNT가 현존하는 neuroscience knowledge와 일관성있고 함리적인 설명 가능성을 제안하는지?)</p>

<h3 id="41-experimental-settings">4.1 Experimental Settings</h3>
<p>Dataset:</p>
<ul>
  <li>ABIDE
    <ul>
      <li>#: 1009 subjects</li>
      <li>자폐: 516 subjects</li>
      <li>atals:Craddock 200</li>
      <li>network 바로 다운 가능</li>
      <li>multi-site problem을 stratified sampling로 해결</li>
    </ul>
  </li>
  <li>ABCD
    <ul>
      <li>#: 7901 subjects</li>
      <li>여자: 3961 subjects</li>
      <li>atals: HCP 360 ROI</li>
    </ul>
  </li>
</ul>

<p>Metrics는 두 데이터셋 모두 binary classification이므로 AUROC를 사용했으며 임상 적용성을 위해 Sensitivity와 Specificity까지 3가지를 5 random seed에 평균과 표준편차를 보고한다. Model의 헤드는 4개 레이어는 2개를 사용했으면 7:1:2 split을 활용한다. Adam으로 1e-4의 초기 learning rate를 사용한다. weight decay는 1e-4를 사용하며, batch size는 64로 설정됐다. 200 epoch동안 AUROC가 가장 좋은 모델을 테스트에 사용했다.</p>

<h3 id="42-performance-analysis-rq1">4.2 Performance Analysis (RQ1)</h3>

<center>
<img src="/assets/img/paper-review/bnt/tab1.webp" width="80%" />
</center>
<p><br /></p>

<p>(a) BNT vs other graph transformers <br />
위 표를 보면 알 수 있지만 VanillaTF가 다른 비교군을 AUROC 측면에서 이겼다. 저자들은 이를 brain network의 특성 때문이라고 분석했다. 이는</p>

<p>(b) BRAINNETTF vs. neural network models on fixed brain networks <br />
BrainGNN, BrainGB 및 BrainnetCNN 등의 NN 기반 모델보다 더 좋은 성능을 보였다.</p>

<p>(c) BRAINNETTF vs. neural network models on learnable brain networks <br />
learnable graph를 사용한 경우에도 더 좋은 성능을 보였다.</p>

<h3 id="43-ablation-studies-on-the-ocread-module-rq2">4.3 Ablation Studies on the OCREAD Module (RQ2)</h3>
<h4 id="431-ocread-with-varying-readout-functions">4.3.1 OCREAD with varying readout functions</h4>
<p>아래 표는 SAN, Graphormer and VanillaTF에 대하여 다양한 readout function과 본 논문에서 제안된 OCREAD를 비교한 표이다. 전반적으로 OCREAD가 가장 효과적인 readout function이었으며 다양한 transformer 아키텍처의 예측 성능을 높혀주었다.</p>

<center>
<img src="/assets/img/paper-review/bnt/tab2.webp" width="80%" />
</center>
<p><br /></p>

<h4 id="432-ocread-with-varying-cluster-initializations">4.3.2 OCREAD with varying cluster initializations</h4>
<p>OCREAD의 설계가 BNT의 성능에 어떻게 영향을 미치는지 추가로 입증하기 위해 cluster center 초기화 방법과 cluster $K$의 수를 어떻게 선택하는지 이 섹션에서 논의한다.
Random, Learnable, Orthonormal 세가지로 초기화 방법을 비교했으며 2, 3, 4, 5, 10, 50, 100로 $K$를 비교한다. fig3(a)가 이의 결과이다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig3.webp" width="80%" />
</center>
<p><br /></p>

<p>결과에 따르면 최적의 $K$의 수는 상대적으로 작다. 이는 적은 계산량으로 이끌며 일반적인 functional module의 수가 25개 미만인 것과 일치한다. 충분히 큰 $K$에서는 세가지 방법이 비슷한 성능을 보이나 적은 $K$에서 제안된 orthonormal 방법이 가장 안정적으로 성능을 낸다. OCREAD의 std가 가장 작은 것도 알아두면 좋다.</p>

<h3 id="44-in-depth-analysis-of-attention-scores-and-cluster-assignments-rq3">4.4 In-depth Analysis of Attention Scores and Cluster Assignments (RQ3)</h3>
<p>fig3(b)는 ABCD 테스트셋에서 MHSA 첫 레이어의 평균 self-attention score이다. 이는 학습된 attention score가 available labels 기반의 functional modules의 구분과 잘 일치하여 transformer 모델의 효율성과 설명 가능성을 보여준다. ABIDE는 label을 제공하지 않기 때문에 시각화하지 못했다.</p>

<p>아래 그림은 두 가지 초기화 방법을 사용하여 OCREAD의 노드에 대한 cluster soft assignment 결과 $P$를 보여준다. orthonormal 초기화는 random 초기화보다 좀 더 구분 가능한 $P$를 보여준다. 각 클래스 내에서 orthonormal 초기화는 노드가 그룹을 형성하도록 장려한다.</p>

<center>
<img src="/assets/img/paper-review/bnt/fig4.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="5-discussion-and-conclusion">5 Discussion and Conclusion</h2>
<p>본 논문에서는 brain network analysis를 위한 OCREAD를 갖춘 특화된 graph transformer를 제시한다. 두개의 대규모 brain network 데이터셋에 대한 광범위한 실험으로 BNT가 SOTA인 것을 확인했으며 brain network의 잠재적인 노드 기능 유사성을 모델링하기 위해 OCREAD를 설계하고 이론적, 경험적으로 그 효과를 입증한다. 마지막으로 ABIDE에 대한 재표준화된 데이터셋 분할은 커뮤니티의 새로운 방법에 대한 공정한 평가를 제공할 수 있다. 향후 작업을 위해 BNT는 explicit explanation modules을 통해 개선될 수 있으며 정신 장애에 대한 필수 신경 회로 발굴 및 청소년의 인지 발달 이해와 같은 추가 뇌 네트워크 분석을 위한 백본으로 사용될 수 있다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<ul>
  <li>본 논문은 brain network 관련 연구에서 베이스라인으로 사용되는 모델이라 자세히 리뷰를 하고 싶었는데, 일단 OCREAD라는 것이 나에게는 너무 어려워서 제대로 이해하지 못했다. 추후에 BNT에 대하여, 또 OCREAD에 대하여 자세히 공부할 일이 있으면 이론까지 공부를 하여 해당 본문을 수정할 의향이 있다.</li>
  <li>기존 모델들이 왜 brain network에서 약한 모습을 보였는지 이해할 수 있었다. brain network의 고유한 특성에 대하여 이해할 수 있었다. 저자들의 뇌 기능에 관한 깊은 insight를 얻을 수 있어서 좋았다.</li>
  <li>brain network가 complete graph 라서 edge 정보가 불필요한 것을 이해하였으며 connection profile이 구조적 위치적 정보를 가지고 있는 것을 알게 되었다.</li>
  <li>다시 원문을 보니 stratified sampling에 대하여 자세히 설명하지 않은 것 같다.</li>
  <li>글을 보면 자연스럽지 못한 부분이 참 많은 것 같다. 이는 내가 이 논문을 완전히 이해하고 있지 못하다는 뜻이다. 여러모로 아쉬운 리뷰였지만, 이번 리뷰를 발판 삼아 더 나은 리뷰를 하는 블로거가 되어야겠다.</li>
</ul>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Medical-AI" /><category term="Brain-Network" /><category term="NeurIPS" /><summary type="html"><![CDATA[Brain Network Transformer (NeurIPS)]]></summary></entry><entry><title type="html">[논문리뷰] Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture</title><link href="https://kitewatermelon.github.io/paper-review/ijepa/" rel="alternate" type="text/html" title="[논문리뷰] Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" /><published>2026-03-14T00:00:00+09:00</published><updated>2026-03-12T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/ijepa</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/ijepa/"><![CDATA[<blockquote>
  <p>CVPR 2023 [<a href="https://arxiv.org/pdf/2301.08243">Paper</a>] [<a href="https://github.com/facebookresearch/ijepa">GitHub</a>]<br />
 Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas
 13 Apr 2023</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>CV 분야에서는 invariance-based method와 generative method라는 두가지 SSL 기법이 있다. invariance-based pretraining method는 같은 이미지의 서로 다른 view에서 비슷한 임베딩을 얻으며 최적화한다. 이때 서로 다른 view는 hand-crafted augmentations를 통해 주로 만든다. 이 pretraining 기법은 high semantic level의 representations을 얻을 수 있지만, 특정 task나 다른 데이터 분포에서 강한 편향을 주입하기도 한다. 다양한 수준의 추상화가 필요한 task에 대해 이런 편향을 일반화하는 것은 아직 불분명한 경우가 많다. 예를 들어 image classification과 instance segmentation은 같은 invariance를 요구하지 않는다. 추가적으로 image-specific augmentation을 audio 같은 다른 모달에 일반화하는 것은 간단하지 않다.</p>

<p>Cognitive learning theories는 생물학적 시스템에서 representations learning 이면에 있는 구동 메커니즘은 감각 입력 반응을 예측하기 위한 내부 모델의 adaptation이라고 제안한다. 이 아이디어는 self-supervised generative methods의 핵심이다. 이는 입력의 일부를 제거하거나 오염시키고 해당 부분을 예측하는 방식이다. Masked pretraining task는 view-invariance method 보다 사전 지식(hand-crafted transformers을 의미하는 것 같음.)이 덜 필요하고 다른 모달에 일반화 성능이 좋다. 그러나 invariance-based method보다 낮은 semantic 수준을 보이며ㅡ off-the-shelf evaluation에서 성능이 낮다. 결과적으로 end-to-end fine-tunning 같은 복잡한 adaptation 메커니즘을 활용해야 이 방법의 완전한 이점을 누릴 수 있다.</p>

<p>본 논문에서는, I-JEPA라는 추가 사전 지식 없이 self-supervised representations의 semantic 수준을 높이는 method를 도입한다. I-JEPA의 핵심 아이디어는 abstract representation space(추상 표현 공간)에서 누락된 정보를 예측하는 것이다. 예를 들어 같은 이미지의 단일 context block를 주고, 여러 target block의 representation을 예측하는 것이다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig3.webp" width="80%" />
</center>
<p><br /></p>

<p>pixel/token space에서 예측하는 generative methods에 비해 I-JEPA는 불필요한 pixel-level 디테일이 잠재적으로 지워진 target의 abstract를 예측함으로 model이 더 의미있는 특징을 학습하도록 이끌어낸다.</p>

<p>I-JEPA를 더 semantic representations을 생산하도록 선택된 또 다른 핵심 디자인은 multi-block masking strategy이다. 특히 이미지에서 충분히 큰 target blocks를 예측하도록 하는 것의 중요성을 입증한다.</p>

<p>저자들은 방대한 양의 실험을 통해 다음을 입증한다.</p>
<ul>
  <li>strong off-the-shelf representation을 hand-crafted view augmentation없이 학습한다.</li>
  <li>I-JEPA는 view-invatiant pretraining approaches와 비등한 semantic task 결과를 보였고, low-level visions tasks에서는 더 나은 결과를 보였다.</li>
  <li>I-JEPA는 scalable하고 효율적인다.</li>
</ul>

<h2 id="2-background">2. Background</h2>
<p>SSL은 system이 입력간의 관계를 포착하도록 하는 representation learning 기법이다. 이 목표는 incompatible inputs 끼리는 높은 에너지를, compatible inputs끼리는 낮은 에너지를 할당하는 Energy-Based Models(EBMs)의 프레임워크를 이용하여 쉽게 설명 가능하다. 현존하는 많은 SSL 방법들이 이 프레임워크로 설명가능하다. 다음 그림을 보면 이해가 쉽다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig2.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="joint-embedding-architectures">Joint-Embedding Architectures</h3>
<p>Invariance-based pretraining methods는 compatible inputs인 $x,y$ 간에 비슷한 embedding을 산출하고,incompatible inputs에는 다른 embedding을 산출하느 Joint-Embedding Architectures을 사용하는 EBMs로 설명가능하다(Figure 2a.). image-based pretraining의 관점에서 compatible inputs인 $x,y$는 주로 같은 입력 이미지에 랜덤하게 hand-crafted augmentation을 적용하여 만든다. JEAs의 주된 과제는 energy landscapes가 평평해 지는 representation collapse이다(입력에 관계없이 완전히 같은 출력을 내보냄.). 지난 몇년간 representation collapse을 방지하기 위한 다양한 방법이 연구되었다. contrastive loss를 사용하거나 non-contrastive loss를 사용하여 embedding간 정보 중복을 최소화 하거나 평균 임베딩의 엔트로피를 극대화하는 clustering-based 등의 방법이 있다. 또 서로 다른 인코더를 사용해서 collapse를 방지하는 방법도 있다.</p>

<h3 id="generative-architectures">Generative Architectures</h3>
<p>Reconstruction-based method 역시 Generative Architectures를 사용하는 EBMs 프레임워크로 설명할 수 있다(Figure 2b.). Generative Architectures는 compatible signal $x$에서 바로 $y$를 reconstruction한다. 이때 디코더는 이를 촉진하기 위해서 $z$를 추가적으로 조건으로 받는다. image-based pretraining의 관점에서 가장 흔한 compatible inputs $x,y$를 만드는 방법은 masking이다. $z$는 mask와 position token이다. 이는 어느 이미지 패치를 디코더가 reconstruction할지 명시해준다. $y$보다 $z$의 정보 용량이 적은한 representation collapse는 문제가 되지 않는다.</p>

<h3 id="joint-embedding-predictive-architectures">Joint-Embedding Predictive Architectures</h3>
<p>Figure 2c.에서 볼 수 있듯이 Joint-Embedding Predictive Architectures는 Generative Architectures와 비슷하다. 이와 가장 큰 차이는 loss 계산이 input space가 아닌 embedding space에서 일어나는 것이다. JEPAs는 예측을 용이하기 위한 변수 $z$를 조건으로 받아 compatible signal $x$로 부터 signal $y$의 임베딩을 예측 네트워크를 통해 예측하도록 학습한다. 자세한 그림은 Figure 3.을 참조하면된다.</p>

<p>JEA와 다르게 JEPAs는 representation invatiant를 hand-crafted augmentation을 이용하여 representation invatiant를 찾지 않고 대신 추가 정보 $z$를 조건으로 할 때 서로를 예측하는 representation을 찾는다. 그러나 JEA와 마찬가지로 representation collapse는 JEPAs에서도 문제가 되는데, 이를 방지하기 위해 $x$와 $y$ 사이에 비대칭 아키텍처를 사용한다.</p>

<h2 id="3-method">3. Method</h2>
<p>Figure 3.을 다시 한번 보자.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig3.webp" width="80%" />
</center>
<p><br /></p>

<p>I-JEPA의 전반적인 목적은 같은 이미지에서 context block이 주어졌을때 다양한 target block의 representation을 예측하는 것이다. context-encoder와 target-encoder 모두 ViT를 사용하였으며 decoder구조는 MAE에서 따왔다. 둘의 차이는 I-JEPA는 non-generative method이며 prediction이 representation space에서 일어나는 것이다.</p>

<h3 id="targets">Targets</h3>
<p>주어진 입력 이미지 $y$를 겹치지 않는 N개의 패치를 만든다. target-encoder $f_{\bar \theta}$에 이걸 넣어서 대응하는 patch-level representation $s_y=\lbrace s_{y1},…, s_{yN} \rbrace $을 만든다. loss를 위한 targets를 얻기 위해 $s_y$에서 M개의 랜덤한 sample block을 뽑는다. i번째 블록에 대응하는 마스크를 $B_i$로 표시하고 $s_y(i) = \lbrace s_{yj} \rbrace _{j \in B_i}$로 표시한다. 저자들은 실험에서 M=4로 셋하고 0.75:1의 종횡비와 0.15~0.2의 스케일로 block을 샘플링했다.</p>

<h3 id="context">Context</h3>
<p>I-JEPA의 목표는 single context block으로 부터 target block의 representation을 예측하는 것이다. 이를 위해 이미지의 0.85~1의 스케일로 $x$를 샘플링하고 $B_x$를 이용해 context blocks을 할당한다. 이후 target block과 겹치는 부분을 없애준다. 아래 그림이 target blocks와 context block을 이해하는데 도움을 준다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig4.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="prediction">Prediction</h3>
<p>context encoder의 출력 $s_x$가 주어졌을 때 우리는 $M$ 개의 target block의 
representation $s_y(1), \ldots, s_y(M)$ 을 예측하기를 바란다. 이를 위해, 
대상 마스크 $B_i$ 에 해당하는 주어진 target block $s_y(i)$ 에 대해 예측기 
$g_\phi(\cdot, \cdot)$ 는 context encoder의 출력 $s_x$와 예측하려는 각 패치에 
대한 마스크 토큰 $\lbrace m_j \rbrace_{j \in B_i}$ 를 입력으로 취하고 패치 수준 예측 \(\hat{s}_{y}(i) = \lbrace \hat{s}_{yj} \rbrace_{j \in B_i} = g_\phi(s_x, \lbrace m_j \rbrace_{j \in B_i})\) 을 출력한다. 
mask token은 positional embedding이 추가된 shared learnable vector이다.</p>

<h3 id="loss">Loss</h3>
<p>loss는 predicted patch-level representations $\hat{s}_y(i)$과 the target patch-level representation $s_y(i)$간의 평균 $L_2$ distance이다. 
predictor, $\phi$와 context encoder, $\theta$는 gradient based optimization을 target encoder $\bar \theta$는 EMA 방식으로 학습한다.</p>

<h2 id="4-related-work">4. Related Work</h2>
<p>본문에서 다양한 SSL 기법에 대한 설명을 하고 있으나 이 글에서는 다루지 않겠다.</p>

<h2 id="5-image-classification">5. Image Classification</h2>

<h3 id="imagenet-1k">ImageNet-1K</h3>
<center>
<img src="/assets/img/paper-review/ijepa/tab1.webp" width="80%" />
</center>
<p><br /></p>

<p>hand-crafted augmentation을 사용하지 않는 다른 유명한 방법인 MAE, CAE 그리고 data2vec과 비교했을때 I-JEPA는 더 적은 연산으로 linear probing 성능을 향상시켰다.</p>

<h3 id="low-shot-imagenet-1k">Low-Shot ImageNet-1K</h3>
<center>
<img src="/assets/img/paper-review/ijepa/tab2.webp" width="80%" />
</center>
<p><br /></p>

<p>IN1k의 1%(각 클래스 별로 12~13 장)으로 학습한 결과이다. 적은 에폭으로도 비슷한 구조의 MAE보다 나은 성능을 보였다. 이미지 해상도가 높아져도 JEAs보다 더 나은 성능을 보인다.</p>

<h3 id="transfer-learning">Transfer learning</h3>
<p>기존 모델들 보다 더 좋은 성능을 보였으며 view-invariance-based와의 간격도 줄었다.</p>
<center>
<img src="/assets/img/paper-review/ijepa/tab3.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="6-local-prediction-tasks">6. Local Prediction Tasks</h2>
<ol>
  <li>에서 I-JEPA의 강력함을 엿볼 수 있었는데 이 섹션에서는 I-JEPArk local image feature를 학습하고 low-level이고 dense prediction task에서 view-invariance based method보다 더 나은 결과를 보임을 입증한다.</li>
</ol>
<center>
<img src="/assets/img/paper-review/ijepa/tab4.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="7-scalability">7. Scalability</h2>
<h3 id="model-efficiency">Model Efficiency</h3>
<p>I-JEPA는 기존 방법들보다 더 높은 확장성을 제공한다. MAE 같은 reconstructionbased methods는 픽셀을 target으로 삼는 반면, I-JEPA는 representation space에서 계산을 하기 때문에 약간의 오버헤드가 있다. 그러나 5배 더 빠른 수렴을 보여준다. 또한 I-JEPA로 ViT-H/14를 학습하는 것 보다 ViT-S/16으로 iBOT을 학습하는 것이 더 적은 연산을 필요로 한다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig5.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="scaling-data-size">Scaling data size</h3>
<p>I-JEPA는 더 큰 데이터셋에서 pretraining할 때 효과적임을 아래 표를 통해 확인할 수 있다.</p>
<center>
<img src="/assets/img/paper-review/ijepa/tab5.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="scaling-model-size">Scaling model size</h3>
<p>위 표는 I-JEPA가 더 큰 모델에서 pretraining을 할때 더 효과적임을 입증한다. 그러나 ViT-G/16은 입력 팿치가 더 커서 local prediction task 성능이 안좋다.</p>

<h2 id="8-predictor-visualizations">8. Predictor Visualizations</h2>
<p>I-JEPA로 학습한 모델을 RCDM framework로 생성을 시킨 결과이다. I-JEPA의 predictor는 고수준 object의 부분을 정확한 Pose로 잘잡아낸다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/fig6.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="9-ablations">9. Ablations</h2>
<h3 id="predicting-in-representation-space">Predicting in representation space</h3>
<p>pixel space vs representation space에서 loss를 계산할 때의 성능 차이를 ImageNet-1K 1% linear probe로 비교한 실험이다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/tab7.webp" width="80%" />
</center>
<p><br /></p>

<p>pixel space에서 예측하게 되면 모델이 픽셀 수준의 세부 정보(텍스처, 조명, 노이즈 등)까지 다 맞춰야 해서 representation이 low-level detail에 오염된다. 그러나 
representation space에서 예측하게 되면 target encoder가 추상적인 예측 타겟을 만들 수 있으므로 의미없는 픽셀 수준 디테일이 제거된 상태로 학습된다.</p>

<h3 id="masking-strategy">Masking strategy</h3>
<p>다양한 마스킹 전략을 ablation한 결과이다. multi-block masking이 I-JEPA가 semantic representation을 학습하는 데에 도움이 되는 가이드를 하는 것을 알아냈다.</p>

<center>
<img src="/assets/img/paper-review/ijepa/tab6.webp" width="80%" />
</center>
<p><br /></p>

<h2 id="10-conclusion">10. Conclusion</h2>
<p>본 논문에서 저자들은 I-JEPA를 소개한다. I-JEPA는 hand-crafted augmentation에 의존하지 않으며 semantic image representation을 학습하는 간단하고 효율적인 방법이다. I-JEPA는 다른 pixel-level의 방법보다 빠르게 수렴하며 높은 수준의 semantic representation을 학습한다. view-invariance based method와 달리 I-JEPA는 and-crafted augmentation에 의존하지 않고 JEA를 사용하여 general representation을 학습할 수 있는 경로를 강조한다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<ul>
  <li>오랜만에 CV 이론 논문을 리뷰해서 introduction에 힘을 줘버려서 Experiments 부분을 제대로 리뷰하지 못한 것 같아서 아쉬웠다.</li>
  <li>얀 르쿤이 심열을 기울인 방법으로 디테일이 크게 돋보인 논문이었다.</li>
  <li>아직 익숙하지 않은 개념이라, 따로 드는 생각은 없는 것 같다. 코드를 한번 뜯어봐야겠다.</li>
</ul>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Computer-Vision" /><category term="Self-Supervised-Learning" /><category term="CVPR" /><summary type="html"><![CDATA[Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture (CVPR)]]></summary></entry><entry><title type="html">[코드 리뷰] Tensor Slicing</title><link href="https://kitewatermelon.github.io/code-review/tensor-slice/" rel="alternate" type="text/html" title="[코드 리뷰] Tensor Slicing" /><published>2026-03-12T00:00:00+09:00</published><updated>2026-03-12T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/tensor-slice</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/tensor-slice/"><![CDATA[<p>실습 예제: <a href="https://colab.research.google.com/drive/1btkCLqW3QAqOZZ6ymbmL5trOHW4ri4IS#scrollTo=9_nIJ9P9H5Np">Colab</a></p>

<h2 id="1-introduction">1. Introduction</h2>
<p>딥러닝 관련 논문을 읽으며 공부하다 보면 대부분의 논문 구현이 PyTorch(torch) 기반으로 되어 있음을 알 수 있다. 아래 그래프에서 볼 수 있듯이 2024년 기준 PyTorch, TensorFlow, JAX 중 PyTorch를 사용한 프로젝트가 제일 많은 비중을 차지하는 것을 볼 수 있다 <a href="https://softwaremill.com/ml-engineer-comparison-of-pytorch-tensorflow-jax-and-flax/">[1]</a>.</p>

<p><img src="https://softwaremill.com/user/pages/blog/229.ml-engineer-comparison-of-pytorch-tensorflow-jax-and-flax/image2.png?g-1efd1e18" alt="image.png" /></p>

<h3 id="11-tensor-slicing이란">1.1. Tensor Slicing이란?</h3>
<p>텐서 슬라이싱은 다차원 배열에서 원하는 부분만 선택적으로 추출하는 연산으로 NumPy의 배열 인덱싱에서 유래했고 PyTorch도 동일한 문법을 사용한다 (NumPy-like).</p>

<h3 id="12-왜-필요한가">1.2. 왜 필요한가?</h3>
<p>딥러닝에서 텐서는 보통 (B, H, W, C) 같은 고차원 구조를 가진다. 모델 내부에서 특정 배치만, 특정 채널만, 특정 공간 위치만 꺼내서 연산해야 할 일이 매우 많다. 슬라이싱 없이는 불필요한 데이터까지 복사하거나 반복문으로 순회해야 하는데, 슬라이싱은 이걸 뷰(view) 방식으로 해결한다.</p>

<h3 id="13-원리">1.3. 원리</h3>
<p>핵심은 “메모리를 복사하지 않는다”는 것이다.</p>

<p>텐서는 내부적으로 두 가지로 구성되는데:</p>
<ul>
  <li>storage: 실제 데이터가 1D로 연속 저장된 메모리</li>
  <li>stride + offset: “몇 칸 건너뛰면 다음 원소인지”를 기술하는 메타데이터</li>
</ul>

<p>슬라이싱을 하면 storage는 그대로 두고 stride와 offset만 바꾼 새 텐서 객체를 반환하기 때문에 빠르고 메모리 효율적이다 <a href="https://docs.pytorch.org/docs/stable/tensor_view.html">[2]</a>.</p>

<p>메모리를 복사하지 않는 것으로 문제가 생길 수 있는데, 이 부분은 3장에서 다룬다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>  <span class="c1"># (B, H, W, C)
</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:].</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[...,</span> <span class="mi">0</span><span class="p">].</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">].</span><span class="n">shape</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>torch.Size([3, 3, 3])
torch.Size([3, 3, 3])
torch.Size([3, 3, 3])
torch.Size([2, 3, 3, 3])
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">].</span><span class="n">stride</span><span class="p">())</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">].</span><span class="n">stride</span><span class="p">())</span>
<span class="n">t_slice</span> <span class="o">=</span> <span class="n">t</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">()</span> <span class="o">==</span> <span class="n">t</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">())</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(27, 9, 3, 1)
(9, 3, 1)
(27, 3, 1)
True
</code></pre></div></div>

<h2 id="2-시각화로-텐스-슬라이싱-제대로-보기">2. 시각화로 텐스 슬라이싱 제대로 보기</h2>
<p>본 글에서는 사람들의 이해를 돕기 위해 matplotlib로 시각화한다. Computer Vision 영역에서 제일 많이 사용되는 4D [B, H, W, C] 형태를 시각화 할 것이며, [3,3,3,3] 사이즈의 boolean 자료형의 입력을 사용한다. 이때 B는 batch 사이즈이고, H는 이미지의 높이, W는 이미지의 너비, C는 RGB로 판단한다. 따라서 $3^2$ 크기의 컬러 이미지 3개가 있는 상황이다.</p>

<p>기본적으로 <code class="language-plaintext highlighter-rouge">tensor.zeros().dtype(bool)</code> 로 4차원 False tensor를 생성하여 슬라이싱 되는 부분만 True로 변환하여 어떤 부분이 슬라이싱 되는지 시각화한다.</p>

<p>우리들의 천하무적 클로드가 <code class="language-plaintext highlighter-rouge">visualize_tensor()</code>라는 시각화 코드를 만들어줬다:</p>

<details>
<summary>코드 정보</summary>

<div>

    <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">matplotlib.patches</span> <span class="k">as</span> <span class="n">mpatches</span>
<span class="kn">from</span> <span class="nn">mpl_toolkits.mplot3d.art3d</span> <span class="kn">import</span> <span class="n">Poly3DCollection</span>

<span class="n">RGB_COLORS</span> <span class="o">=</span> <span class="p">[</span><span class="s">'red'</span><span class="p">,</span> <span class="s">'green'</span><span class="p">,</span> <span class="s">'blue'</span><span class="p">]</span>
<span class="n">RGB_LABELS</span> <span class="o">=</span> <span class="p">[</span><span class="s">'R'</span><span class="p">,</span> <span class="s">'G'</span><span class="p">,</span> <span class="s">'B'</span><span class="p">]</span>

<span class="k">def</span> <span class="nf">draw_cube</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">z</span><span class="p">,</span> <span class="n">filled</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'steelblue'</span><span class="p">):</span>
    <span class="n">vertices</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([</span>
        <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="p">],</span>   <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="p">],</span>   <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="p">],</span>   <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="p">],</span>
        <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span>   <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="n">x</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="n">x</span><span class="p">,</span>   <span class="n">y</span><span class="o">+</span><span class="mi">1</span><span class="p">,</span> <span class="n">z</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
    <span class="p">])</span>
    <span class="n">faces</span> <span class="o">=</span> <span class="p">[</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">3</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">4</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">6</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">7</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">5</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">4</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">7</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">6</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">7</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">4</span><span class="p">]],</span>
        <span class="p">[</span><span class="n">vertices</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">6</span><span class="p">],</span> <span class="n">vertices</span><span class="p">[</span><span class="mi">5</span><span class="p">]],</span>
    <span class="p">]</span>
    <span class="k">if</span> <span class="n">filled</span><span class="p">:</span>
        <span class="n">poly</span> <span class="o">=</span> <span class="n">Poly3DCollection</span><span class="p">(</span><span class="n">faces</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span>
                                <span class="n">facecolor</span><span class="o">=</span><span class="n">color</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'black'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">poly</span> <span class="o">=</span> <span class="n">Poly3DCollection</span><span class="p">(</span><span class="n">faces</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.03</span><span class="p">,</span>
                                <span class="n">facecolor</span><span class="o">=</span><span class="s">'white'</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'--'</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">add_collection3d</span><span class="p">(</span><span class="n">poly</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">visualize_tensor</span><span class="p">(</span><span class="n">tensor</span><span class="p">):</span>
    <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="s">'numpy'</span><span class="p">):</span>
        <span class="n">arr</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">.</span><span class="n">numpy</span><span class="p">().</span><span class="n">astype</span><span class="p">(</span><span class="nb">bool</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">arr</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">tensor</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

    <span class="k">assert</span> <span class="n">arr</span><span class="p">.</span><span class="n">shape</span> <span class="o">==</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="s">"Input must be [3,3,3,3]"</span>
    <span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">arr</span><span class="p">.</span><span class="n">shape</span>

    <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span> <span class="o">*</span> <span class="n">B</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span>

    <span class="k">for</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">B</span><span class="p">):</span>
        <span class="n">ax</span> <span class="o">=</span> <span class="n">fig</span><span class="p">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">B</span><span class="p">,</span> <span class="n">b</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">projection</span><span class="o">=</span><span class="s">'3d'</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">H</span><span class="p">):</span>
            <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">W</span><span class="p">):</span>
                <span class="k">for</span> <span class="n">c</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">C</span><span class="p">):</span>
                    <span class="n">draw_cube</span><span class="p">(</span><span class="n">ax</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">c</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span>
                              <span class="n">filled</span><span class="o">=</span><span class="nb">bool</span><span class="p">(</span><span class="n">arr</span><span class="p">[</span><span class="n">b</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="n">c</span><span class="p">]),</span>
                              <span class="n">color</span><span class="o">=</span><span class="n">RGB_COLORS</span><span class="p">[</span><span class="n">c</span><span class="p">])</span>  <span class="c1"># c=0→R, c=1→G, c=2→B
</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s">'C'</span><span class="p">,</span> <span class="n">labelpad</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s">'W'</span><span class="p">,</span> <span class="n">labelpad</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_zlabel</span><span class="p">(</span><span class="s">'H'</span><span class="p">,</span> <span class="n">labelpad</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span>

        <span class="n">ticks</span>  <span class="o">=</span> <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">,</span> <span class="mf">2.5</span><span class="p">]</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">ticks</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="n">RGB_LABELS</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>  <span class="c1"># R/G/B 표기
</span>        <span class="n">ax</span><span class="p">.</span><span class="n">set_yticks</span><span class="p">(</span><span class="n">ticks</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_yticklabels</span><span class="p">([</span><span class="s">'1'</span><span class="p">,</span> <span class="s">'2'</span><span class="p">,</span> <span class="s">'3'</span><span class="p">],</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_zticks</span><span class="p">(</span><span class="n">ticks</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_zticklabels</span><span class="p">([</span><span class="s">'1'</span><span class="p">,</span> <span class="s">'2'</span><span class="p">,</span> <span class="s">'3'</span><span class="p">],</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">7</span><span class="p">)</span>

        <span class="n">ax</span><span class="p">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">C</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">W</span><span class="p">);</span> <span class="n">ax</span><span class="p">.</span><span class="n">set_zlim</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">H</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s">'B=</span><span class="si">{</span><span class="n">b</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">view_init</span><span class="p">(</span><span class="n">elev</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">azim</span><span class="o">=-</span><span class="mi">60</span><span class="p">)</span>

    <span class="c1"># Legend: R/G/B + False
</span>    <span class="n">patches</span> <span class="o">=</span> <span class="p">[</span><span class="n">mpatches</span><span class="p">.</span><span class="n">Patch</span><span class="p">(</span><span class="n">facecolor</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'black'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s">'True (</span><span class="si">{</span><span class="n">l</span><span class="si">}</span><span class="s">)'</span><span class="p">)</span>
               <span class="k">for</span> <span class="n">c</span><span class="p">,</span> <span class="n">l</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">RGB_COLORS</span><span class="p">,</span> <span class="n">RGB_LABELS</span><span class="p">)]</span>
    <span class="n">patches</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">mpatches</span><span class="p">.</span><span class="n">Patch</span><span class="p">(</span><span class="n">facecolor</span><span class="o">=</span><span class="s">'white'</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s">'gray'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'False'</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'--'</span><span class="p">))</span>
    <span class="n">fig</span><span class="p">.</span><span class="n">legend</span><span class="p">(</span><span class="n">handles</span><span class="o">=</span><span class="n">patches</span><span class="p">,</span> <span class="n">loc</span><span class="o">=</span><span class="s">'lower center'</span><span class="p">,</span>
               <span class="n">ncol</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">frameon</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">bbox_to_anchor</span><span class="o">=</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">))</span>

    <span class="n">plt</span><span class="p">.</span><span class="n">suptitle</span><span class="p">(</span><span class="s">'[B, H, W, C] Tensor Slice Visualization'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">13</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="mf">1.01</span><span class="p">)</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
    <span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>


</code></pre></div>    </div>
  </div>
</details>

<p>전체 체크</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t1</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_7_1.webp" alt="png" /></p>

<p>두번째 이미지만 슬라이싱</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">t</span><span class="p">[</span><span class="mi">1</span><span class="p">,:,:,:]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_9_1.webp" alt="png" /></p>

<p>RGB 중 G만 시각화<br />
<code class="language-plaintext highlighter-rouge">...</code>(Ellipsis)는 “나머지 차원은 전부 : 로 채워줘” 라는 뜻이다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">t</span><span class="p">[:,:,:,</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>

<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">t</span><span class="p">[...,</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_11_1.webp" alt="png" /></p>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_11_3.webp" alt="png" /></p>

<p>십자가 모양으로 슬라이싱</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,...]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">t</span><span class="p">[...,</span><span class="mi">1</span><span class="p">,:]</span> <span class="o">=</span> <span class="bp">True</span>

<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_13_1.webp" alt="png" /></p>

<p>중심 부분만 슬라이싱</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">,:]</span> <span class="o">=</span> <span class="bp">True</span>

<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_15_1.webp" alt="png" /></p>

<p>멋지게 인덱싱 해보기</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># 안 멋진 방법
</span><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">t</span><span class="p">[:,</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">,</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>
<span class="n">t</span><span class="p">[:,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>

<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>

<span class="c1"># 멋진 방법
</span><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">idx</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span>
<span class="n">t</span><span class="p">[:,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">idx</span><span class="p">,</span> <span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span>  <span class="c1"># H=W=C 인 대각선
</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_17_1.webp" alt="png" /></p>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_17_3.webp" alt="png" /></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">,</span><span class="mi">3</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">bool</span><span class="p">)</span>

<span class="n">t</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span> <span class="c1"># R채널만, shape [B, H, W]
</span><span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>

<span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="n">T</span>
<span class="n">t</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="bp">True</span> <span class="c1"># R채널만, shape [B, H, W]
</span>
<span class="n">visualize_tensor</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
</code></pre></div></div>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_18_1.webp" alt="png" /></p>

<p><img src="/assets/img/code-review/tensor-slice/tensor-slice_18_3.webp" alt="png" /></p>

<h2 id="3-contiguous에-대하여">3. <code class="language-plaintext highlighter-rouge">contiguous()</code>에 대하여</h2>

<h3 id="31-메모리-레이아웃부터-이해하기">3.1. 메모리 레이아웃부터 이해하기</h3>

<p>PyTorch 텐서는 내부적으로 <strong>1D 메모리(storage)</strong> 위에 존재한다. 예를 들어 shape <code class="language-plaintext highlighter-rouge">[2, 3]</code> 텐서는 실제로 메모리에 이렇게 저장된다:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>메모리: [a, b, c, d, e, f]
         ↕
tensor([[a, b, c],
        [d, e, f]])
</code></pre></div></div>

<p>이때 “다음 원소로 가려면 몇 칸 건너뛰어야 하는가”를 <strong>stride</strong>라고 한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>

<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">],[</span><span class="mi">4</span><span class="p">,</span><span class="mi">5</span><span class="p">,</span><span class="mi">6</span><span class="p">]])</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>  <span class="c1"># (3, 1) → 행 이동시 3칸, 열 이동시 1칸
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(3, 1)
</code></pre></div></div>

<h3 id="32-슬라이싱-후-stride가-꼬이는-상황">3.2. 슬라이싱 후 stride가 꼬이는 상황</h3>

<p>이 슬라이싱은 <strong>메모리를 복사하지 않고</strong> stride/offset만 바꿔서 반환한다.
그 결과 메모리 상에서 원소들이 <strong>띄엄띄엄</strong> 놓이게 된다.</p>

<p>stride의 마지막 값이 1이 아니라는 건, 메모리에서 원소들이 연속적으로 붙어있지 않다는 뜻이다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span>
<span class="n">t_slice</span> <span class="o">=</span> <span class="n">t</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span>   <span class="c1"># C 채널 중 G만 추출 → shape [3,3,3]
</span>
<span class="k">print</span><span class="p">(</span><span class="n">t</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>               <span class="c1"># (27, 9, 3, 1)
</span><span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">stride</span><span class="p">())</span>         <span class="c1"># (27, 9, 3)  ← 마지막이 1이 아님!
</span><span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">())</span>  <span class="c1"># False
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>(27, 9, 3, 1)
(27, 9, 3)
False
</code></pre></div></div>

<h3 id="33-언제-문제가-터지나">3.3. 언제 문제가 터지나?</h3>

<p><code class="language-plaintext highlighter-rouge">view()</code>는 메모리가 연속적으로 배치되어 있다고 가정한다. 그래서 비연속 텐서에 <code class="language-plaintext highlighter-rouge">.view()</code>를 쓰면 에러가 발생한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">rand</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">2</span><span class="p">])</span>

<span class="c1"># transpose/permute는 stride 관계가 틀어져서 진짜 에러 발생
</span><span class="n">t_transposed</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>  <span class="c1"># stride 관계가 깨짐
</span><span class="k">print</span><span class="p">(</span><span class="n">t_transposed</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">())</span>  <span class="c1"># False
</span>
<span class="k">try</span><span class="p">:</span>
    <span class="n">t_transposed</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">except</span> <span class="nb">RuntimeError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="n">e</span><span class="p">)</span>

<span class="c1"># 해결
</span><span class="k">print</span><span class="p">(</span><span class="n">t_transposed</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">t_transposed</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>False
view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
tensor([0.0949, 0.1639, 0.6846, 0.3884, 0.6910, 0.5094, 0.1464, 0.4296])
tensor([0.0949, 0.1639, 0.6846, 0.3884, 0.6910, 0.5094, 0.1464, 0.4296])
</code></pre></div></div>

<h3 id="34-contiguous의-역할">3.4. <code class="language-plaintext highlighter-rouge">contiguous()</code>의 역할</h3>

<p><code class="language-plaintext highlighter-rouge">.contiguous()</code>는 <strong>메모리를 새로 할당하고 데이터를 연속된 형태로 복사</strong>한다.
이때 <strong>실제 copy가 발생</strong>하기 때문에 주의가 필요하다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">t_cont</span> <span class="o">=</span> <span class="n">t_slice</span><span class="p">.</span><span class="n">contiguous</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="n">t_cont</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">())</span>                      <span class="c1"># True
</span><span class="k">print</span><span class="p">(</span><span class="n">t_slice</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">()</span> <span class="o">==</span> <span class="n">t_cont</span><span class="p">.</span><span class="n">data_ptr</span><span class="p">())</span>     <span class="c1"># False → 다른 메모리
</span><span class="n">t_cont</span><span class="p">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>                                    <span class="c1"># 정상 작동
</span></code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>True
False

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])
</code></pre></div></div>

<h3 id="35-view-vs-reshape-정리">3.5. <code class="language-plaintext highlighter-rouge">view()</code> vs <code class="language-plaintext highlighter-rouge">reshape()</code> 정리</h3>

<table>
  <thead>
    <tr>
      <th> </th>
      <th><code class="language-plaintext highlighter-rouge">view()</code></th>
      <th><code class="language-plaintext highlighter-rouge">reshape()</code></th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>contiguous 필요</td>
      <td>✅ 반드시</td>
      <td>❌ 아니어도 됨</td>
    </tr>
    <tr>
      <td>동작 방식</td>
      <td>항상 view (zero-copy)</td>
      <td>contiguous면 view, 아니면 내부적으로 copy</td>
    </tr>
    <tr>
      <td>에러 발생</td>
      <td>비연속이면 RuntimeError</td>
      <td>없음</td>
    </tr>
  </tbody>
</table>

<p>실무에서는 보통 아래 두 패턴 중 하나를 쓴다:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># 패턴 1: 명시적으로 contiguous 보장 후 view
</span><span class="n">t_slice</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>

<span class="c1"># 패턴 2: reshape에 맡기기 (더 간편)
</span><span class="n">t_slice</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])
</code></pre></div></div>

<h3 id="36-실무에서-자주-만나는-케이스">3.6. 실무에서 자주 만나는 케이스</h3>

<p>Vision Transformer나 멀티헤드 어텐션 구현에서 특히 자주 나온다.
<code class="language-plaintext highlighter-rouge">transpose()</code>와 <code class="language-plaintext highlighter-rouge">permute()</code>는 <strong>항상 비연속 텐서를 반환</strong>하기 때문에, 이후에 <code class="language-plaintext highlighter-rouge">view()</code>를 쓸 계획이라면 <code class="language-plaintext highlighter-rouge">.contiguous()</code>를 습관적으로 붙여주는 것이 좋다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span>
<span class="n">feature_map</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>

<span class="c1"># [B, H, W, C] → 특정 채널 추 후 reshape
</span><span class="n">x</span> <span class="o">=</span> <span class="n">feature_map</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span>     <span class="c1"># shape [B, H, W], 비연속 가능성 있음
</span><span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># 안전하게 flatten출
</span>
<span class="c1"># transpose 후 reshape할 때
</span><span class="n">x</span> <span class="o">=</span> <span class="n">feature_map</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>  <span class="c1"># transpose는 항상 비연속!
</span><span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">contiguous</span><span class="p">().</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</code></pre></div></div>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="PyTorch-basic" /><summary type="html"><![CDATA[tensor slicing이란 무엇인지 깨닫고, 메모리를 고려하며 코딩하는 법]]></summary></entry><entry><title type="html">[논문리뷰] Interpretable fMRI Captioning via Contrastive Learning</title><link href="https://kitewatermelon.github.io/paper-review/brain-decoding-with-blip2/" rel="alternate" type="text/html" title="[논문리뷰] Interpretable fMRI Captioning via Contrastive Learning" /><published>2026-03-10T00:00:00+09:00</published><updated>2026-03-10T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/brain-decoding-with-blip2</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/brain-decoding-with-blip2/"><![CDATA[<blockquote>
  <p>MICCAI 2025 [<a href="https://papers.miccai.org/miccai-2025/paper/2049_paper.pdf">Paper</a>] [<a href="https://github.com/slavaheroes/brain-decoding-with-blip2">GitHub</a>]<br />
 Vyacheslav Shen, Kassymzhomart Kunanbayev, Donggon Jang, Daeshik Kim
 20 Sep 2025</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>뇌의 계층적 이미지 처리는 CNN 개발에 영감을 주었다. CNN 레이어 전체에 걸쳐 특징과 돌출 맵을 시각화하면 초기 레이어에서 엣지를 감지하고 깊어질수록 클래스 특화된 특징을 감지하는 것은 시각 피질의 기능과 유사하다. 더욱이 CNN-learned representations는 원숭이와 사람의 neural activity와 강한 상관관계가 있다. 이런 유사성 덕분에 neural activity로 DNN feature를 역으로 예측하는 방식으로 DNNs는 visual representations를 디코딩하는데 많이 사용된다.</p>

<p>Huthet al.은 fMRI 데이터를 단어 임베딩에 매핑하여 몇 시간 분량의 서술된 이야기를 디코딩할 수 있음을 보여주었으며, 최근에는 LDM을 이용하여 fMRI 데이터로부터 고해상도 자극 이미지를 reconstruction 하는 연구도 있었다. 한편, Transformer 아키텍처와 GPT-2는 neural activity로부터 자연어 재구성을 크게 향상시켰다. 그러나 생성된 출력물의 품질과 의미론적 일관성을 위해 추가적인 개성과 대안이 필요하다. 기존에는 brain activity의 시각 자극에서 이미지를 재구성하는 방식으로 접근했으나, 최근에는 multimodal deep learning이 대안을 제공한다. 신경 반응을 바로 textual descriptions로 디코딩하는 것인데 이를 fMRI captioning이라고 한다. 이런 관점에서 multimodal retrieval은 brain activity로부터 무엇이 보였고 근본적으로 의미론적인 내용을 유연하게 디코딩할 수 있다. fMRI-based decoding의 발전에도 불구하고 효율적으로 brain activity와 의미 있는 textual descriptions을 align하는 것은 아직 여러 문제가 있는데, 연산 효율, 의미론적 일관성 그리고 retrieval capabilities이다. 본 논문에선 contrastive learning을 통해 이 문제를 해결한다.</p>

<p>본 논문의 contribution은 다음과 같다.</p>
<ul>
  <li>연산 효율이 좋은 two-stage training을 도입하여 fMRI 데이터와 VL model(BLIP-2)을 align한다.</li>
  <li>synthetic fMRI patterns을 이용하여 interpretability decoding analysis를 제안한다.</li>
</ul>

<h3 id="11-related-work">1.1 Related Work</h3>
<p>CLIP (Contrastive Language-Image Pre-training)은 image 인코더와 text 인코더로 구성되며 multimodal model의 진보에 크게 기여했다. LDM의 reverse diffusion process에서 가이드를 하는 역할도 하고 VLMs에 LLMs과 visual data를 align 할 때도 사용한다.</p>

<p>이런 유능함에 힘입어 fMRI 신호로 CLIP의 image embedding을 예측하도록 하여 시각 자극을 재건하는 곳에 쓰인다. 그러나 breain decoding 연구에는 fMRI의 차원이 15,724로 충분히 고차원인데 conditional embedding 역시 257 × 768이나 257 × 1024 같은 고차원으로, 높은 연산량을 요구받는 어려움이 있다.</p>

<p>본 논문에서는 BLIP-2를 이용하여 visual embedding의 차원을 32 × 768로 compact하게 만든다. BLIP-2는 Q-Former(Querying Transformer)를 사용하여 이미지 인코더 기능을 LLM 임베딩 공간에 매핑한다. 압축 네트워크 역할을 하는 Q-Former는 대규모 frozen image features(257 × 1024)를 compact query tokens(32 × 768)으로 인코딩하여 뇌 디코딩에 적합한 텍스트 관련 및 의미론적으로 풍부한 이미지 표현을 보존한다.</p>

<h2 id="2-methodology">2 Methodology</h2>
<h3 id="21-dataset">2.1 Dataset</h3>
<p>Natural Scenes Dataset (NSD) 데이터셋을 사용한다. 이는 COCO dataset의 image를 각각 3초간 본 8명의 피험자의 7T fMRI 데이터셋이다. 기존 연구와 일관되도록 subj1의 데이터에서만 정량 분석을 한다. subj1은 모든 실험 시험을 완료하여 24,980개의 fMRI 시험(이미지당 최대 3회 반복)에 해당하는 8,859개의 훈련 이미지와 2,770개의 fMRI 시험이 포함된 982개의 테스트 이미지의 데이터 세트를 얻었다. 여러번 보여진 이미지에 대해서는 대응하는 fMRI trials에 대하여 평균을 취했다.</p>

<p>Ozcelik et al.을 따라 ridge regression을 사용한 GLM에서 억은 단일 실험 베타 가중치를 사용하여 fMRI를 처리했다. 시각 축을 따라 z-정규화했으며 NSDGeneral Regions-of-Interest (ROI) 마스크를 사용하여 15,764 복셀 벡터를 추출했다.</p>

<h3 id="22-fmri-captioning-with-blip-2">2.2 fMRI Captioning with BLIP-2</h3>
<p>본 논문에서는 textual descriptions from fMRI activity를 생성하기 위해 pre-trained BLIP-2를 사용했다. 이는 compact language-aligned image representations (32 × 768)을 제공하기 때문이다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/fig1.webp" width="80%" />
</center>
<p><br /></p>

<p>위 그림에서 볼 수 있듯이 feature extraction and Brain Model training으로 2 단계 프레임워크가 시작된다.</p>

<p>첫번째 단계에서는 stimulus image이 BLIP-2 이미지 인코더로 처리되고 BLIP-2 Q-Former안의 learned query vectors와 cross attention 하여 32 × 768의 최종 representation을 뽑는다. Brain Model은 ridge regression을 이용하여 fMRI activity(15,764 voxel)을 32 × 768의 최종 representation의 임베딩과 매핑한다.</p>

<p>두번째 단계에서는 retrieval을 위해 Brain Model의 출력과 text embeddings를 contrastive learning을 통해 align한다. GT caption은 BLIP-2 Q-Former’s self-attention 레이어를 통해 text embedding을 생성한다. image-text space와 Brain Model의 출력을 align하기 위해 linear projection layer를 도입한다. (fig 2 참고) 최종 loss는 다음과 같다.</p>

\[\begin{equation}
\mathcal L = \lambda_1\mathcal L MSE(b,i) + \lambda_2\mathcal L CLIP(b,t) + \lambda_3\mathcal L CLIP(i,t)
\end{equation}\]

<p>역할은 다음과 같다.</p>
<ol>
  <li>Mean Squared Error (MSE) loss: Brain Model’s predicted embeddings b 와 the GT image embeddings i의 alignment를 보존</li>
  <li>Brain-text contrastive loss: Brain Model’s outputs b 와 text embeddings t를 align해서 text retrieval 성능 향상</li>
  <li>Image-text contrastive loss: catastrophic forgetting 방지 및 t와 i의 일관성을 강화하며 robust image-text를 align</li>
</ol>

<h2 id="3-results--discussion">3 Results &amp; Discussion</h2>
<h3 id="31-retrieval">3.1 Retrieval</h3>
<blockquote>
  <p>Multimodal Retrieval이란?</p>
  <ul>
    <li>여러 종류의 데이터(뇌 신호, 이미지, 텍스트)를 서로 검색할 수 있는 능력</li>
    <li>예시</li>
  </ul>

  <table>
    <thead>
      <tr>
        <th>입력 (Query)</th>
        <th>검색 대상 (Retrieved)</th>
        <th>의미</th>
      </tr>
    </thead>
    <tbody>
      <tr>
        <td>fMRI 뇌 신호</td>
        <td>이미지 (B→I)</td>
        <td>“이 뇌 활동을 봤을 때 어떤 이미지를 본 거지?”</td>
      </tr>
      <tr>
        <td>이미지</td>
        <td>fMRI 뇌 신호 (I→B)</td>
        <td>“이 이미지를 봤을 때의 뇌 신호는 어느 것이지”</td>
      </tr>
      <tr>
        <td>fMRI 뇌 신호</td>
        <td>텍스트 (B→T)</td>
        <td>“이 뇌 활동을 설명하는 문장은 무엇이지?”</td>
      </tr>
      <tr>
        <td>텍스트</td>
        <td>fMRI 뇌 신호 (T→B)</td>
        <td>“이 문장에 해당하는 뇌 신호는 어느 것이지?”</td>
      </tr>
    </tbody>
  </table>
</blockquote>

<h4 id="image-and-brain-retrieval">image and brain retrieval</h4>
<p>이미지를 BLIP-2 Q-Former representation으로 만들고 fMRI-derived representation과 image embedding의 cosine similarity를 계산한다. MindEye-2의 eval protocol을 따라 300 sample의 top-1 retrieval accuracy를 측정한다. 보고된 결과는 30번의 시도에 대한 평균 정확도를 반영한다.</p>

<h4 id="textbrain-retrieval">text/brain retrieval</h4>
<p>text-aligned image embedding을 stage 2의 Brain Model을 이용하여 예측한다. caption embedding을 BLIP-2 Q-Former를 이용하여 얻으며 올바른지 확인하기 위해 cosine similarity를 계산한다. 50번의 시도에 대한 평균 정확도를 보고 한다.</p>

<p>성능은 다음 표와 같으며 T → B와 B → T가 가능 한 모델임을 보여준다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/tab1.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="32-fmri-captioning">3.2 fMRI Captioning</h3>
<p>BLIP-2에 구현되어 있는 OPT-2.7B decoder-only language model를 이용하여 textual descriptions을 생성한다. 6개 중 5개에서 다른 모델들을 stage 1에서도 이미 넘어섰으며 stage 2는 압도적인 성능을 보인다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/tab2.webp" width="80%" />
</center>
<p><br /></p>

<p>아래의 Figure 4는 정성적인 성능을 보여준다. Stage 1보다 Stage 2에서 구체적인 caption이 나왔다. (beach 보다 wave, horses보다 zebra 등…)</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/fig4.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="33-interpretability-analysis-of-roi-specific-fmri-signals">3.3 Interpretability Analysis of ROI-Specific fMRI Signals</h3>
<p>서로 다른 뇌 영역의 역할을 분석하기 위해 ROI-based interpretability analysis를 Brain Diffuser를 따라 한다. ROI의 voxel의 값을 1로 하고, 나머지를 0으로 만들어 synthetic fMRI 신호를 생성한다. Brain Model을 통해 처리 되고 정규화 후 11로 스케일되고 나서 caption 생성을 위해 language model을 통과한다. 아래 표는 그 결과이다.</p>

<center>
<img src="/assets/img/paper-review/brain-decoding-with-blip2/tab3.webp" width="80%" />
</center>
<p><br /></p>

<p>이 결과는 인간의 계층적, 모듈적 특성을 반영하는 시각 처리의 신경과학적 연구 결과와 일치한다. 예를 하나만 들자면 V1은 basic black-and-white features를 highlight한다. floc-words 영역은 텍스트 및 기호와 관련된 caption을 생성한다. 이런 결과는 Brain Diffuser의 결과와 일관되게 같다.</p>

<h2 id="4-conclusion">4 Conclusion</h2>
<p>본 논문에서는 연산 효율이 좋은 2 단계의 학습 프레임워크를 제안한다. contrastive learning을 도입하여 fMRI로 부터 정확한 captions을 생성하도록 하였으며, Vision-Language model representations과 brain activity를 align하여 multimnodal retrieval의 성능을 향상시켰다. ROI-optimal stimuli analysis는 decoding 과정에서 특정 뇌 영역의 contribuution을 식별하며 interpretability를 향상시켰다. 일반화 능력을 향상시키기 위해 cross-subject decoding에 초점을 두고, 적용 가능성을 향상시키기 위하여 multimodal generarion을 더 탐구하는 것을 future work로 두며 저자들은 글을 마무리 짓는다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<ul>
  <li>본 논문은 BLIP-2 Q-Former를 이용하여 연산 효율을 높이며 multimodal retrieval, fMRI captioning, Interpretability Analysis의 3가지 실험을 통해 우수성을 입증했다.</li>
  <li>새로운 데이터셋을 통해 fMRI가 질환 연구에만 사용되는 것이 아닌 신경과학 분야에서 뇌를 이해하기 위해 사용되는 것을 확인하며 fMRI의 범용성을 알 수 있었다.</li>
  <li>b, i에서는 왜 MSE를 사용하고, 나머지는 왜 CLIP loss를 사용하는지 이해하지 못했는데 이유는 다음과 같다.
    <blockquote>
      <p>MSE: Brain Model의 출력 b가 Image Embedding i와 “최대한 똑같은 벡터값”이 되길 원함 <br />
CLIP: Brain Model의 출력 b가 Text Embedding t와 “의미적으로 가까운 공간”에 있길 원함</p>
    </blockquote>
  </li>
  <li>이 논문 역시 작년에 직접 설명을 들었었는데, 배경지식의 부족으로 그저 지나친 논문중에 하나였다. 이제 공부를 해서 어느정도 이해를 할 수 있어서 기쁘다. 저자분은 한국말을 잘하셨다.</li>
</ul>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Medical-AI" /><category term="Brain-Decoding" /><category term="Contrastive-Learning" /><category term="MICCAI" /><summary type="html"><![CDATA[Interpretable fMRI Captioning via Contrastive Learning (MICCAI 2025)]]></summary></entry><entry><title type="html">[fMRI] BOLD</title><link href="https://kitewatermelon.github.io/study/BOLD/" rel="alternate" type="text/html" title="[fMRI] BOLD" /><published>2026-03-09T00:00:00+09:00</published><updated>2026-03-09T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/study/BOLD</id><content type="html" xml:base="https://kitewatermelon.github.io/study/BOLD/"><![CDATA[<h2 id="introduction">Introduction</h2>
<p>본 게시글은 <a href="https://ko.wikipedia.org/wiki/%ED%98%88%EC%95%A1-%EC%82%B0%EC%86%8C-%EB%86%8D%EB%8F%84-%EC%9D%98%EC%A1%B4_%EC%98%81%EC%83%81">BOLD wikipedia</a>의 글을 그대로 들고 와서 다시 읽었을 때 뜻을 이해할 수 있게 풀어 쓰는 것을 목표로 한다.</p>

<h2 id="bold">BOLD</h2>
<p>혈액-산소-농도-의존 영상 (blood-oxygen-level-dependent imaging) 또는 볼드-대조 영상 (BOLD-contrast imaging, 두드러진 대조 영상)은 기능적 자기 공명 영상(fMRI)에서 특정 시간에 활성화되는 것으로 밝혀진 뇌나 기타 기관의 다양한 영역을 관찰하는 데 사용되는 방법이다.</p>

<h2 id="이론">이론</h2>
<p>뉴런은 당 및 산소 형태의 내부 에너지 보유량이 없으므로 발화(firing)로 인해 더 많은 에너지가 빠르게 유입되어야 한다. 혈액은 혈역학적 반응이라는 과정을 통해 비활성 뉴런보다 더 빠른 속도로 활성 뉴런에 산소를 방출한다. 이로 인해 옥시헤모글로빈 및 디옥시헤모글로빈(산소화 또는 탈산소화 혈액)의 상대적 수준이 변경되며, 이는 차동 자기 감수율(differential magnetic susceptibility)을 기반으로 감지할 수 있다.</p>

<p>1990년에 오가와 세이지와 동료들이 발표한 세 편의 논문에서는 헤모글로빈이 산소화된 형태와 탈산소화된 형태에서 서로 다른 자기 특성을 가지고 있음을 보여주었다(탈산소화된 헤모글로빈은 상자성이고 산소화된 헤모글로빈은 반자성).</p>

<p>둘 모두 MRI를 사용하여 감지할 수 있다. 이로 인해 MRI 스캐너를 사용하여 감지할 수 있는 자기 신호 변화가 발생한다. 생각, 행동 또는 경험을 많이 반복한 후, 통계적 방법을 사용하여 결과적으로 이러한 차이가 더 많이 나는 뇌 영역을 결정할 수 있으며, 따라서 해당 생각, 행동 또는 경험 중에 뇌의 어느 영역이 가장 활동적인지 결정할 수 있다.</p>

<p>원리:</p>
<ul>
  <li>DeoxyHb (상자성) → 주변 자기장 왜곡</li>
  <li>OxyHb (반자성) → 자기장 왜곡 적음</li>
  <li>ex) 뇌가 활동하면 특정 부위에 산소 소비가 많아지고 이로 인해 혈류가 쏠려 oxyHb 비율이 높아 진다.</li>
</ul>

<h2 id="비판과-한계">비판과 한계</h2>
<p>대부분의 fMRI 연구에서는 뇌의 어느 부분이 가장 활동적인지 확인하는 방법으로 볼드 대비 영상(BOLD contrast imaging)을 사용하지만 신호는 상대적이고 개별적으로 정량적이지 않기 때문에 일부에서는 그 엄격성에 의문을 제기한다. 신경 활동을 직접 측정하기 위해 제안된 다른 방법도 시도되었다(예를 들어, 혈액 내 산소헤모글로빈이 얼마나 많은 탈산소헤모글로빈으로 전환되었는지를 측정하는 뇌 영역의 산소 추출 비율(oxygen extraction fraction, OEF) 측정). 그러나 활성 또는 발화 뉴런에 의해 생성된 전자기장은 너무 약하기 때문에 신호 대 잡음 비율(signal-to-noise ratio)이 매우 낮고 정량적 데이터를 추출하는 데 사용되는 통계적 방법은 지금까지 대체로 성공하지 못했다.</p>

<p>볼드 대조 영상에서 저주파 신호(low-frequency signal)를 폐기하는 일반적인 현상은 1995년에 오른손 움직임을 제어하는 뇌 영역의 “소음”이 왼손 움직임과 관련된 뇌의 반대쪽 영역의 유사한 활동과 동시에 변동하는 것이 관찰되면서 의문이 제기되었다.  볼드-대비 영상은 두개의 뇌 상태 간의 차이에만 민감하므로 이러한 연관된 변동을 분석하려면 휴식 상태 fMRI(resting state fMRI)라고 하는 새로운 방법이 필요했다.</p>]]></content><author><name>YSPARK</name></author><category term="Study" /><category term="fMRI" /><summary type="html"><![CDATA[[fMRI] BOLD]]></summary></entry><entry><title type="html">[논문리뷰] BrainLM: A foundation model for brain activity recordings</title><link href="https://kitewatermelon.github.io/paper-review/BrainLM/" rel="alternate" type="text/html" title="[논문리뷰] BrainLM: A foundation model for brain activity recordings" /><published>2026-03-05T00:00:00+09:00</published><updated>2026-03-05T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/BrainLM</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/BrainLM/"><![CDATA[<blockquote>
  <p>ICLR 2024 [<a href="https://openreview.net/pdf?id=RwI7ZEfR27">Paper</a>] [<a href="https://huggingface.co/vandijklab/brainlm">huggingface</a>]<br />
 Josue Ortega Caro, Antonio Henrique de Oliveira Fonseca, Syed A Rizvi, Matteo Rosati, Christopher Averill, James L Cross, Prateek Mittal, Emanuele Zappala, Rahul Madhav Dhodapkar, Chadi Abdallah, David van Dijk
 16 Jan 2024</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>뇌에서 어떻게 인지와 행동이 일어나는지 이해하는 것은 neuroscience research 분야에서 근본적인 chllenghes 중 하나이다. fMRI는 이를 해결할 도구중에 하나이다. 그러나 fMRI가 측정하는 BOLD 시그널은 뇌 기능의 간접적인 신호라서 해석에 어려움이 있다. 또한 fMRI는 시간과 공간에 모두 종속되는 복잡한 slatoptemporal(시공간) 역학을 나타낸다. 기존의 접근 방식은 이 복잡한 비선형적인 상호작용을 완전히 모델링하는데 실패했다.</p>

<p>기존의 fMRI analysis 기술들은 일반화 성능을 저해하며, 특정 task들에만 모델링했다. 또한 label이 없는 가용하며 풍부한 fMRI 데이터를 사용하는 것에 어려움을 겪었다. 본 논문에서는 NLP에서 foundation model의 획기적인 성공에 힘입어 대규모 데이터에 대하여 다목적 모델을 훈련하여  fine-tuning을 통해 downstream 기능을 활성화한다. BrainLM은 fMRI에 대한 첫 foundation model이다. Transformer 베이스의 모델을 사용하여 큰 스케일의 뇌 활성 데이터의 고유한 시공간 역학을 캡쳐한다.</p>

<center>
<img src="/assets/img/paper-review/brainlm/fig1.webp" width="80%" />
</center>
<p><br /></p>

<p>위 그림에서 볼 수 있듯이, unsupervised representation learning을 통해 다양한 downstream task에서도 일반적인 성능을 낼 수 있다. pretraining이 완료된 후에는 뛰어난 fine-tune 성능과 zero-shot 성능을 보인다.</p>

<h2 id="2-realted-work">2. Realted Work</h2>
<p>이전 연구에서는 fMRI recording을 분석하기 위한 다양한 ML 기술을 탐구했다. SVM, NN 등을 이용하여 자극 카테고리나 관심 변수를 회귀했다. 그러나 이 모델들은 특정 task를 위해 representation을 학습했으며 따라서 일반화에 어려움을 겪었다. 최근 연구는 일반화에 힘을 많이 썼다. recordings을 recon하는 autoencoder를 학습하는 방식을 사용하곤 했는데, 이는 적은 양의 데이터셋을 사용해서 robust한 일반화 능력을 갖지는 못했다. (약 1~2배 많은 데이터로 본 논문은 MAE를 이용한다.) 또한 다른 연구들이 LLM과 brain recordings간 representational similities를 찾는데 중심을 두었는데 이 부분의 연장 선상에서 본 연구에서는 LLM과 언어 처리 영역과 같은 특정 뇌 영역 사이의 높은 상관 관계가 발견되었다. 그러나 이 논문은 뇌 역학의 foundation model을 학습하거나 downstream biological tasks를 위해 모델을 fine-tune하는 데 중점을 두지 않는다.</p>

<h2 id="3-method">3. Method</h2>
<h3 id="31-dataset-ansd-preprocessing">3.1. Dataset ansd Preprocessing</h3>
<p>데이터셋으로 공공 데이터 중 큰 사이즈인 Human Connectome Project (HCP)과 UK Biobank(UKB)를 사용했다. UKB는 40~69세의 robust한 76,296 task-based and resting-state functional MRI (rs-fMRI) recordings를 가지고 있으며 이는 Siemens 3T scanner로  0.735s temporal resolution로 측정한 데이터셋이다. HCP는 건강한 성인의 1,002 high-quality fMRI recordings를 0.72s resolution으로 측정한 데이터셋이다.</p>

<p>BrainLM은 UKB 80%로 학습하고, 나머지 20%와 HCP 데이터로 검증한다. 모든 recordings는 데이터 준비를 위해 motion correction, normalization, temporal filtering, 그리고 ICA denoising 까지 standard preprocessing을 거쳤다. parcel-wise time series를 추출하기 위해 AAL-424 atlas를 사용했다. 이는 424D의 ~1Hz로 샘플링된 시퀀스를 산출하게 된다. 각 parcel별로 환자들을 Robust scaling을 함으로 전처리를 마무리 한다. 따라서 6,700시간의 77,298 recordings을 얻었으며 이런 큰 데이터셋이 BrainLM으로 하여금 robust functional representations을 학습하도록 한다.</p>

<h3 id="32-model-architecture--training-procedure">3.2. Model Architecture &amp; Training Procedure</h3>
<center>
<img src="/assets/img/paper-review/brainlm/fig2.webp" width="80%" />
</center>
<p><br /></p>

<p>본 모델은 전적으로 MAE 기반의 Transformer를 사용한다. 이것의 핵심에는 마스킹 된 패치의 원래 신호를 예측하는 것이다. 자세한 사항은 위 그림을 참조하면 된다.</p>

<p>학습 시에 랜덤으로 200 timestep의 subsequence를 선택하고 20 timestep으로 쪼개의 10개의 겹치지 않는 조각으로 나누다. 이 조각은 512D의 벡터로 변형되고 20%, 75%, 90% 이율로 마스킹한다. 100M 이상의 모델에 확장하기 위해 CV 쪽에서 영감을 받아 alternative masking 전략을 채택한다. 따라서 424 parcel 별로 랜덤하게 200 timestep의 window를 2D 이미지로 처리한다.이렇게 하면 brain region의 locality를 보존하고(뇌의 해부학적 인접성을 유지) multi-parcel encoding을 허용하여 total token 양을 줄인다(연산 효율 증대). parcel은 공간과 관련이 있기 때문에 single-parcel보다 mult-parcel이 더 복원하기 어렵다(공간적으로 상관된 parcel들을 묶어서 한꺼번에 masking하면 모델이 인접 정보를 “커닝”할 수 없어서 더 robust한 representation을 학습하게 됨, 이는 MAE에서 random masking보다 block masking이 더 어렵고 효과적인 것과 같은 원리). masking 되지 않은 조각들만이 4개의 레이어와 4개의 헤드를 가진 Transformer로 인코딩되고 이는 레이어 2개로 구성된 디코더로 복원된다.</p>

<p>batch size는 512, Adam optimizer, 100 epoch으로 MSE를 최소화 하도록 학습한다.</p>

<h3 id="33-clinical-variable-prediction">3.3. Clinical Variable Prediction</h3>
<p>다양한 태스크에 적용하기 위해 레이어 3개로 구성된 MLP head를 사용했으며, age, neuroticism, PTSD 그리고 anxiety disorder scores를 regress하도록 했다. age는 z-score로 정규화했으며 neuroticism은 min-max scaling했다. 나머지 두개는 log transformation을 사용하여 exponential distribution을 조절한 후 min-max scaling을 했다. 오버피팅을 피하기 위해 BrainLM과 MLP 모두 10%의 Dropout을 적용했다.</p>

<h2 id="4-results">4. Results</h2>
<h3 id="41-model-generalization">4.1. Model Generalization</h3>
<p>BrainLM이 UKB에서 $R^2$는 0.464로 같은 분포에서 unseen data에 대한 일반화 능력이 우수했으며, HCP에서는 0.278로 다른 분포에서도 잘 일반화되었다.</p>

<center>
<img src="/assets/img/paper-review/brainlm/fig3.webp" width="80%" />
</center>
<p><br /></p>

<p>아래 그림에서 볼 수 있듯이, 모델 확장성도 강력했다.</p>

<center>
<img src="/assets/img/paper-review/brainlm/fig4.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="42-clinical-variable-prediction">4.2. Clinical Variable Prediction</h3>
<p>foundation mdoel의 주요한 이점 중 하나는 특정한 downstream task들에 잘 fine-tune 가능한 것이다. latent space를 조사한 결과 pretrained BrainLM은 임상에서 중요한 정보를 적절히 인코딩한다는 사실이 드러났다. (아래 그림 참고) 이는 age, neuroticism, PTSD, 그리고 anxiety disorder scores 같은 예측 변수를 더 잘 예측한다는 뜻이다.</p>

<center>
<img src="/assets/img/paper-review/brainlm/fig5.webp" width="80%" />
</center>
<p><br /></p>

<p>저자들은 다른 저명한 방법들과 이를 비교했다. 인상깊게도 BrainLM은 모든 임상 변수에서 일관성있게 더 나은 결과를 냈다. (아래 표 참고)</p>

<center>
<img src="/assets/img/paper-review/brainlm/tab1.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="43-prediction-of-future-brain-states">4.3. Prediction of Future Brain States</h3>
<p>BrainLM이 시공간 역학을 잘 잡아내는지 검증하기 위해 저자들은 미래 뇌 상태를 예측하는 성능을 측정했다. UKB의 subset을 이용하여 미래 시점의 parcel의 활동을 예측하는 task를 수행할 수 있게 fine-tune했다. 180 timestep을 주고 나머지 20 timestep을 예측하도록 하여 LSTMs, NODE, non-pretrained BrainLM을 비교한다. 결과는 다음과 같다.</p>

<center>
<img src="/assets/img/paper-review/brainlm/fig6.webp" width="80%" />
</center>
<p><br /></p>

<p>fine-tuned BrainLM이 UKB와 HCP 모두에서 성능이 좋았다. 이는 BrainLM이 fMRI의 역학을 직관적으로 파악하는 robust한 기능을 highlight한다.</p>

<h3 id="44-interpretability-via-attention-analysis">4.4. Interpretability via Attention analysis</h3>
<p>BrainLM의 중요한 특징은 interpretability이다. self-attention weight를 시각화 함으로 모델 내부 representation에 대한 깊은 통찰을 얻을 수 있다. 각각의 parcel에 할당된 fMRI recordings의 cls 토큰을 평균 내어 이를 계산한다.</p>

<center>
<img src="/assets/img/paper-review/brainlm/fig7.webp" width="80%" />
</center>
<p><br /></p>

<p>위 그림에서 볼 수 있듯이 rs와 비교했을때 task recordings는 visual cortex에 뚜렷한 초점을 가지는 것을 볼 수 있다. 이는 tasks 중 발생하는 시각적 자극과 잘 align 된다. 더 나아가 PHQ-9라는 우울증 임상 척도를 기준으로 중증 우울증 환자의 fMRI를 BrainLM이 인코딩할 때, frontal lobe(전두엽-감정 조절, 의사결정, 실행 기능 저하와 연관) 와 limbic system(변연계-감정 처리, 스트레스 반응, 보상 회로 이상과 연관) 영역에 attention이 집중되었다.  BrainLM은 우울증 레이블을 직접 학습한 게 아님에도 임상적으로 의미 있는 뇌 영역을 자동으로 포착했다.</p>

<h3 id="45-functional-network-prediction">4.5. Functional Network Prediction</h3>
<p>마지막으로 저자들은 BrainLM가 network-based supervision 없이도 parcels를 fMRI의 활성화 패턴만으로 intrinsic functional brain network로 segment할 수 있는지 검증했다. parcels를 7개의 functional categories로 나누고, 1,000 UKB recordings에서 각 parcel을 7개 중 하나로 k-NN classifier을 이용하여 분류하는 작업을 거쳤다. 결과는 다음과 같다.</p>

<center>
<img src="/assets/img/paper-review/brainlm/tab3.webp" width="80%" />
</center>
<p><br /></p>

<p>BrainLM의 attention-driven approach가 다른 모델들을 압도 했으며 GCN이 가장 낮은 성능을 보였다. 이는 BrainLM가 label 없이 pre-training만으로 뇌의 functional topography를 내재적으로 학습하는 것을 의미한다.</p>

<h2 id="5-discussion">5. Discussion</h2>
<p>본 논문에서 BrainLM이라는 fMRI 분석을 위한 첫 foundation model을 도입했다. 6.7k시간의 풍부한 양의 brain activity recordings를 이용하여 modeling, predicting, interpreting 모두 좋은 성능을 보였다. BrainLM의 핵심은 fMRI recordings의 다른 분포의 데이터셋에서도 일반화된 representation이다. 또한 모델이 많은 파라미터를 가져도 잘 학습이 됐다.</p>

<p>BrainLM은 biomarker를 찾는 강력한 프레임워크를 제공한다. fine-tune을 통해 임상 변수와 psychiatric disorders를 예측할 수 있다. 이는 rs-fMRI 만으로 비침습적인 인지 건강 평가를 가능하게 한다. 마지막으로 network-based supervision 없이도 뇌의 intrinsic functional connectivity map을 바로 구분할 수 있었다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<ul>
  <li>본 논문은 fMRI 분석을 위한 첫 foundation model이다. 예측 성능뿐 아니라 attention map을 이용한 해석 가능성과 functional network prediction을 하는 것은 매우 놀라웠다.</li>
  <li>baseline model이 다소 예전 모델들이지만, 괄목할 만한 성능을 냈다.</li>
  <li>도메인을 넘어선 self-supervised learning에서 masking approach의 강력함을 다시금 깨달았다. 이는 robust하고 generalized representation을 만들며, 모델 파라미터에 갯수에 구애받지 않는 아주 좋은 전략인 것 같다.</li>
  <li>아직 fMRI 관련 논문을 찾는 단계라 모르는 부분이 많았는데, 학습 과정과 전처리 과정이 잘 나와있어서 좋았다.</li>
</ul>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Medical-AI" /><category term="Foundation-Model" /><category term="Contrastive-Learning" /><category term="ICLR" /><summary type="html"><![CDATA[BrainLM: A foundation model for brain activity recordings (ICLR 2024)]]></summary></entry><entry><title type="html">[코드 리뷰] Contrastive Loss - InfoNCE(NT-Xent)</title><link href="https://kitewatermelon.github.io/code-review/contrastive-loss/" rel="alternate" type="text/html" title="[코드 리뷰] Contrastive Loss - InfoNCE(NT-Xent)" /><published>2026-03-03T00:00:00+09:00</published><updated>2026-03-03T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/contrastive-loss</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/contrastive-loss/"><![CDATA[<p>code 다운로드: <a href="/assets/code/code-review/contrastive-loss.ipynb">📥 contrastive-loss.ipynb 다운로드</a></p>

<h3 id="1-introduction">1. Introduction</h3>

<p>“같은 것은 가까이, 다른 것은 멀리” 라는 아이디어에서 시작된 Contrastive learning은 Deep learning 분야에서 많이 각광 받아온 학습 방법이다.</p>

<p>Self-supervised learning 분야에서 주로 사용되며 2010년 말 빅테크기업에서 압도적인 연산력을 토대로 만들어낸 MoCo, <a href="/paper-review/simclr/">SimCLR</a> 같은 모델이 대표적인 예시이며, Multi-modal 학습 방법인 OpenAI의 CLIP의 C도 Contrastive이다.</p>

<p>Contrastive Loss는 여러가지가 있지만, 본 게시글에서는 <a href="/paper-review/simclr/">SimCLR</a> 논문에서 사용된 가장 많이 사용되는 Loss 중 하나인 InfoNCE 계열의 NT-Xent loss를 구현하고 코드 리뷰를 할 것이다.</p>

<p>Contrastive Loss에 대해 입문하거나 관심이 있다면 <a href="https://proceedings.mlr.press/v119/wang20k/wang20k.pdf">Understanding Contrastive Representation Learning through Alignment and Uniformity on the Hypersphere</a>를 한번 읽어보는 것을 추천한다.</p>

<h3 id="2-preliminary">2. Preliminary</h3>

<p><a href="/paper-review/simclr/">SimCLR</a>은 2020년 구글에서 발표한 self-supervised learning 논문으로 한 sample에 두개의 augmented view를 만들어서 같은 이미지끼리는 큰 cosine similarity를 가지게 하고 서로 다른 이미지 샘플끼리는 낮은 cosine similarity를 가지도록 하는 방식으로 학습한다.</p>

<p>우선 수식부터 알아보자</p>

\[\begin{equation}
l(i, j)=-log\frac{exp(sim(z_i, z_j)/\tau)}{\sum^{2N}_{k=1} \mathbb{1}_{[k \neq i]}exp(sim(z_i, z_k)/\tau) }
\end{equation}\]

<p>위 식을 흐린눈 하고 보면 익숙한 수식이 보인다. 많은 블로거들이 외치듯 위 loss는 softmax의 형태를 하고 있다. 이 수식은 positive pair와 negative pair로 구분하도록 한다.</p>
<hr />

<p>N은 batch size이다. 분모항의 sumation에서 2N이 나타난 이유는 모든 샘플에 있어서 2개의 augmented view를 생성하기 때문이다. 여기서 i와 j는 같은 sample의 서로 다른 augmented view이고 sim은 cosine similarity이다.</p>

<p>즉 i, j는 같은 sample이므로 positive pair가 되고 분모는 i가 자기 자신을 제외한 positive, negative pair의 cosine similarity를 지수승하여 합한 값이다.</p>

<p>$\tau$는 분포를 좀 더 뾰족하거나 완만하게 만들어준다.</p>

<h3 id="implementation">implementation</h3>

<p><a href="https://github.com/sthalles/SimCLR/blob/master/simclr.py">참고 코드</a>에서 class 내의 함수를 꺼내서 사용할 수 있도록 아주 약간 수정했으며 핵심 메커니즘은 그대로임을 미리 알린다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>    
<span class="kn">import</span> <span class="nn">torch.nn.functional</span> <span class="k">as</span> <span class="n">F</span>

<span class="k">def</span> <span class="nf">NTXentLoss</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">n_views</span><span class="p">,</span> <span class="n">temperature</span><span class="p">,</span> <span class="n">device</span><span class="p">):</span>
    <span class="c1"># classification 문제이기 때문에 labels을 생성함.
</span>    <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n_views</span><span class="p">)],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)).</span><span class="nb">float</span><span class="p">()</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>

    <span class="c1"># cosine similarity를 global하게 계산함.
</span>    <span class="n">features</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># 입력은 [B, D]
</span>    <span class="n">similarity_matrix</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">features</span><span class="p">.</span><span class="n">T</span><span class="p">)</span>

    <span class="c1"># discard the main diagonal from both: labels and similarities matrix 
</span>    <span class="c1"># diagonal 원소는 본인이기 때문에 k!=i를 지키기 위해 masking
</span>    <span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">bool</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
    
    <span class="c1"># view 함수는 tensor 형태의 데이터의 shape을 바꿔주는 함수 - Shared Data / Memory Efficiency
</span>    <span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="o">~</span><span class="n">mask</span><span class="p">].</span><span class="n">view</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># ~mask은 torch에서 True인 원소들을 없애고 값을 반환하므로 2N x (2N-1)짜리 행렬이 나오게 됨
</span>    <span class="n">similarity_matrix</span> <span class="o">=</span> <span class="n">similarity_matrix</span><span class="p">[</span><span class="o">~</span><span class="n">mask</span><span class="p">].</span><span class="n">view</span><span class="p">(</span><span class="n">similarity_matrix</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    
    <span class="c1"># 굳이 아래처럼 positives와 negatives를 일렬로 작업하는 이유는 cross entropy 계산 시 label을 항상 0번째 열로 고정할 수 있기 때문임.
</span>    <span class="c1"># select and combine multiple positives
</span>    <span class="n">positives</span> <span class="o">=</span> <span class="n">similarity_matrix</span><span class="p">[</span><span class="n">labels</span><span class="p">.</span><span class="nb">bool</span><span class="p">()].</span><span class="n">view</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="c1"># select only the negatives the negatives
</span>    <span class="n">negatives</span> <span class="o">=</span> <span class="n">similarity_matrix</span><span class="p">[</span><span class="o">~</span><span class="n">labels</span><span class="p">.</span><span class="nb">bool</span><span class="p">()].</span><span class="n">view</span><span class="p">(</span><span class="n">similarity_matrix</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

    <span class="n">logits</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">positives</span><span class="p">,</span> <span class="n">negatives</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">logits</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">long</span><span class="p">).</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>

    <span class="c1"># temperature 계산
</span>    <span class="n">logits</span> <span class="o">=</span> <span class="n">logits</span> <span class="o">/</span> <span class="n">temperature</span>
    <span class="k">return</span> <span class="n">logits</span><span class="p">,</span> <span class="n">labels</span>
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">BATCH_ZISE</span> <span class="o">=</span> <span class="n">N_VIEW</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># Batch가 2이고 augmentated view가 2개 있다고 가정했을 때
</span>
<span class="c1"># labels 구성은 다음과 같다.
</span><span class="n">labels</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">torch</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="n">BATCH_ZISE</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">N_VIEW</span><span class="p">)],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">labels</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0, 1, 0, 1])
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># unsueeze는 추가하는 차원축의 인덱스를 지정할 수 있다.
</span><span class="k">print</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[0, 1, 0, 1]])
tensor([[0],
        [1],
        [0],
        [1]])
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">labels</span> <span class="o">=</span> <span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">==</span> <span class="n">labels</span><span class="p">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)).</span><span class="nb">float</span><span class="p">()</span>
<span class="n">labels</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[1., 0., 1., 0.],
        [0., 1., 0., 1.],
        [1., 0., 1., 0.],
        [0., 1., 0., 1.]])
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># torch.eye는 대각 성분만 1이고 나머지는 0인 텐서를 만든다.
</span><span class="n">mask</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">eye</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="nb">bool</span><span class="p">)</span>
<span class="n">mask</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># 마스킹을 하면 다음과 같이 True인 부분만 남기고 False인 부분은 없앤다. 이는 torch 특유의 연산이다.
</span><span class="n">m</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="bp">True</span><span class="p">,</span> <span class="bp">False</span><span class="p">])</span>
<span class="n">l1</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">])</span>
<span class="n">l1</span><span class="p">[</span><span class="n">m</span><span class="p">]</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([1])
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="o">~</span><span class="n">mask</span><span class="p">].</span><span class="n">view</span><span class="p">(</span><span class="n">labels</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="n">labels</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([[0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]])
</code></pre></div></div>

<h3 id="4-low-loss-high-loss-collapse-loss">4. Low Loss, High Loss, Collapse Loss</h3>
<p>이제 구현한 Loss를 활용하여 언제 Loss가 높아지고 낮아지는지 그리고 붕괴는 어떻게 일어나는지 알아보자.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="c1"># ── Low Loss ──────────────────────────────────────────────────────────────────
</span><span class="n">z1_low</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">]])</span>
<span class="n">z2_low</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">]])</span>  <span class="c1"># z1이랑 거의 같은 방향
</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z1_low</span><span class="p">,</span> <span class="n">z2_low</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">NTXentLoss</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="s">'cpu'</span><span class="p">)</span>
<span class="n">loss_low</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Low  Loss: </span><span class="si">{</span><span class="n">loss_low</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Low  Loss: 0.0003
</code></pre></div></div>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># ── High Loss ─────────────────────────────────────────────────────────────────
</span><span class="n">z1_high</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">]])</span>
<span class="n">z2_high</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="o">-</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]])</span>  <span class="c1"># 정반대 방향
</span>
<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z1_high</span><span class="p">,</span> <span class="n">z2_high</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">NTXentLoss</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="s">'cpu'</span><span class="p">)</span>
<span class="n">loss_high</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"High Loss: </span><span class="si">{</span><span class="n">loss_high</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>High Loss: 20.0002
</code></pre></div></div>

<p>Collapse란 model이 입력과 무관하게 상수의 값을 내뱉는 상황이다. Collapse가 일어나면 model의 representation이 저하된다. 따라서 InfoNCE는 큰 배치를 이용하여 분모를 키워 collapse penalty를 강하게 함으로써 collapse가 쉽게 일어나지 않도록 한다. 이것이 <a href="/paper-review/simclr/">SimCLR</a>이 대용량 배치(e.g. 4096)를 필요로 하는 이유이기도 하다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># ── Collapse ─────────────────────────────────────────────────────────────────
</span><span class="n">z1_collapse</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">])</span>
<span class="n">z2_collapse</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">],</span> <span class="p">])</span>

<span class="n">z</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cat</span><span class="p">([</span><span class="n">z1_collapse</span><span class="p">,</span> <span class="n">z2_collapse</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">logits</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">NTXentLoss</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="s">'cpu'</span><span class="p">)</span>
<span class="n">loss_collapse</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">).</span><span class="n">item</span><span class="p">()</span>

<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Collapse Loss: </span><span class="si">{</span><span class="n">loss_collapse</span><span class="si">:</span><span class="p">.</span><span class="mi">4</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Collapse Loss: 1.9459
</code></pre></div></div>

<p>만능 클로드의 힘을 빌려 시각화를 하며 글을 마무리한다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="c1"># ── 시각화 ────────────────────────────────────────────────────────────────────
</span><span class="n">COLORS</span> <span class="o">=</span> <span class="p">[</span><span class="s">'tab:orange'</span><span class="p">,</span> <span class="s">'tab:blue'</span><span class="p">,</span> <span class="s">'tab:purple'</span><span class="p">,</span> <span class="s">'tab:green'</span><span class="p">]</span>
<span class="n">theta</span>  <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">np</span><span class="p">.</span><span class="n">pi</span><span class="p">,</span> <span class="mi">300</span><span class="p">)</span>

<span class="n">fig</span><span class="p">,</span> <span class="p">(</span><span class="n">ax1</span><span class="p">,</span> <span class="n">ax2</span><span class="p">,</span> <span class="n">ax3</span><span class="p">)</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>

<span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="n">title</span><span class="p">,</span> <span class="n">loss_val</span><span class="p">,</span> <span class="n">z1</span><span class="p">,</span> <span class="n">z2</span> <span class="ow">in</span> <span class="p">[</span>
    <span class="p">(</span><span class="n">ax1</span><span class="p">,</span> <span class="sa">f</span><span class="s">"Low Loss  (</span><span class="si">{</span><span class="n">loss_low</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">)"</span><span class="p">,</span>      <span class="n">loss_low</span><span class="p">,</span>      <span class="n">z1_low</span><span class="p">,</span>      <span class="n">z2_low</span><span class="p">),</span>
    <span class="p">(</span><span class="n">ax2</span><span class="p">,</span> <span class="sa">f</span><span class="s">"High Loss (</span><span class="si">{</span><span class="n">loss_high</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">)"</span><span class="p">,</span>     <span class="n">loss_high</span><span class="p">,</span>     <span class="n">z1_high</span><span class="p">,</span>     <span class="n">z2_high</span><span class="p">),</span>
    <span class="p">(</span><span class="n">ax3</span><span class="p">,</span> <span class="sa">f</span><span class="s">"Collapse  (</span><span class="si">{</span><span class="n">loss_collapse</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">)"</span><span class="p">,</span> <span class="n">loss_collapse</span><span class="p">,</span> <span class="n">z1_collapse</span><span class="p">,</span> <span class="n">z2_collapse</span><span class="p">),</span>
<span class="p">]:</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">cos</span><span class="p">(</span><span class="n">theta</span><span class="p">),</span> <span class="n">np</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">theta</span><span class="p">),</span> <span class="s">'lightgray'</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_aspect</span><span class="p">(</span><span class="s">'equal'</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_xlim</span><span class="p">(</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_ylim</span><span class="p">(</span><span class="o">-</span><span class="mf">1.5</span><span class="p">,</span> <span class="mf">1.5</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">set_title</span><span class="p">(</span><span class="n">title</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">13</span><span class="p">)</span>
    <span class="n">ax</span><span class="p">.</span><span class="n">grid</span><span class="p">(</span><span class="bp">True</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span>

    <span class="n">z1n</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">z1</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">numpy</span><span class="p">()</span>
    <span class="n">z2n</span> <span class="o">=</span> <span class="n">F</span><span class="p">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">z2</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">).</span><span class="n">numpy</span><span class="p">()</span>

    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">):</span>
        <span class="n">c</span> <span class="o">=</span> <span class="n">COLORS</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">annotate</span><span class="p">(</span><span class="s">""</span><span class="p">,</span> <span class="n">xy</span><span class="o">=</span><span class="n">z1n</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">xytext</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">),</span>
                    <span class="n">arrowprops</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">arrowstyle</span><span class="o">=</span><span class="s">"-|&gt;"</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">))</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">annotate</span><span class="p">(</span><span class="s">""</span><span class="p">,</span> <span class="n">xy</span><span class="o">=</span><span class="n">z2n</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">xytext</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span><span class="mi">0</span><span class="p">),</span>
                    <span class="n">arrowprops</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">arrowstyle</span><span class="o">=</span><span class="s">"-|&gt;"</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">linestyle</span><span class="o">=</span><span class="s">'dashed'</span><span class="p">))</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">text</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">z1n</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1.25</span><span class="p">),</span> <span class="sa">f</span><span class="s">"$z_</span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">$"</span><span class="p">,</span>  <span class="n">color</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s">'center'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>
        <span class="n">ax</span><span class="p">.</span><span class="n">text</span><span class="p">(</span><span class="o">*</span><span class="p">(</span><span class="n">z2n</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="mf">1.25</span><span class="p">),</span> <span class="sa">f</span><span class="s">"$z'_</span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s">$"</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="n">c</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s">'center'</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">11</span><span class="p">)</span>

<span class="n">plt</span><span class="p">.</span><span class="n">tight_layout</span><span class="p">()</span>
<span class="n">plt</span><span class="p">.</span><span class="n">savefig</span><span class="p">(</span><span class="s">"ntxent_circle.png"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<center>
<img src="/assets/img/code-review/contrastive-loss/fig1.webp" width="80%" />
</center>
<p><br /></p>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="Contrastive-Learning" /><summary type="html"><![CDATA[Contrastive Loss 코드 리뷰]]></summary></entry><entry><title type="html">[논문리뷰] Learning 3D Medical Image Models From Brain Functional Connectivity Network Supervision For Mental Disorder Diagnosis</title><link href="https://kitewatermelon.github.io/paper-review/cinp/" rel="alternate" type="text/html" title="[논문리뷰] Learning 3D Medical Image Models From Brain Functional Connectivity Network Supervision For Mental Disorder Diagnosis" /><published>2026-02-27T00:00:00+09:00</published><updated>2026-02-27T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/paper-review/cinp</id><content type="html" xml:base="https://kitewatermelon.github.io/paper-review/cinp/"><![CDATA[<blockquote>
  <p>MICCAI 2025 [<a href="https://papers.miccai.org/miccai-2025/paper/4296_paper.pdf">Paper</a>] <br />
 Xingcan Hu, Wei Wang, Li Xiao
 Thu, 6 Mar 2025</p>
</blockquote>

<h2 id="1-introduction">1. Introduction</h2>
<p>blood-oxygen-level-dependent (BOLD)를 활용하여 뇌의 신경활동을 포착하는 fMRI는 대표적인 인간 뇌의 기능과 관련한 다양한 행동 및 인지적 특성에 관련한 비침습적 neuroimaging 기법이다. 최근엔 brain ROIs를 노드로, 각 노드의 기능적 연결을 엣지로 한 그래프인 Functional Connectivity Network(FCN)가 mental disorder를 진단하는데 많은 주목을 받고 있다. 지금까지는 FCN을 GNNs, CNNs, Graph Transformer로 학습을 해왔다. 그러나 아직 임상에는 널리 사용되지 않았는데 이는 fMRI 데이터 자체의 부족과 structure MRI(sMRI)에서 쉽게 얻을 수 있는 anatomical information과의 통합이 부족하다는 두가지 원인이 있다. anatomical structure가 근본적으로 뇌의 기능을 제한하기 때문에 mental disorder 진단에 더 나은 정밀도로 이끌 수 있다.</p>

<p>image-text pair의 contrastive learning의 성공에서 영감을 받아 의료쪽에서도 image-text, sMRI-fMRI 간의 멀티 모달 학습이 이루어지고 있다. 그러나 특정 모델을 사용해야 하며, 데이터가 부족하여 feature의 coarse representation로 인해 확장성이 떨어진다. 이와 관련하여 이 논문은 확장성에 초점을 두어 subject-level에서 sMRI와 fMRI를 contrastive pre-training한다. 3D T1w MRI-FCN 4619 쌍을 구성하고 Contrastive Image-Network Pre-training (CINP) framework를 구성하여 FCN Supervision을 통해 3D T1w MRI 영상의 representation을 학습한다.</p>

<h2 id="methods">Methods</h2>
<h3 id="21-contrastive-image-network-pre-training">2.1 Contrastive Image-Network Pre-training</h3>

<center>
<img src="/assets/img/paper-review/cinp/fig1.webp" width="80%" />
</center>
<p><br /></p>

<p>fig. 1에 명시되어 있듯 CINP는 visual encoder, visual decoder, network encoder로 구성되어 있다. visual encoder로는 Self-supervised pre-training of swin transformers for 3d medical image analysis 논문에서 weight로 초기화를 한다. 입력 3D T1w MRI 이미지 $I$는 전체 voxel의 30%
를 랜덤 마스크로 가려 masked MRI 이미지 $I^\ast$를 만든다. encoder를 통해 normalized image embedding $v_I$와 normalized masked image embedding $v^{\ast}_I$를 얻는다. (각 768D) visual decoder는 위의 언급한 논문의 구조를 따르며 $v^{\ast}_I$를 $\hat I$로 reconstruct 한다. Brain Network Transformer(BNT)를 network encoder로 사용하며 이는 FCN을 768 D의 normalized network embedding $w_N$로 인코딩한다.</p>

<p>Image-Network Contrastive Learning(CINP)는 CLIP의 성공에 영감을 받아 image-network 간 contrastive learning을 통해 FCN Supervision으로 활용하여 3D T1w MRI의 representation을 강화하는 학습 방법이다. 자세히 설명하면 similarity score $s(I, N)=v_I^T w_N $가 있을때 같은 subject이면 더 높은 score를 다른 subject이면 더 낮은 score를 받도록 학습한다. 배치에서 각 image와 network는 softmax-normalized image-to-network와 network-to-image similarity는 다음과 같이 계산된다.</p>

\[\begin{equation}
p_k^{in}(I)=\frac{exp(s(I,N_k)/\tau)}{\sum^K_{k=1}exp(s(I,N_k)/\tau)}
\quad and \quad
p_k^{ni}(N)=\frac{exp(s(N,I_k)/\tau)}{\sum^K_{k=1}exp(s(N,I_k)/\tau)}
\end{equation}\]

<p>여기서 $\tau$는 학습가능한 파라미터이고 K는 배치 사이즈이다. $y^{in}$과 $y^{ni}$를 모든 image와 network의 GT(positive는 1, negative는 0)라고 할때, cross entropy $H(\cdot , \cdot)$를 통헤 image-network contrastive(INC) loss를 다음과 같이 구성한다.</p>

\[\begin{equation}
\mathcal{L}_{INC}=\frac{1}{2}\mathbb{E}_{(I,N)~D}[H(y^{ni}(I),p^{ni}(I))+H(y^{in}(N),p^{in}(N))]
\end{equation}\]

<p>Masked image modeling (MIM)은 MRI 이미지의 robust representation을 구성하는 것을 목표로 한다. $L_1$를 이용하여 학습한다.</p>

\[\begin{equation}
\mathcal{L}_{MIM}=\mathbb{E}_{(I,\hat I)~D}||I- \hat I ||_1
\end{equation}\]

<p>Image-network matching (INM)은 image-network pair가 같은 subject에서 왔는지 파악하는 이진 분류 문제이다. $v_I$와 $w_N$을 concatenation 하여 fully-connected layer를 통과하여 이진 분류 확률인 $q$를 얻는다.
\(\begin{equation}
\mathcal{L}_{INM}=\mathbb{E}_{(I,N)~D}[H(z_{INM},q(I,N))]
\end{equation}\)</p>

<p>이때 $z_{INM}$은 GT label로 2D one-hot vector이다. 
위의 세 Lossf를 합하여 CINP loss는 다음과 같이 구성된다.</p>

\[\begin{equation}
\mathcal{L} = \mathcal{L}_{INC}+\alpha\mathcal{L}_{MIM}+\beta\mathcal{L}_{INM}
\end{equation}\]

<h3 id="22-network-prompting">2.2 Network prompting</h3>

<center>
<img src="/assets/img/paper-review/cinp/fig2.webp" width="80%" />
</center>
<p><br /></p>

<p>기존의 연구들은 linear probe나 fine-tunning을 사용하여 downstream task에 사용한다. 이는 많은 양의 annotated data를 요구하는데, fMRI는 정기적으로 수집되지 않는 문제가 있다. 저자들은 이 문제를 해결하기 위해 서로 다른 subject에서 FCNs이 significant한 group-level 차이를 보이는 것에서 영감을 받아 network prompting을 사용한다(fig. 2). 이들은 pre-trained CINP가 3D T1w MRI 이미지 embedding과 FCN embedding이 공동의 semantic space에서 학습됐기에 이들의 similarity를 측정할 수 있다는 가정을 한다.</p>

<p>k subject class에서 network embedding set $\mathcal{U}={\mathcal{C_1},\mathcal{C_2},…\mathcal{C_k}}$을 얻는다. 이때 $\mathcal{U}$ 의 각 원소는 n개의 FCNs를 포함한다. $\mathcal{C_l}= { w^l_1, w^l_2, … w^l_n }$ for $l=1,2,…k$. 또한 각 class 별로 FCNs을 r개의 같은 크기의 분리된 하위 집합으로 나누면 다음과 같이 표기할 수 있다.
\(\mathcal{C}_l = \bigcup_{i=1}^{r} \mathcal{C}_l^i,
\quad
\mathcal{C}_l^i \cap \mathcal{C}_l^j = \emptyset \ (i \ne j),
\quad
|\mathcal{C}_l^i| = \frac{|\mathcal{C}_l|}{r}.\)</p>

<p>같은 subset 안의 network embedding은 평균이 내어져서 r group-level reference network embeddings를 구성한다: $\mathcal{\bar U}={{\bar w^i_1, \bar w^i_2,…\bar w^i_r,}|l=1,2,…,k},$ where $\bar w^i_l=\frac{1}{|\mathcal{C^i_l}} \sum_{w \in \mathcal{c^i_l}}w$ for $i=1,2,..,r$, 이는 subject-level biases를 지우는 것을 도와준다. 그리고 난 후에 3D T1w MRI의 image embedding인 $v$를 reference network embedding과 유사도인 
$\mathcal{S}={{s^l_1,s^l_2,…,s^l_r}|l=1,2,…,k}$ where 
$s^l_i=v^T\bar w^l_i$ for $i=1,2,…r$를 구한다. 각 class의 FCNs와의 평균 similarity $\mathcal{\bar S}={\bar s^1,\bar s^2,…,\bar s^r}$, where $\bar s^l = \frac{1}{r} \sum_{i=1}^r s^l_i$ for $l=1,2,…k,$를 구하고 가장 높은 avg. similarity를 이용하여 환자에게 class를 할당한다.</p>

<h2 id="3-experiments">3. Experiments</h2>
<h3 id="31-experimental-settings">3.1 Experimental Settings</h3>
<p>저자들은 공공 데이터를 이용하여 3D T1w MRI and resting-state fMRI (rs-fMRI) 데이터셋을 구성했다. 4개의 데이터셋을 pre-training에 사용한다(HBM, HCP, QTIM, CNP). evaluation에는 ABIDE, ADHD, SRPBS의 세개의 데이터셋을 이용한다. fMRIPrep preprocessing pipeline을 이용하여 전처리 했으며 모든 mri의 복셀 크기는 $2\times2\times2 mm^3$로 preprocessing 되었다. rs-fMRI로부터 FCNs를 뽑기 위해 116개의 ROIs를 가진 AAL atals를 피어슨 상관계수로 사용했다. 자세한 데이터에 관한 설명은 아래 표를 참고하면 된다.</p>

<center>
<img src="/assets/img/paper-review/cinp/tab1.webp" width="80%" />
</center>
<p><br /></p>

<p>Implementation Details은 다음과 같다.</p>

<blockquote>
  <ul>
    <li>lr: 1e-5 (weight decay of 1e−5)</li>
    <li>cosine annealing schedule (1e−6)</li>
    <li>batch size(K): 256</li>
    <li>epoch: 8 $\times$ NVIDIA A800 ($\approx$ 100 hours)</li>
    <li>$\alpha = \beta = 1$</li>
    <li>augmentation: Gaussian noise addition, flipping, intensity scaling and shifting</li>
    <li>resize: $96 \times 96 \times 96$</li>
    <li>linear probe protocol: SVM / train:val:test=7:2:1</li>
  </ul>
</blockquote>

<p>evaluation 시 임상에서 부족한 FCNs을 반영하기 위해 FCNS의 10%만을 사용한다.</p>

<h3 id="32-quantitative-results">3.2 Quantitative Results</h3>
<p>아래 표에서 볼 수 있듯이 ADHD와 SRPBS에서 가장 좋은 ACC를 얻었다. 이는 sMRI-based model과 비교했을때 ABIDE, ADHD, and SRPBS 데이터셋에서 1.46%, 1.26%, 1.21% 차이가 난다. 이는 sMRI와 FCNs을 contrastive learning 함으로 상호 보완적인 정보를 완전히 포착할 수 있어서 mental disorder diagnosis에 도움을 준다는 것을 확인할 수 있다.</p>

<center>
<img src="/assets/img/paper-review/cinp/tab2.webp" width="80%" />
</center>
<p><br /></p>

<p>그러나 ABIDE에서는 SOTA 급 성능을 내지 못했는데 이는 ADHD 진단에 sMRI의 정보가 더 필요하다는 것을 의미할지도 모른다고 저자들은 말한다.</p>

<p>table 3은 network prompting에 대한 ablation study 결과이다.</p>

<center>
<img src="/assets/img/paper-review/cinp/tab3.webp" width="80%" />
</center>
<p><br /></p>

<h3 id="33-ablation-study">3.3 Ablation Study</h3>
<p>table 4는 Loss에 대한 abltation study 결과이다.</p>

<h2 id="4-conclusion">4. Conclusion</h2>
<p>본 논문에서 저자들은 3D T1w MRI와 FCN을 contrastive learning하는 framework인 CINP을 소개한다. mental disorder diagnosis의 성능을 향상시키기 위해 pre-training 하는 동안 3개의 loss가 3D T1w MRI image의 representation 품질을 FCN supervision으로 향상시켰다. 또한 network prompting을 도입함으로 적은 양의 FCN으로 환자가 sMRI만 촬영하여 예측에 사용하는 것이 가능하도록 한다. 3개의 데이터셋에서 CINP의 효과를 입증하였으며 sMRI를 임상 진단에 통합하는 것에 대한 가능성을 보여준다.</p>

<h2 id="개인적인-생각">개인적인 생각</h2>
<p>(1) fMRI는 촬영 시간이 길고 1회 촬영이 비싸서 환자들이 정기적으로 검사를 받기 어려운 modality이다. 그런 현실적인 상황을 검증시에 반영하였으며 network prompting이라는 기법으로 꽤 잘해결했다. 그럼에도 불구하고 ABIDE dataset에서 BNT에 밀린 것은 다소 아쉽다. 또한 github에 코드가 공개되어 있지 않은 것도 아쉽다.</p>

<p>(2) 이 논문을 통해 fMRI에 관심이 생겼었다. BOLD란 무엇인지 FCN이 뭐고 어떻게 계산되고 활용되고 있는지 공부할 수 있어서 좋았다. 앞으로도 fMRI를 사용한 논문을 자주 찾아볼 것 같은데 큰 도움이 되었다.</p>

<p>(3) 여담으로 작년 MICCAI에서 저자의 포스터 발표를 들었는데 저자분께서 매우 친절하게 잘 설명해주셨던 기억이 난다.</p>]]></content><author><name>YSPARK</name></author><category term="Paper-Review" /><category term="Medical-AI" /><category term="Contrastive-Learning" /><category term="MICCAI" /><summary type="html"><![CDATA[Learning 3D Medical Image Models From Brain Functional Connectivity Network Supervision For Mental Disorder Diagnosis (MICCAI 2025)]]></summary></entry><entry><title type="html">[코드리뷰] Entropy, Cross Entropy, KL Divergence</title><link href="https://kitewatermelon.github.io/code-review/information/" rel="alternate" type="text/html" title="[코드리뷰] Entropy, Cross Entropy, KL Divergence" /><published>2026-02-18T00:00:00+09:00</published><updated>2026-02-18T00:00:00+09:00</updated><id>https://kitewatermelon.github.io/code-review/information</id><content type="html" xml:base="https://kitewatermelon.github.io/code-review/information/"><![CDATA[<h3 id="0-index">0. Index</h3>
<ol>
  <li>Entropy</li>
  <li>Cross Entropy</li>
  <li>KL Divergence</li>
  <li>Mutual Information</li>
</ol>

<p>code 다운로드: <a href="/assets/code/code-review/information.ipynb">📥 information.ipynb 다운로드</a></p>

<h3 id="1-entropy">1. Entropy</h3>

<p>정보이론에서 시스템은 송신자, 채널, 수신자를 이용하여 모형화한다. 송신자는 채널을 통해 전달되는 메시지를 만들어낸다. 채널은 특정한 방식을 통해 메시지를 변경한다. 수신자는 어떤 메시지가 보내진 것인지 추론하고자 한다. 이 맥락에서 정보 엔트로피(또는 섀넌 엔트로피)는 <strong>각 메시지에 포함된 정보의 기댓값(평균)이다.</strong> ‘메시지’는 어떤 흐름의 정보에 대해서도 모형화할 수 있다.</p>
<ul>
  <li>기댓값(expected value, E)은 각 사건이 벌어졌을 때의 이득과 그 사건이 벌어질 확률을 곱한 것을 전체 사건에 대해 합한 값이다.</li>
</ul>

<p>기술적인 관점에서 보면 정보는 발생 가능한 사건이나 메시지의 확률분포의 음의 로그로 정의할 수 있다. 각 사건의 정보량은 그 기댓값, 또는 평균이 섀넌 엔트로피인 확률변수를 형성한다.</p>

<p>확률이 낮을수록, 어떤 정보일지는 불확실하게 되고, 우리는 이때 ‘정보가 많다’, ‘엔트로피가 높다’고 표현한다.</p>

<p>$H(X)=-\sum_{i} p_i log_2 p_i$</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>

<span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">])</span>
<span class="n">y</span> <span class="o">=</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="n">log2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">y</span><span class="o">*</span><span class="n">x</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span><span class="mi">3</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"-log2*p"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">"-p*log2*p"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="o">*</span><span class="n">x</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000,
        1.0000])
tensor([3.3219, 2.3219, 1.7370, 1.3219, 1.0000, 0.7370, 0.5146, 0.3219, 0.1520,
        -0.0000])
tensor([0.3322, 0.4644, 0.5211, 0.5288, 0.5000, 0.4422, 0.3602, 0.2575, 0.1368,
        -0.0000])
</code></pre></div></div>

<center>
<img src="/assets/img/code-review/information/fig1.webp" width="80%" />
</center>
<p><br /></p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">COIN</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">DICE</span> <span class="o">=</span> <span class="mi">6</span>

<span class="n">coin_probabilty</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="o">/</span><span class="n">COIN</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">COIN</span><span class="p">)])</span>
<span class="n">dice_probabilty</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="o">/</span><span class="n">DICE</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">DICE</span><span class="p">)])</span>
<span class="n">random_probabilty</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">])</span>

<span class="k">def</span> <span class="nf">entropy</span><span class="p">(</span><span class="n">p</span><span class="p">):</span>
    <span class="k">assert</span> <span class="n">p</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"확률 분포가 아닙니다."</span>
    <span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-8</span>
    <span class="k">return</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">p</span><span class="o">*</span><span class="n">torch</span><span class="p">.</span><span class="n">log2</span><span class="p">(</span><span class="n">p</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>

<span class="n">coin_entropy</span> <span class="o">=</span> <span class="n">entropy</span><span class="p">(</span><span class="n">coin_probabilty</span><span class="p">)</span>
<span class="n">dice_entropy</span> <span class="o">=</span> <span class="n">entropy</span><span class="p">(</span><span class="n">dice_probabilty</span><span class="p">)</span>
<span class="n">random_entropy</span> <span class="o">=</span> <span class="n">entropy</span><span class="p">(</span><span class="n">random_probabilty</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">coin_probabilty</span><span class="p">,</span> <span class="n">dice_probabilty</span><span class="p">,</span> <span class="n">random_probabilty</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">coin_entropy</span><span class="p">,</span> <span class="n">dice_entropy</span><span class="p">,</span> <span class="n">random_entropy</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor([0.5000, 0.5000]) tensor([0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]) tensor([0.1000, 0.9000])
tensor(1.) tensor(2.5850) tensor(0.4690)
</code></pre></div></div>

<p>위와 같은 균등분포에서 $p_i = 1/n$이고 따라서</p>

<p>$H=-\sum^n_{i=1}p_ilog_2p_i=log_2n$이다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">coin_entropy</span> <span class="o">==</span> <span class="n">torch</span><span class="p">.</span><span class="n">log2</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">COIN</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="n">dice_entropy</span> <span class="o">==</span> <span class="n">torch</span><span class="p">.</span><span class="n">log2</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">DICE</span><span class="p">)))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor(True)
tensor(True)
</code></pre></div></div>

<p>로그 함수는 독립적인 불확실성에 가산성을 제공하는데 사용된다. 예를 들어, 크기 m의 이산 표본 공간과 크기 n의 이산 표본 공간에서, 서로 독립이며 균등분포를 따르는 두 확률변수를 동시에 측정할 경우, 그 총 엔트로피는</p>

<p>$log_2(mn)=log_2(m)+log_2(n)$과 같다.</p>

<p>즉, 서로 독립인 두 확률변수의 엔트로피는 각 확률변수의 엔트로피의 합과 같다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">((</span><span class="n">coin_entropy</span> <span class="o">+</span> <span class="n">dice_entropy</span><span class="p">)</span> <span class="o">==</span> <span class="n">torch</span><span class="p">.</span><span class="n">log2</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="n">COIN</span><span class="o">*</span><span class="n">DICE</span><span class="p">)))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor(True)
</code></pre></div></div>

<h3 id="2-cross-entropy">2. Cross Entropy</h3>

<p>정보이론에서 교차 엔트로피란, 두 확률 분포 p 와 q를 구분하기 위해 필요한 평균 비트 수를 의미한다.</p>

<p>동일한 이벤트 공간의 두 분포 p와 q 사이의 교차 엔트로피는 다음과 같이 정의된다.</p>

\[E_P[X] = \sum_ip_ix_i \\
H(P, Q) = E_P[-log_2Q] = -\sum_ip_ilog_2q_i \\
H(P,Q) \neq H(Q,P)\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">cross_entropy</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">):</span>
    <span class="k">assert</span> <span class="n">p</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">q</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"확률 분포가 아닙니다."</span>
    <span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-8</span>
    <span class="k">return</span> <span class="o">-</span><span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">p</span><span class="o">*</span><span class="n">torch</span><span class="p">.</span><span class="n">log2</span><span class="p">(</span><span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">))</span>

<span class="c1"># binary class classification
</span><span class="n">binary_label</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">])</span>
<span class="c1"># multi class classification
</span><span class="n">multi_label</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="n">multi_label</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span>

<span class="n">binary_predictions</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">],</span>
                                   <span class="p">[</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span>
                                   <span class="p">[</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">],</span>
                                   <span class="p">[</span><span class="mf">1.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">]</span>
                                   <span class="p">])</span>

<span class="n">multi_predictions</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">],</span>
                                   <span class="p">[</span><span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,],</span>
                                   <span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">],</span>
                                   <span class="p">[</span><span class="mf">0.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="mf">1.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,]</span>
                                   <span class="p">])</span>

<span class="k">print</span><span class="p">(</span><span class="s">"Binary Class Classification"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">binary_prediction</span> <span class="ow">in</span> <span class="n">binary_predictions</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"prediction: "</span><span class="p">,</span> <span class="n">binary_prediction</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"CE: "</span><span class="p">,</span><span class="n">cross_entropy</span><span class="p">(</span><span class="n">binary_label</span><span class="p">,</span> <span class="n">binary_prediction</span><span class="p">))</span>

<span class="k">print</span><span class="p">(</span><span class="s">"Multi Class Classification"</span><span class="p">)</span>

<span class="k">for</span> <span class="n">multi_prediction</span> <span class="ow">in</span> <span class="n">multi_predictions</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"prediction: "</span><span class="p">,</span> <span class="n">multi_prediction</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"CE: "</span><span class="p">,</span> <span class="n">cross_entropy</span><span class="p">(</span><span class="n">multi_label</span><span class="p">,</span> <span class="n">multi_prediction</span><span class="p">))</span>    
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Binary Class Classification
prediction:  tensor([0.1000, 0.9000])
CE:  tensor(3.3219)
prediction:  tensor([0.5000, 0.5000])
CE:  tensor(1.)
prediction:  tensor([0.9000, 0.1000])
CE:  tensor(0.1520)
prediction:  tensor([1., 0.])
CE:  tensor(-0.)
Multi Class Classification
prediction:  tensor([0.0500, 0.0500, 0.0500, 0.0500, 0.8000])
CE:  tensor(4.3219)
prediction:  tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
CE:  tensor(2.3219)
prediction:  tensor([0.0500, 0.0500, 0.8000, 0.0500, 0.0500])
CE:  tensor(0.3219)
prediction:  tensor([0., 0., 1., 0., 0.])
CE:  tensor(-0.)
</code></pre></div></div>

<h3 id="3-kl-divergence">3. KL Divergence</h3>
<p>쿨백-라이블러 발산(Kullback–Leibler divergence, KLD)은 두 확률분포의 차이를 계산하는 데에 사용하는 함수로, 어떤 이상적인 분포에 대해, 그 분포를 근사하는 다른 분포를 사용해 샘플링을 한다면 발생할 수 있는 정보 엔트로피 차이를 계산한다. 상대 엔트로피(relative entropy), 정보 획득량(information gain), 인포메이션 다이버전스(information divergence)라고도 한다. 정보이론에서는 상대 엔트로피, 기계학습의 결정 트리에서는 정보 획득량을 주로 사용한다.</p>

<p>쿨백-라이블러 발산은 비대칭으로, 두 값의 위치를 바꾸면 함수값도 달라진다. 따라서 이 함수는 거리 함수는 아니다.</p>

\[D_{KL}(P||Q) = \sum P(i)log\frac{P(i)}{Q(i)}\]

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">KLD</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">):</span>
    <span class="k">assert</span> <span class="n">p</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">and</span> <span class="n">q</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">,</span> <span class="s">"확률 분포가 아닙니다."</span>
    <span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-8</span>
    <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">p</span><span class="o">*</span><span class="n">torch</span><span class="p">.</span><span class="n">log2</span><span class="p">((</span><span class="n">p</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">q</span> <span class="o">+</span> <span class="n">eps</span><span class="p">)))</span>

<span class="n">p</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">])</span>
<span class="n">q</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">])</span>
<span class="n">rst1</span> <span class="o">=</span> <span class="n">KLD</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">)</span>
<span class="n">rst2</span> <span class="o">=</span> <span class="n">KLD</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">rst1</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">rst2</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor(0.1536)
tensor(0.1757)
</code></pre></div></div>

<p>쿨백-라이블러 발산은 어떠한 확률분포 P가 있을 때, 샘플링 과정에서 그 분포를 근사적으로 표현하는 확률분포 Q 를 P 대신 사용할 경우 엔트로피 변화를 의미한다. 따라서, 원래의 분포가 가지는 엔트로피 H(P)와 P 대신 Q를 사용할 때의 교차 엔트로피(cross entropy) H(P,Q)의 차이를 구하면,
\(D_{KL}(P||Q) = H(P, Q) - H(P) \\
= (-\sum_iP(i)log_2Q(i)) - (-\sum_iP(i)log_2P(i)) \\
= \sum_iP(i)log_2(P(i))-log_2(Q(i)) \\
= D_{KL}(P||Q) = \sum P(i)log\frac{P(i)}{Q(i)}\)
로, 원래 정의했던 식과 같은 결과가 나온다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">KLD2</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">):</span>
    <span class="k">return</span> <span class="n">cross_entropy</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">)</span> <span class="o">-</span> <span class="n">entropy</span><span class="p">(</span><span class="n">p</span><span class="p">)</span>

<span class="n">rst3</span> <span class="o">=</span> <span class="n">KLD2</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">)</span>
<span class="n">rst4</span> <span class="o">=</span> <span class="n">KLD2</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">rst1</span> <span class="o">==</span> <span class="n">rst3</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">rst2</span> <span class="o">==</span> <span class="n">rst4</span><span class="p">)</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>tensor(False)
tensor(False)
</code></pre></div></div>

<p>torch.allclose() 함수는 두 텐서가 지정된 허용 오차 내에서 거의 동일한지를 확인한다. 이는 부동 소수점 연산의 미세한 차이로 인한 불일치를 허용할 수 있기에 권장된다.</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">rst1</span><span class="p">,</span> <span class="n">KLD2</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">q</span><span class="p">)))</span>
<span class="k">print</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">allclose</span><span class="p">(</span><span class="n">rst2</span><span class="p">,</span> <span class="n">KLD2</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">p</span><span class="p">)))</span>
</code></pre></div></div>

<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>True
True
</code></pre></div></div>

<p>Distance와 Divergence 차이는 무엇일까?</p>

<p>거리 척도 특성이란, “거리”라고 말할 수 있게 되기 위한 기준을 의미한다. 이 특성에는 다음 4가지가 존재한다.
비음성: d(x, y) &gt;= 0
동일성: d(x,y)=0 이면 x=y이고, 그 역도 성립한다.
대칭성: d(x, y) = d(y, x)
삼각부등식 d(x, z) &lt;= d(x, y) + d(y, z)</p>

<p>문제: 실제로 KL 발산은 거리의 척도(metric) 특성 네가지 중 두가지만을 만족한다. 만족하지 않는 두가지는 무엇일까?</p>

<p>Reference</p>
<ol>
  <li>entropy:
    <ul>
      <li>https://ko.wikipedia.org/wiki/%EC%A0%95%EB%B3%B4_%EC%97%94%ED%8A%B8%EB%A1%9C%ED%94%BC</li>
    </ul>
  </li>
  <li>cross entropy:
    <ul>
      <li>https://en.wikipedia.org/wiki/Cross-entropy</li>
    </ul>
  </li>
  <li>KL Diverence:
    <ul>
      <li>https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence</li>
      <li>https://ds31x.tistory.com/265</li>
      <li>https://go-gradually.tistory.com/m/entry/%EC%A0%95%EB%B3%B4-%EC%9D%B4%EB%A1%A0-KL-Divergence-KL-%EB%B0%9C%EC%82%B0-%ED%81%AC%EB%A1%9C%EC%8A%A4-%EC%97%94%ED%8A%B8%EB%A1%9C%ED%94%BC%EB%A5%BC-%EC%93%B0%EB%8A%94-%EC%9D%B4%EC%9C%A0</li>
    </ul>
  </li>
</ol>]]></content><author><name>YSPARK</name></author><category term="Code-Review" /><category term="Information-Theory" /><summary type="html"><![CDATA[Entropy, Cross Entropy, KL Divergence]]></summary></entry></feed>