본문 바로가기
Paper Review/Generative Model

[논문리뷰] DDT : Decoupled Diffusion Trasnformer

by 서윤하 2025. 4. 17.
반응형
 

DDT: Decoupled Diffusion Transformer

Diffusion transformers have demonstrated remarkable generation quality, albeit requiring longer training iterations and numerous inference steps. In each denoising step, diffusion transformers encode the noisy inputs to extract the lower-frequency semantic

arxiv.org

오늘 리뷰할 논문은 DDT라고 해서, 기존의 DiT / SiT 등 Transformer 기반의 Diffusion Process에서의 한계점을 극복하기 위한 논문이라고 할 수 있다. 

Abstract

Diffusion transformers have demonstrated remarkable generation quality, albeit requiring longer training iterations and numerous inference steps. In each denoising step, diffusion transformers encode the noisy inputs to extract the lower-frequency semantic component and then decode the higher frequency with identical modules. This scheme creates an inherent optimization dilemma: encoding lowfrequency semantics necessitates reducing high-frequency components, creating tension between semantic encoding and high-frequency decoding. To resolve this challenge, we propose a new Decoupled Diffusion Transformer (DDT), with a decoupled design of a dedicated condition encoder for semantic extraction alongside a specialized velocity decoder. Our experiments reveal that a more substantial encoder yields performance improvements as model size increases. For ImageNet 256 × 256, Our DDT-XL/2 achieves a new state-of-the-art performance of 1.31 FID (nearly 4× faster training convergence compared to previous diffusion transformers). For ImageNet 512 × 512, Our DDTXL/2 achieves a new state-of-the-art FID of 1.28. Additionally, as a beneficial by-product, our decoupled architecture enhances inference speed by enabling the sharing selfcondition between adjacent denoising steps. To minimize performance degradation, we propose a novel statistical dynamic programming approach to identify optimal sharing strategies.

굵은 글씨로 표현된 부분이 이 논문의 핵심으로 보일 수 있는 부분이다. 간단히 아래와 같이 정리해볼 수 있을 것이다.

  1. 기존의 Diffusion Transformer 구조는 동일한 모듈로 Semantic / detail 정보를 동시에 처리하였다.
  2. 이러한 문제 떄문에 최적화 과정에서 semantic / high-frequency 사이에서 상충되는 문제가 발생한다.
  3. 이 문제를 해결하기 위해 인코더와 디코더를 분리하는 구조를 제안하고, 아래와 같다.
    • Condition Encoder : 저주파수 부분의 의미론적인 정보를 추출하는 역할
    • Velocity Decoder : 고주파수의 세부 정보를 복원하는 데에 사용됨. 여기서 Velocity란, noise가 image로 복원되는 경로(trajectory)의 변화율을 의미한다.
  4. 이 분리를 통해 각 모듈이 특정 주파수 대역에 집중할 수 있도록 하여, 위에서 대두되었던 최적화 과정에서의 문제를 개선하고 추가적인 방법을 통해 추론 속도를 향상시켰다.

 

Introduction

DiT(아래를 보고 와도 됨)이 UNet Based Model들을 이미 대체한 상황에서, 물론 좋은 성능을 내고 있었지만 훈련이 느리고, 최적화가 어렵다는 구조적 한계들이 명확해졌다.

  • UNet 기반 디퓨전 모델들은, 효율성은 존재하지만 표현력이 아무래도 제한적이다.
  • DiT는 성능이 좋지만, 위에서 언급했듯이 최적화적인 문제가 존재해서 학습의 안정성이 떨어지고 학습 및 추론 과정이 느리다는 단점이 존재한다.

➔ Main Contribution

1. Encoder / Decoder를 분리하여, Semantic한 정보와 Detail 정보를 분리해서 처리한다.
2. Encoder들의 출력을 diffusion step 사이에 공유함으로써, 계산 효율성을 개선한다.
3. 통계적 Dynamic Programming 방식을 사용하여, 어느 step에서 encoder output을 reuse할지 선택한다.

 

Method

DTT Abstract Structure

Method들에 대해서 알아보도록 하자. 여기서 특징적으로 봐야할 점은, Encoder과 Decoder 사이의 기능 분리(Modularity)를 철저히 한 것이라고 볼 수 있을 것이다.

The condition encoder extracted the low-frequency component from noisy input, class label, and timestep to serve as a selfcondition for the velocity decoder; the velocity decoder processed the noisy latent with the self-condition to regress the high-frequency velocity.

조건 인코더의 경우에는 이제 input, class label, timestep 3가지를 이용하여 low-frequency component를 뽑아내는 역할을 하고, 속도 디코더의 경우에는 self-condition, timestep, noise를 이용하여 denoising을 위한 noise component를 뽑아냄으로써 high-frequency velocity를 얻어내는 것이 목표라고 할 수 있다. 각각에 대해서 알아보자.

Preliminary

앞으로 논문에서 말하는 'semantic'을 저주파 정보, 논문에서 말하는 'detail'한 정보들을 고주파 정보라고 지칭할 것이다. 사실 진작 했어야 했던 일인 것 같기는 한데, 그래도 지금부터라도 이렇게 설명하도록 하겠다.

Condition Encoder

conditional encoder는 noisy한 image와 timestep, class label을 입력받아서 낮은 주파수의 정보를 인코딩하는 역할이다. 이 역할은 DiT / SiT 기반의 트랜스포머 구조를 따르되, 여기서 주로 나오게 되던 long residual connection은 없애고 진행한다. 코드와 함께 알아보도록 하자.

def forward(self, x, t, y, s=None, mask=None):
    B, _, H, W = x.shape
    pos = self.fetch_pos(H//self.patch_size, W//self.patch_size, x.device)
    x = torch.nn.functional.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
    t = self.t_embedder(t.view(-1)).view(B, -1, self.hidden_size)
    y = self.y_embedder(y).view(B, 1, self.hidden_size)
    c = nn.functional.silu(t + y)
    if s is None:
    	s = self.s_embedder(x)
        for i in range(self.num_encoder_blocks):
        	s = self.blocks[i](s, c, pos, mask)
        s = nn.functional.silu(t + s)

	x = self.x_embedder(x)
	for i in range(self.num_encoder_blocks, self.num_blocks):
		x = self.blocks[i](x, s, pos, None)
	x = self.final_layer(x, s)
	x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
	return x, s

여기서 보면 t_embedder, y_embbedder를 이용해서 timestep, class label을 임베딩한 후, 거기에 t까지 같이 임베딩한 후 semantic condition을 생성한다는 것을 볼 수 있다. 더 자세한 임베딩까지는 뜯어보면서 찾아보길 바란다. 필자도 현재 뜯어보고 있는 단계라, 더 뜯어보기에는 조금 부족한 점이 있는 것 같다.

일단 전체적인 flow를 보자면, 먼저 ViT와 같이 x_t를 unfold 함수를 통해 patch embedding으로 쪼갠 후, t,y를 adaln_zero를 통해서 각 블록에 progressive하게 injection하게 되는데, 이에 대해서는 나중에 알아보도록 하고 일단 여기서는 semantic condition을 생성하는 것을 볼 수 있다.(if s in None 이후 5줄 보면 attention먹이고 ffn block을 통과시켜서 생성한 후, timestep을 이 뒤에 임의로 더해 최종 s를 완성하게 된다)

여기에서는 가장 중요한 것이 timestep 사이의 consistency인데, 이를 유지하기 위해서 REPA(Representation Alignment)를 사용하게 된다. 

여기서 r*는 DINO를 이욯나 pre-trained representation인데, 이것과 i번째 output이 일치하면 일치할수록 이후에  나온느 indirect supervision에 활용할 가능성이 높아진다. 이에 대해서는 추후에 다루도록 하겠다.

Velocity Decoder

velocity decoder는 self-condition을 이용해서 noisy latent로부터 velocity를 예측하는 것을 목표로 한다. 이는 고주파의 정보를 복원함으로써, 디테일한 구조나 텍스처의 복원을 주 역할로 삼는다.
구조를 위의 forward 코드에서 보면 아래와 같다.

	x = self.x_embedder(x)
	for i in range(self.num_encoder_blocks, self.num_blocks):
		x = self.blocks[i](x, s, pos, None)
	x = self.final_layer(x, s)
	x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), (H, W), kernel_size=self.patch_size, stride=self.patch_size)
	return x, s

여기서 보면 noisy image를 latent로 embedding한 후, DDTBlock을 decoder로 사용하는데 여기서 self-condition을 이용하게 된다.(timestep같은 경우에는, 이미 encoder에서 사용하였기 때문에 별도로 사용하지 않음. 이미 s에 포함되었다고 간주하는 것) 추후, final_layer를 통해서 최종 velocity를 예측하게 된다. 이를 수식으로 표현하면 아래와 같다.

여기서 s는 여기서도 Adaptive Layer Normalization-Zero(AdaLN-Zero) 방식으로 Decoder의 attention block에 주입된다.

여기서의 학습 손실은 flow matching loss를 이용하는데, 이는 결국 예측하는 것이 하나의 값이 아니라 velocity(복원하는 과정)이기 때문이라고 할 수 있다. 이를 식으로 나타내면 아래와 같다.

flow matching loss

Sampling Accleration

앞에서 Representation Alignment(REPA)를 통해 인접한 timestep들이 유사하다고 하였다. 그렇기 때문에, 이를 이용해 인접한 timestep 간에 재사용한다면 어느 정도의 속도의 향상을 할 수 있을 것이라는 것이 저자의 아이디어였고, 이를 통해서 인코더의 계산적 부하까지 줄일 수 있었다.

DDT에서는 이러한 Self-Condition의 공유를 위해 두 가지의 주요 전략을 사용하였다.

  1. 균일 공유(Uniform Encoder Sharing) : 전체 추론 단계 중에서 특정 K 단계에서만 self-condition을 재계산하고, 나머지 단계에서는 K/N 간격으로 이전 단계의 self condition을 재사용하는 방식이다. 그렇기 때문에 K가 작을수록 더 많은 단계에서 self-condition을 공유함으로써 추론 속도를 올릴 수 있다. 이전 연궁니 DeepCache에서도 이와 같이 균일 공유 전략을 사용하였다고 하는데, 이에 대해서는 읽어본 적이 없어서 잘 모르겠다.
  2. 통계적 동적 프로그래밍(Statistical Dynamic Programming) : 여기서는 일단 timestep별로 self-condition 간의 cosine 유사도를 구한다. 이후 이를 이용해서 유사도 행렬을 구성하고, 여기서 minimal path를 dynamic programming 방식을 통해 계산함으로써 최적의 해를 구하여 이를 통해 추론을 진행할 수 있도록 하였다.

이러한 다양한 방식을 통해 self-condition 공유 비율이 적절한 범위 내에 있으면 FID 상의 저하가 거의 없이 추론 속도의 향상을 얻을 수 있다고 한다. 아래의 표에서 Uniform / StatisticDP를 보면 DP 방식이 조금은 더 나은 방식이라는 것을 볼 수 있다. share ratio 파라미터의조절이 필요해보이긴 하지만, 그래도 좋은 아이디어임에는 틀림 없다.

Experiments

기본적으로 Imagenet 256x256에 대한 생성을 진행하였고, 표는 아래와 같다.

Classifier-Free Guidance가 들어갔을 때의 수치를 보면 확실히 FID에 대해서도 좋은 모습을 보이고 있고 Inception Score 부분에서도 좋은 모습을 보였으며, 가장 큰 부분은 훨씬 적은 에폭으로 (REPA기준 4배 빠른 속도) 비슷하거나 더 상위의 결과를 내었다는 것이라 할 수 있다.

Conclusion

**여기는 필자의 개인적 의견입니다**
이렇게 DDT 논문을 읽어봤는데, 이런 류의 논문을 읽을 때마다 드는 생각은 '작은 변화도 큰 결과를 가져올 수 있다'는 것이다. 기존 방식들이 Decoder-only구조로 진행을 하는 것에서 한계를 느낀다는 것이 사실 쉽지는 않은 일인 것 같은데, 이게 코드로 비교해보니 생각보다 큰 일이 아니지만 큰 변화를 가져왔음이 참 느끼는 점이 많아지는 논문이었다. 단순히 논문을 보기만 하지 않고, 비판적으로 보는 것의 의미를 다시 한 번 생각하게끔 하는 논문이 아니었나... 싶다. 한번 써봐야지

반응형