lotte-generator.py 3.71 KB
Newer Older
kihoon.lee's avatar
백업  
kihoon.lee committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import logging
import argparse
import os
import pandas as pd
from templates import LOTTE_PROMPT_STRATEGY

try:
    from aphrodite import LLM, SamplingParams

    print("- Using aphrodite-engine")

except ImportError:
    from vllm import LLM, SamplingParams

    print("- Using vLLM")

# 로깅 설정
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser()
parser.add_argument("-g", "--gpu_devices", help=" : CUDA_VISIBLE_DEVICES", default="0")
parser.add_argument(
    "-m",
    "--model",
    help=" : write huggingface model name to evaluate",
    default="LDCC/Chat-Mistral-Nemo-12B-32k",
)
parser.add_argument(
    "-ml", "--model_len", help=" : Maximum Model Length", default=4096, type=int
)
args = parser.parse_args()

logger.info(f"Args - {args}")

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_devices
gpu_counts = len(args.gpu_devices.split(","))

# LLM 초기화
logger.info(f"Initializing LLM with model: {args.model}")
llm = LLM(
    model=args.model,
    tensor_parallel_size=gpu_counts,
    max_model_len=args.model_len,
    gpu_memory_utilization=0.8,
    trust_remote_code=True,  # !
)
logger.info("LLM initialized successfully")

sampling_params = SamplingParams(
    temperature=0,
    skip_special_tokens=True,
    max_tokens=args.model_len,
    stop=[
        "<|endoftext|>",
        "[INST]",
        "[/INST]",
        "<|im_end|>",
        "<|end|>",
        "<|eot_id|>",
        "<end_of_turn>",
        "<eos>",
    ],
)

# chat_temlate가 없다면 default로 세팅하는 과정
tokenizer = llm.llm_engine.tokenizer.tokenizer

if tokenizer.chat_template is None:
    logger.info("chat template가 없으므로 default로 설정")
    default_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
    tokenizer.chat_template = default_chat_template

# 문제 로드
logger.info("Loading questions from questions.jsonl")
df_questions = pd.read_json(
    "questions.jsonl", orient="records", encoding="utf-8-sig", lines=True
)
logger.info(f"Loaded {len(df_questions)} questions")
if not os.path.exists("./generated/" + args.model):
    os.makedirs("./generated/" + args.model)


logger.info("Processing LOTTE questions")

def format_lotte_question(row):
    category = row['category']
    question = row['questions'][0]  # 첫 번째 질문만 사용
    prompt = LOTTE_PROMPT_STRATEGY.get(category, "")
    return tokenizer.apply_chat_template(
        [
            {"role": "system", "content": prompt},
            {"role": "user", "content": question}
        ],
        tokenize=False,
        add_generation_prompt=True,
    )

formatted_questions = df_questions.apply(format_lotte_question, axis=1)

logger.info("Generating LOTTE outputs")
lotte_outputs = [
    output.outputs[0].text.strip()
    for output in llm.generate(formatted_questions, sampling_params)
]
logger.info(f"Generated {len(lotte_outputs)} LOTTE outputs")

df_output = pd.DataFrame(
    {
        "id": df_questions["id"],
        "category": df_questions["category"],
        "questions": df_questions["questions"].apply(lambda x: x[0]),  # 첫 번째 질문만 저장
        "outputs": lotte_outputs,
        "references": df_questions["references"].apply(lambda x: x[0] if x is not None else None),  # 첫 번째 참조만 저장
    }
)

# 결과 저장
output_file = f"./generated/{args.model}/lotte_single_turn.jsonl"
logger.info(f"Saving results to {output_file}")
df_output.to_json(
    output_file,
    orient="records",
    lines=True,
    force_ascii=False,
)
logger.info("LOTTE generation process completed")