import json from vllm import LLM, SamplingParams template = """### USER: 請將以下路名解析為 JSON 格式。 輸入:臺北市中正區八德路 輸出:{{"city": "臺北市", "town": "中正區", "road": "八德路"}} 輸入:{} ### ASSISTANT: {}""" template = """<|im_start|>user 請將以下路名解析為 JSON 格式。 輸入:{} <|im_end|> <|im_start|>assistant {}""" # 輸入:臺北市中正區八德路 # 輸出:{{"city": "臺北市", "town": "中正區", "road": "八德路"}} def build_prompt(inn, out=""): return template.format(inn, out) def iter_dataset(file_path): data = load_json(file_path) for item in data: city = item["city"] town = item["town"] road = item["road"] full = f"{city}{town}{road}" yield full, item def load_json(file_path): with open(file_path, "rt", encoding="UTF-8") as fp: return json.load(fp) # 建立測試集的 Prompt 列表 prompts, items = list(), list() for full, item in iter_dataset("data/test.json"): prompt = build_prompt(full) prompts.append(prompt) items.append(item) # 讀取模型 model_name = "models/Llama-7B-TwAddr-Merged" llm = LLM(model_name, dtype="float16") # temperature 設為 0.0 為 Greedy Decode # 確保每次實驗的結果都是一樣的 sampling_params = SamplingParams( max_tokens=256, temperature=0.0, stop=["}"], ) # 對所有 Prompt 同時進行推論 outputs = llm.generate(prompts, sampling_params) # 評估生成結果 results = list() for out, item in zip(outputs, items): text = out.outputs[0].text # 嘗試解析模型的輸出 try: begin = text.index("{") text = text[begin:] + "}" pred = json.loads(text) except: pred = None results.append(pred == item) if pred != item: print(pred, item) # 輸出準確率 accuracy = sum(results) / len(results) print(f"Accuracy: {accuracy:.2%}")