안녕하세요, 수달이입니다.
라마 모델 이해하기 두 번째 파트! 오늘은 디코더 (Decoder) 모듈에 대해 다루겠습니다. 디코더의 8할은 어텐션 (Attention) 모듈이라고 할 수 있는데요. 지난 포스팅을 열심히 읽으신 분이라면 이미 디코더의 반 이상은 이해하신 셈이죠. ت
그럼 가벼운 마음으로 시작해 볼까요?
디코더 구성요소 살펴보기
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
먼저 HuggingFace Transformers 코드를 통해 디코더 모듈이 어떤 요소들로 이루어져 있는지 살펴볼까요? 위 코드는 LlamaDecoderLayer의 __init__ 함수인데요. 크게 세 가지 클래스가 눈에 띕니다.
1. LlamaAttention: 지난 시간에 배운 바로 그 어텐션 모듈이죠!
2. LlamaMLP: 데이터를 더욱 풍부하게 처리해 주는 신경망입니다.
3. LlamaRMSNorm (input layernorm, post attention layernorm): 모델의 안정적인 학습을 돕는 정규화 모듈입니다.
어텐션 모듈은 이미 익숙하실 테니, 오늘은 나머지 두 클래스에 대해 자세히 알아보겠습니다.
LlamaMLP
MLP는 Multi-layer Perceptron의 약자로 가장 기본적인 형태의 딥러닝 모델입니다. 입력층 (input layer) - 하나 이상의 은닉층 (hidden layer) - 출력층 (output layer) 으로 심플하게 구성되어 있죠.
하지만 Llama 모델에 쓰이는 MLP는 (트랜스포머에서는 feed-forward network이라고 지칭) 이 기본 구조보다 살짝 더 발전된 형태입니다. 가장 큰 차이점은 바로 '게이트(Gate)' 메커니즘을 도입했다는 점입니다. 이 게이트는 정보의 흐름을 선택적으로 조절하여, 모델이 더욱 정교하게 데이터를 처리하도록 돕습니다.
이러한 게이트 구조를 GLU (Gated Linear Unit) 라고 하며, 라마 모델은 이 게이트 구조에 Swish 활성화 함수를 사용했기 때문에 SwiGLU (Swish Gated Linear Unit) 라 칭합니다. 아래의 그림을 보시면 기존 MLP와 SwiGLU가 적용된 MLP의 차이를 더 쉽게 이해하실 수 있습니다.
💡 Swish 활성화 함수, 잠깐 알고 가기!
Swish 활성화 함수에 대한 자세한 설명은 이 글을 참조해 주세요. 이 포스팅에서는 Swish 함수를 통과한 값들이 일종의 '게이트 스코어' 역할을 한다는 점만 기억하시면 됩니다! 이 스코어는 항상 특정 값 (약 -0.278) 이상을 가지며, 이 값에 따라 hidden states에 포함된 각 정보의 중요도가 섬세하게 조절된답니다.
자, 이제 코드로 직접 확인해 볼까요?
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
코드를 보면, gate_proj와 up_proj은 입력값 x를 더 높은 차원의 벡터로 확장하는 역할을 합니다. 그리고 act_fn (여기서는 Swish 함수) 이 gate_proj의 출력값을 '게이트 스코어'로 변환시키죠. 이 게이트 스코어가 up_proj를 통과한 hidden state 벡터의 각 요소에 곱해져 정보의 양이 조절됩니다. 마치 중요한 정보는 더 많이, 덜 중요한 정보는 적게 흘려보내는 필터처럼요! 마지막으로 down_proj을 통해 다시 입력값과 동일한 저차원의 벡터로 축소됩니다.
LlamaRMSNorm
다음으로 살펴볼 클래스는 RMSNorm (Root Mean Square Normalization) 입니다. RMSNorm은 LayerNorm (Layer Normalization) 과 비슷한 정규화 기법으로, 모델의 안정적인 학습을 돕습니다.
LayerNorm은 각 레이어의 입력값에 대해 평균을 0, 표준편차를 1로 만들어주는 정규화 방식인데요. 좀 더 풀어서 설명하면, 벡터의 각 요소에서 벡터 전체의 평균을 빼고, 그 값을 표준편차로 나누어 데이터의 분포를 안정시키죠. 덕분에 극단적인 값의 영향을 줄여 학습을 안정적으로 만듭니다.
RMSNorm은 여기서 평균을 빼는 과정을 생략하여 계산을 단순화합니다. 그리고 표준편차 대신 RMS (Root Mean Square, 제곱평균제곱근) 로 각 요소를 나누어 정규화를 수행합니다. 이런 차이 덕분에 RMSNorm은 LayerNorm보다 연산량이 적어 더 효율적이라는 장점을 가집니다.
수식으로 비교하면 아래와 같습니다.
코드는 다음과 같습니다.
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True) --- (1)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) --- (2)
return self.weight * hidden_states.to(input_dtype) --- (3)
(1) Mean Square 계산: 입력값 hidden_states의 각 요소를 제곱한 후, 마지막 차원을 기준으로 평균을 계산합니다.
(2) 정규화 수행: (1)에서 구한 Mean Square에 아주 작은 값 (self.variance_epsilon, 분모가 0이 되는 것을 방지) 을 더한 후, 역제곱근 (reciprocal sqaure root) 을 구합니다. torch.rsqrt(x) 은 1 / torch.sqrt(x) 와 동일하며, 이 값이 바로 1 / RMS에 해당합니다. 이제 입력값 hidden_states에 1 / RMS 값을 곱하여 정규화를 수행합니다.
(3) 가중치 적용: 정규화된 hidden_states에 학습 가능한 가중치 self.weight 을 곱하여 최종 결과를 리턴합니다.
RMSNorm 모듈의 최종 목표는 hidden_states의 각 요소 값을 안정적으로 만드는 것입니다. 이 과정을 거치면 각 요소의 값은 변화하지만, 입력값과 출력값의 차원/모양은 그대로 유지됩니다.
디코더 완성하기
자, 이제 디코더의 구성요소에 대해 모두 알아보았으니 이들을 조합해 볼까요? 각 구성요소는 다음의 순서로 조립되는데요. 그림의 아래에서 위 방향으로 따라가시면 됩니다.
디코더는 크게 두 개의 블록으로 이루어져 있다고 생각하면 쉽습니다. 바로 어텐션 블록과 MLP 블록이죠. 그리고 각 블록이 시작될 때마다, 입력 벡터를 다듬는 정규화 모듈, 즉 RMSNorm이 첫 관문으로 자리 잡고 있습니다.
디코더 순서: [ 정규화 -> 어텐션 계산 ] -> [ 정규화 -> MLP ]
여기서 또 하나의 중요한 특징은, 각 블록의 연산이 끝날 때마다 입력 벡터를 출력 벡터에 더해주는 residual connection (잔차 연결) 이 일어난다는 점입니다. 이 기법은 모델이 아무리 깊어지더라도 초기 입력에 대한 정보를 잃지 않고 끝까지 전달할 수 있게 하여 모델의 안정적인 학습을 돕습니다. 실제로 라마 3.1 모델의 경우, 가장 작은 모델도 32개의 디코더 레이어를 쌓아 올린 아주 깊은 네트워크 (deep network) 입니다.
* 전체 코드는 포스트가 너무 길어질 수 있어 생략하겠습니다. 궁금하신 분들은 이곳을 클릭하시면 디코더 forward 함수로 이동합니다.
마무리
어떠셨나요? 직전의 어텐션 포스팅보다는 한결 수월하게 이해되셨기를 바랍니다. 오늘 내용을 정리하자면, 디코더는 다음과 같은 핵심 역할을 수행합니다.
1. 어텐션 모듈을 통해 입력 문장 내 단어들 사이의 관계과 중요도를 파악하고,
2. MLP 레이어를 통해 학습된 정보를 더욱 정교하고 풍부하게 다듬습니다.
라마를 비롯한 수많은 LLM의 심장과도 같은 디코더! 이제 그 다음 단계는 이 핵심 부품들을 차곡차곡 쌓아 올려 완성된 '모델' 클래스에 대해 알아보는 것입니다. 다음 포스팅은 이번 포스팅보다도 더 쉬울 예정이니, 포기하지 마시고 수달이를 잘 따라와 주시길 바랍니다.
그럼 모두 Happy LLM!

'LLM' 카테고리의 다른 글
라마 모델, 코드와 그림으로 이해하기 파트 3 - RoPE를 중심으로 (0) | 2025.06.16 |
---|---|
라마 모델, 코드와 그림으로 이해하기 파트 1 (1) | 2025.05.13 |
[Position Embedding] RoPE 로타리 포지션 임베딩 (0) | 2025.03.01 |
[Attention] Sliding Window Attention 슬라이딩 윈도우 어텐션 (0) | 2025.02.15 |