본문 바로가기

LLM

라마 모델, 코드와 그림으로 이해하기 파트 1

안녕하세요, 수달이입니다.

 

오늘부터 메타 라마(Llama) 모델의 아키텍처를 깊-게 살펴보는 새로운 시리즈를 시작하겠습니다. 논문으로 공부하는 것도 좋은 방법이지만, 이번 시리즈에서는 이미지와 HuggingFace Transformer 코드를 활용하여 좀 더 쉽게 배워보고자 합니다. 그럼 공부 순서부터 함께 볼까요?

 

  1. LlamaAttention
  2. LlamaDecoderLayer
  3. LlamaModel
  4. LlamaForCausalLM

최근 언어 모델들은 대부분 트랜스포머를 기반으로 하는데요. 이 트랜스포머의 핵심인 어텐션 모듈부터 시작해서, 점차 범위를 넓혀 디코더, 베이스 모델, 언어 모델링 모델 순으로 살펴보겠습니다. 나무에서 시작하여 숲을 이해하는 눈을 키우는 것이죠.

 

 

 

그럼 어텐션 모듈부터 시작해 볼까요? 출-발! ⊹ ࣪ ﹏𓊝﹏𓂁﹏⊹ ࣪ ˖

Attention이란?

Attention (self-attention이라고도 칭함) 메커니즘은 문장 내의 단어들이 현재 처리 중인 단어와 얼마나 관련 있는지, 그 중요도를 계산하여 더 중요한 정보에 집중 (Attention!) 하도록 돕는 기술입니다. 단방향성 (예: 현재 처리 중인 단어보다 앞서 위치한 단어들만 고려함) 을 가졌던 과거 모델들에 비해 문맥 (context)을 양방향으로 더 깊이 이해할 수 있게 되는 것이죠.

LlamaAttention 모듈 - 인풋

(타이틀의 링크를 클릭하면 HuggingFace Transformers 코드로 이동합니다 ت)

def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

 

위의 코드 블록은 LlamaAttention forward 함수의 시그니쳐입니다. 주요 입력값은 크게 5가지임을 알 수 있죠.

 

  • hidden_states: 입력 문장의 단어들을 컴퓨터가 이해할 수 있는 벡터로 치환한 값 
  • position_embeddings: 문장 내 단어의 '위치' 정보를 담고 있는 임베딩. 라마에서는 RoPE를 사용
  • attention_mask: 어텐션 계산을 할 때 '보지 말아야 할' 부분 (미래 단어, 패딩 토큰 등) 을 가려주는 장치
  • past_key_value: 앞서 처리한 Key, Value를 저장해 두는 캐시
  • cache_position: 새롭게 형성된 Key, Value를 저장할 캐시 안의 위치

마지막 두 입력값인 past_key_value와 cache_position은 연산 속도를 높이는 KV Cache (Key-Value Cache) 와 관련된 값으로, 자세한 내용은 다음 글에서 설명하도록 하겠습니다.

 

먼저 hidden_states의 형태 (shape) 를 알아보죠. 앞서 언급했듯, hidden_states는 단어들을 벡터로 변환한 값인데요. 입력 문장의 길이를 S, 각 단어 벡터의 크기 (dimension) 를 HS (Hidden Size) 라고 한다면, hidden_states는 크기 S x HS의 행렬 (matrix) 이 됩니다. 만약 여러 문장을 처리한다면 배치 (batch) 사이즈 B를 포함하여 B x S x HS의 행렬이 됩니다. 이제 메인 로직으로 들어가 보겠습니다.

LlamaAttention - 메인 로직

input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

 

본격적인 연산에 들어가기 전에, 중간 결과물의 모양을 먼저 설정하도록 하겠습니다. 위 코드에 따르면 중간 결과물 (여전히 hidden_states이라 칭합니다) 의 모양은 B x S x -1 x HD (Head Dimension) 이 됩니다. 즉, hidden_shape = (B, S, -1, HD) 가 되는 거죠. (여기서 -1은 자동으로 Number of Heads, 즉 NH로 결정됩니다.)

 

Query, Key, Value란?

이쯤에서 attention을 구하는 공식을 살펴볼까요.

 

Source: lena-voita.github.io

 

여기서 q, k, v는 각각 쿼리 (Query), 키 (Key), 밸류 (Vector) 벡터를 뜻합니다.

 

  • 쿼리: 현재 처리 중인 단어
  • 키: 입력 문장 내의 모든 단어의 '색인' 또는 '꼬리표' (쿼리와의 유사도, 즉 attention weight을 구하는 데 쓰임)
  • 밸류: 입력 문장 내의 모든 단어의 '내용' 또는 '표현' (최종 attention output을 구하는데 쓰임)

키와 밸류가 살짝 헷갈릴 수 있는데, 모든 단어가 용도가 다른 두 버전의 벡터를 가지고 있다고 생각하면 됩니다. 하나 (키) 는 쿼리와의 유사성을 구하는데만 쓰이고, 다른 하나 (밸류) 는 최종 어텐션 값을 구하는데 쓰입니다. 

 

Query, Key, Value 구하기

이제 Query, Key, Value 매트릭스를 계산해 보겠습니다. 아래 코드에서 볼 수 있듯, Q, K, V 매트릭스는 모두 동일한 hidden_states (입력 단어들의 벡터 집합) 에서 출발합니다. 다만, 각각 다른 가중치 행렬 (projection matrix) 과 곱해질 뿐이죠.

 

query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

 

대표적으로 쿼리 행렬을 만드는 q_proj 연산을 살펴볼 텐데요. 시각화를 위해 몇 가지 가정을 하겠습니다.

 

  1. 샘플은 하나만 있다고 가정. 즉, B=1
  2. 샘플은 토큰 2개로 이루어진 짧은 문장. 즉, S=2
  3. 단어 벡터는 3차원. 즉, HS=3
  4. attention head는 총 2개, 각 head의 차원은 2. 즉, NH (num_heads) = 2, HD (head_dim) = 2.

q_proj은 선형 (nn.Linear) 모듈로 hidden_states을 쿼리 가중치 행렬 W_query에 곱하는 역할을 합니다. 이 곱셈의 결과는 S x (NH x ND) = 2 X 4 차원의 매트릭스입니다. 배치 사이즈도 고려한다면 1 x 2 x 4 차원의 매트릭스가 되겠습니다.

 

 

 

좀 더 일반화하면,

 

  • q_proj(hidden_states): 두 매트릭스의 곱셈으로 B x S x (NH x HD) 차원의 매트릭스 탄생
  • .view(hidden_shape): hidden_shape = (B, S, -1, HD) 이므로, B x S x NH x HD 형태로 변형
  • .transpose(1, 2):  1, 2차원 위치 바꾸기 -> 최종 매트릭스의 모양은 B x NH x S x HD 

앞선 코드에서 보았듯, 키와 밸류 또한 동일한 과정을 거치기 때문에 쿼리, 키, 밸류 모두 B x NH x S x HD 차원의 매트릭스가 되었습니다. (Group Query Attention의 경우 쿼리와 키/밸류의 헤드 수가 상이할 수 있지만 이 포스팅에서는 모두 헤드의 수(NH)가 모두 동일하다고 가정) 

 

Attention Score

이제 남은 단계는 본격적으로 attention 점수를 계산하는 것입니다. attention 공식에 있는 것들을 하나씩 계산해 보죠.

 

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
    causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
    attn_weights = attn_weights + causal_mask

 

첫 스텝은 attention score를 계산하는 것입니다. 이는 앞서 소개한 attention 공식에서 softmax 함수의 입력값에 해당하며, 위 코드 블록의 첫 번째 줄에서 확인할 수 있습니다. 먼저, 키 매트릭스를 전치 (transpose) 하여 B x NH x HD x S 형태로 만듭니다. 이렇게 변환된 키 매트릭스는 B x NH x S x HD 형태의 쿼리 매트릭스와 행렬곱셈이 가능해집니다. 이 곱셈의 결과로 B x NH x S x S 모양의 score 매트릭스가 생성됩니다.

 

다음으로, 이 매트릭스의 각 값을 sqrt (d_k)으로 나누어 스케일링합니다. 여기서 d_k는 키 벡터의 차원으로, 이 포스팅의 표기법에 따르면 HD (head dimension) 에 해당합니다. 스케일링을 하는 이유는 쿼리와 키의 곱셈 값이 지나치게 커지는 것을 방지하기 위함입니다. 만약 이 값이 너무 크면, softmax 함수를 통과했을 때 특정 값에만 확률이 과도하게 집중되어, (예: 가장 큰 값에 가까운 확률 1, 나머지는 0에 가까운 확률) 극단적인 확률 분포가 생성될 수 있습니다. 이러한 극단적인 분포는 모델 학습을 불안정하게 만드는 요인이 됩니다.

 

attention score 계산의 마지막 단계는 attention mask를 적용하는 것입니다. mask가 적용되는 토큰값에는 아주 큰 음수 (음의 무한대 값처럼 동작) 를 더합니다. 아래와 같이 말이죠. 각 셀은 해당 행과 열의 두 토큰 사이의 최종 attention score를 나타냅니다.

 

 

 

attention mask는 크게 두 가지 이유로 필요합니다. 첫째, 패딩 토큰 가리기. 여러 시퀀스를 동시에 처리하는 경우, 시퀀스마다 길이가 다르기 때문에 짧은 시퀀스들은 뒤에 패딩 토큰을 더하여 가장 긴 시퀀스와 같은 길이가 되도록 만들어집니다. 이런 패딩 토큰은 실제론 의미가 없는 단어이기 때문에 최종 attention output 계산시 가려주어야 합니다.

 

둘째, 미래 토큰 가리기. 언어 모델은 현재까지의 토큰을 활용하여 다음 토큰을 추론하도록 훈련됩니다. 따라서 미래 토큰을 attention 계산에 고려하는 것은 마치 컨닝을 하는 것과 같기 때문에 미래 토큰들을 마스크를 이용하여 가려주어야 합니다. 이런 방식을 Causal (인과적) Masking이라고도 합니다. 

 

자, 이제 거의 다왔습니다. attention score를 구했으니 이제 이 점수를 가중치로 변환해봅시다.

 

Attention Weight

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

 

attention score를 가중치로 만드는 것은 쉽습니다. softmax 함수를 이용하여 확률 분포로 치환해주는 것이죠. 위 코드 블럭의 첫 줄에 해당합니다. 다음으로는 dropout 레이어를 활용하여 score matrix의 일부를 랜덤하게 0으로 만들어줍니다. 단, 이 dropout은 코드에서 볼 수 있듯 모델 학습단계에만 적용됩니다. 모델의 overfitting을 방지하기 위한 regularization (정규화) 테크닉의 일종이죠.

 

Attention Output

attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)

 

드디어 마지막 단계입니다. 잊고 있었던 밸류 매트릭스가 이제 등장합니다. attention weight 매트릭스는 attention score와 같은 B x NH x S x S 형태를 가지고 있습니다. 그리고 밸류 매트릭스는 B x NH x S x HD 형태이죠. 이 둘을 곱하면 B x NH x S x HD 형태의 매트릭스가 생성됩니다.

 

그리고 이 매트릭스는 o_proj 레이어의 아웃풋 가중치 행렬 (W_out)과의 행렬곱셈 연산을 위해 형태가 여러번 바뀝니다.

 

  • .transpose(1,2): B x S x NH x HD 형태로 치환
  • .reshape(*input_shape, -1): (B, S, -1) 형태로 형태 변환. 최종적으로 B x S x (NH x ND)

.contiguous()는 매트릭스의 연산을 위해 텐서들을 메모리에 연속적으로 배치하는 것을 의미합니다. 모델을 이해하는데에는 크게 중요하지 않습니다.

 

이렇게 형태가 변환된 attention output은 o_proj 레이어의 아웃풋 가중치 행렬 W_out과 곱해져 최종 attention output을 생성합니다. 아래의 예시처럼 말이죠. 

 

 

 

마무리

꽤 복잡한 여정이였지만, attention 모듈 이야기를 요약하면 다음과 같습니다.

 

  1. 입력 문장이 벡터로 변환된 hidden states (B x S x HS) 가 입력값으로 들어옴
  2. hidden states는 쿼리/키/밸류 매트릭스와 각각 곱해져 query state, key state, value state로 변환됨
  3. query state, key state를 곱하여 각 단어 사이의 연관성 정도를 나타내는 attention weight을 계산함
  4. attention weigth은 차례로 value state아웃풋 매트릭스와 곱해짐
  5. 최종적으로 입력값인 hidden states와 동일한 형태의 attention output (B x S x HS) 을 생성함

즉, 입력 문장의 벡터들이 어텐션 모듈을 통과한 후 각 단어간의 연관성 정보를 담은 새로운 벡터로 탈바꿈한 것입니다. 단순히 각 단어의 의미만 담고 있던 벡터가 다른 단어와의 관계까지 품게 된 것입니다. 이게 바로 LLM 모델이 문맥을 이해할 수 있는 힘이죠.

 

이제 가장 큰 산을 넘었습니다. 다음 포스팅에서는 attention을 포함하고 있는 decoder 레이어를 살펴볼 텐데요. 여기까지 잘 따라오셨다면 80%는 이해하신 셈이니 가벼운 마음으로 다음 글을 함께해주시기 바랍니다.

 

그럼 모두 Happy LLM!

 

 

* 본 포스팅에서는 설명을 단순하게 하기 위해, position embedding, KV cache, Group-query attention에 대한 아주 자세한 설명은 다음으로 미뤘습니다. 이 내용들은 다음 기회에 더 자세히 다루도록 할게요. :)