본문 바로가기
Research/NLP

Generation Configuration - 생성 인퍼런스에 사용되는 config 이해하기

by yooonlp 2023. 7. 16.

※ 본 포스트는 Coursera 강의인 "Generative AI with Large Language Models"의 Week 1의 내용의 일부를 정리하고 필요한 내용을 추가하여 작성한 글입니다. 

 

 

이전 포스트에서는 사전 훈련 모델의 구조를 소개하며, LLM을 Encoder, Decoder, Encoder-Decoder 구조로 분류하였습니다. 이번 포스트에서는 훈련된 모델을 가지고 생성 Inference를 진행할 때, Configuration을 통해 생성되는 토큰을 제어하는 방법에 대하여 이야기합니다. 

 

 

다음은 Flan-t5 모델에 "dialogue"를 입력으로 넣어 생성을 하는 코드입니다. 입력을 토크나이징 하고, 모델에 입력으로 넣어 다음 토큰을 예측하도록 하여, 결과를 다시 토크나이저로 디코딩하여 출력 문장을 받아옵니다. dialogue는 "knkarthick/dialogsum"에서 가져온 데이터이며, 대화와 대화의 요약 문장이 pair로 있어 대화 요약 모델을 위해 만들어진 데이터셋입니다. 생성 config 사용 방법에 집중하기 위해 요약 태스크에 대해서는 생각하지 않고 본 포스트에서는 대화 입력만 사용합니다. 

from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer

model_name='google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

dialogue = """
#Person1#: What time is it, Tom?
#Person2#: Just a minute. It's ten to nine by my watch.
#Person1#: Is it? I had no idea it was so late. I must be off now.
#Person2#: What's the hurry?
#Person1#: I must catch the nine-thirty train.
#Person2#: You've plenty of time yet. The railway station is very close. It won't take more than twenty minutes to get there.
"""

inputs = tokenizer(dialogue, return_tensors='pt')
output = tokenizer.decode(
        model.generate(
            inputs["input_ids"], 
            max_new_tokens=50,
        )[0], 
        skip_special_tokens=True
    )

generate 메소드에 현재는 max_new_tokens 파라미터만 있지만, 여러가지 파라미터를 추가하여 생성되는 토큰을 제어할 수 있습니다. 

 

 

Max new tokens

이 파라미터는 생성할 토큰의 최대 개수를 제한하는 역할을 합니다. 즉, generate 메소드가 얼마나 많은 토큰을 생성할 것인지를 지정할 수 있습니다. 이는 결과적으로 생성되는 문장의 길이를 제어하여, 생성되는 결과가 너무 길거나 짧지 않게 조절할 수 있습니다.

 

문장 생성을 할 때, 문장의 끝을 알리는 stop token(".", "<eos>" 등)이 등장하면 모델은 더 이상 토큰을 생성하지 않습니다. max new tokens로 설정된 토큰 수에 도달하지 않더라도, stop token이 등장하면 모델은 생성을 중단하게 됩니다.

 

Greedy vs random sampling

토큰을 샘플링하는 방식을 제어하는 방법입니다. 다음과 같은 토큰에 대한 확률이 있다고 가정합니다.

0.20 케익
0.10 도넛
0.02 바나나 
0.01 사과

 

Greedy Sampling은 가장 높은 확률을 가진 토큰을 선택하기 때문에, 예시에서는 '케익'을 선택하게 됩니다. 이는 의미적으로 봤을 때, 일관된 결과를 생성하지만, 그 결과가 상대적으로 뻔할 수 있습니다. 반면에 Random sampling은 모든 토큰의 확률을 고려하여 무작위로 토큰을 선택합니다. 이 방법은 훨씬 창의적인 결과를 생성할 수 있지만, 의미가 없는 결과를 생성할 수도 있습니다. 랜덤 샘플링의 이러한 불확실성을 제한하여 합리적인 예측을 위해, Top-k sampling과 Top-p sampling이라는 기법이 사용됩니다.

  • Top-k sampling: 확률이 높은 상위 k개의 토큰 중에서 무작위로 토큰을 선택하는 방법
  • Top-p sampling: 누적 확률이 p를 초과하지 않는 토큰들 중에서 무작위로 토큰을 선택하는 방법

 

 

Temperature

이 파라미터는 모델의 예측이 얼마나 불확실하게 보일지를 결정하는 역할을 합니다. 앞서 설명한 top-k, top-p sampling은 확률값을 샘플링하는 방법이라면, temperature는 이 값에 따라 다음 토큰의 확률 분포의 형태가 달라집니다. 낮은 값은 강하게 봉우리를 이루는 확률 분포를 만들고, 높은 값은 넓고 평평한 확률 분포를 만듭니다. 낮은 값에서는 모델이 가장 확률이 높은 토큰에 대해 강하게 기울어져 있다는 것을 의미하여, 모델이 가장 가능성이 높은 토큰을 선택하는 경향이 높아져 일관적이지만 덜 창의적인 결과를 생성할 수 있습니다. 

반면에 높은 값은 모델이 가능한 모든 토큰에 대해 거의 동일한 확률을 부여하게 만들어, 모델이 더 무작위적으로 토큰을 선택하는 경향이 있습니다. 이는 더 창의적이고 예측 불가능한 결과를 생성할 수 있지만, 때로는 의미가 모호하거나 무의미한 결과를 생성할 수도 있습니다.

따라서 temperature는 모델의 "창의성"과 "일관성" 사이의 균형을 조절하는 데 사용됩니다. 온도를 높이면 결과의 무작위성이 증가하고, 온도를 낮추면 결과의 일관성이 증가합니다.

Source: Lecture slide from&nbsp;https://www.coursera.org/learn/generative-ai-with-llms/lecture/18SPI/generative-configuration

 

 

이 파라미터들은 GenerationConfig 클래스를 통해 지정할 수 있습니다.

from transformers import GenerationConfig

generation_config = GenerationConfig(max_new_tokens=50)
# generation_config = GenerationConfig(max_new_tokens=10)
# generation_config = GenerationConfig(max_new_tokens=50, do_sample=True, top_p=0.9)
# generation_config = GenerationConfig(max_new_tokens=50, do_sample=True, top_k=10, top_p=0.9)
# generation_config = GenerationConfig(max_new_tokens=50, do_sample=True, temperature=0.1)
# generation_config = GenerationConfig(max_new_tokens=50, do_sample=True, temperature=0.5)
# generation_config = GenerationConfig(max_new_tokens=50, do_sample=True, temperature=1.0)

inputs = tokenizer(dialogue, return_tensors='pt')
output = tokenizer.decode(
    model.generate(
        inputs["input_ids"],
        generation_config=generation_config,
    )[0], 
    skip_special_tokens=True
)

추가 파라미터들은 허깅페이스 공식 문서에서 확인할 수 있습니다.

https://huggingface.co/docs/transformers/v4.30.0/main_classes/text_generation#transformers.GenerationConfig