Llama2-TwAddr-LoRA / scripts /step3_evaluation.py
penut85420's picture
Add README, Data, Scripts
6512525
raw
history blame contribute delete
No virus
1.96 kB
import json
from vllm import LLM, SamplingParams
template = """### USER:
請將以下路名解析為 JSON 格式。
輸入:臺北市中正區八德路
輸出:{{"city": "臺北市", "town": "中正區", "road": "八德路"}}
輸入:{}
### ASSISTANT:
{}"""
template = """<|im_start|>user
請將以下路名解析為 JSON 格式。
輸入:{}
<|im_end|>
<|im_start|>assistant
{}"""
# 輸入:臺北市中正區八德路
# 輸出:{{"city": "臺北市", "town": "中正區", "road": "八德路"}}
def build_prompt(inn, out=""):
return template.format(inn, out)
def iter_dataset(file_path):
data = load_json(file_path)
for item in data:
city = item["city"]
town = item["town"]
road = item["road"]
full = f"{city}{town}{road}"
yield full, item
def load_json(file_path):
with open(file_path, "rt", encoding="UTF-8") as fp:
return json.load(fp)
# 建立測試集的 Prompt 列表
prompts, items = list(), list()
for full, item in iter_dataset("data/test.json"):
prompt = build_prompt(full)
prompts.append(prompt)
items.append(item)
# 讀取模型
model_name = "models/Llama-7B-TwAddr-Merged"
llm = LLM(model_name, dtype="float16")
# temperature 設為 0.0 為 Greedy Decode
# 確保每次實驗的結果都是一樣的
sampling_params = SamplingParams(
max_tokens=256,
temperature=0.0,
stop=["}"],
)
# 對所有 Prompt 同時進行推論
outputs = llm.generate(prompts, sampling_params)
# 評估生成結果
results = list()
for out, item in zip(outputs, items):
text = out.outputs[0].text
# 嘗試解析模型的輸出
try:
begin = text.index("{")
text = text[begin:] + "}"
pred = json.loads(text)
except:
pred = None
results.append(pred == item)
if pred != item:
print(pred, item)
# 輸出準確率
accuracy = sum(results) / len(results)
print(f"Accuracy: {accuracy:.2%}")