|
from Prompter import Prompter |
|
from Callback import Stream, Iteratorize |
|
import os |
|
import sys |
|
|
|
import gradio as gr |
|
import torch |
|
import transformers |
|
from peft import PeftModel |
|
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer |
|
import pandas as pd |
|
import numpy as np |
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
try: |
|
if torch.backends.mps.is_available(): |
|
device = "mps" |
|
except: |
|
pass |
|
|
|
base_model = "openthaigpt/openthaigpt-1.0.0-beta-7b-chat-ckpt-hf" |
|
load_8bit = True |
|
|
|
lora_weights = "PLatonG/openthaigpt-1.0.0-beta-7b-expert-recommendation-2.0" |
|
prompter = Prompter("alpaca") |
|
tokenizer = LlamaTokenizer.from_pretrained(base_model) |
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
base_model, |
|
load_in_8bit=load_8bit, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
offload_folder = "./offload" |
|
) |
|
model = PeftModel.from_pretrained( |
|
model, |
|
lora_weights, |
|
torch_dtype=torch.float16, |
|
offload_folder = "./offload" |
|
) |
|
|
|
|
|
model.config.pad_token_id = tokenizer.pad_token_id = 0 |
|
model.config.bos_token_id = 1 |
|
model.config.eos_token_id = 2 |
|
|
|
if not load_8bit: |
|
model.half() |
|
|
|
model.eval() |
|
if torch.__version__ >= "2" and sys.platform != "win32": |
|
model = torch.compile(model) |
|
|
|
def evaluate( |
|
instruction, |
|
input=None, |
|
stream_output=False, |
|
): |
|
temperature=0.1 |
|
top_p=0.9 |
|
top_k=10 |
|
num_beams=1 |
|
max_new_tokens=380 |
|
|
|
prompt = prompter.generate_prompt(instruction, input) |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(device) |
|
|
|
generation_config = GenerationConfig( |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
num_beams=num_beams, |
|
do_sample = True, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
generate_params = { |
|
"input_ids": input_ids, |
|
"generation_config": generation_config, |
|
"return_dict_in_generate": True, |
|
"output_scores": True, |
|
"max_new_tokens": max_new_tokens, |
|
} |
|
|
|
if stream_output: |
|
|
|
|
|
|
|
|
|
def generate_with_callback(callback=None, **kwargs): |
|
kwargs.setdefault( |
|
"stopping_criteria", transformers.StoppingCriteriaList() |
|
) |
|
kwargs["stopping_criteria"].append( |
|
Stream(callback_func=callback) |
|
) |
|
with torch.no_grad(): |
|
model.generate(**kwargs) |
|
|
|
def generate_with_streaming(**kwargs): |
|
return Iteratorize( |
|
generate_with_callback, kwargs, callback=None |
|
) |
|
|
|
with generate_with_streaming(**generate_params) as generator: |
|
for output in generator: |
|
|
|
decoded_output = tokenizer.decode(output) |
|
|
|
if output[-1] in [tokenizer.eos_token_id]: |
|
break |
|
|
|
yield prompter.get_response(decoded_output) |
|
return |
|
|
|
|
|
torch.manual_seed(42) |
|
with torch.no_grad(): |
|
generation_output = model.generate( |
|
input_ids=input_ids, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=max_new_tokens, |
|
|
|
) |
|
s = generation_output.sequences[0] |
|
output = tokenizer.decode(s) |
|
yield prompter.get_response(output) |
|
|
|
|
|
|
|
fourNSMOTE = pd.read_csv("FILTER_GREATERTHANTHREE_FROM_SHEETS_SMOTE_train.csv") |
|
|
|
with gr.Blocks(fill_height = True, title="Expert Recommendations") as demo: |
|
gr.Markdown( |
|
""" |
|
# Expert Recommendations |
|
วิธีการใช้งาน |
|
* เลือกค่าที่ต้องการในแต่ละตัวเลือก |
|
* กดปุ่ม 'GENERATE INPUT' จากนั้นจะแสดงผลลัพธ์ในช่อง 'full prompt' |
|
* กดปุ่ม 'GENERATE OUTPUT' จากนั้นรอประมาณ 20 ถึง 60 วินาที ผลลัพธ์จะแสดงในช่อง 'ผลลัพธ์ (output)' |
|
""") |
|
with gr.Row(): |
|
birth_year = gr.components.Number(minimum = 2536, maximum = 2557, value= 2545, |
|
label="ปีเกิด", |
|
info="ต่ำสุด : 2536 สูงสุด : 2557") |
|
nationality_name = gr.components.Dropdown(choices=fourNSMOTE.NATIONALITY_NAME.unique().tolist(), |
|
label="สัญชาติ", |
|
value = fourNSMOTE.NATIONALITY_NAME.unique().tolist()[0]) |
|
religion_name = gr.components.Dropdown(choices=fourNSMOTE.RELIGION_NAME.unique().tolist(), |
|
label="ศาสนา", |
|
value = fourNSMOTE.RELIGION_NAME.unique().tolist()[0]) |
|
with gr.Row(): |
|
sex = gr.components.Dropdown(choices=fourNSMOTE.JVN_SEX.unique().tolist(), |
|
label="เพศ", |
|
value = fourNSMOTE.JVN_SEX.unique().tolist()[0]) |
|
inform_status = gr.components.Dropdown(choices=fourNSMOTE.INFORM_STATUS_TXT.unique().tolist(), |
|
label="เหตุที่นำมาสู่การดำเนินคดี", |
|
value = fourNSMOTE.INFORM_STATUS_TXT.unique().tolist()[0]) |
|
age = gr.components.Number(minimum = 10, maximum = 19, value= 17, |
|
label="อายุตอนกระทำผิด", |
|
info="ต่ำสุด : 10 ปี สูงสุด : 19") |
|
with gr.Row(): |
|
|
|
offense_name = gr.components.Dropdown(choices=fourNSMOTE.OFFENSE_NAME.unique().tolist(), |
|
label="คดีที่กระทำผิด", |
|
value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0]) |
|
|
|
ref_value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0] |
|
|
|
allegation_name = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_NAME.unique().tolist(), label="ชื่อของข้อกล่าวหา", |
|
value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_NAME"].unique().tolist()[0]) |
|
|
|
allegation_desc = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_DESC.unique().tolist(), label="รายละเอียดของข้อกล่าวหา", |
|
value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_DESC"].unique().tolist()[0]) |
|
|
|
def update_dropDown_allegation(value): |
|
allegation_query = fourNSMOTE.query("OFFENSE_NAME == @value") |
|
data = allegation_query["ALLEGATION_NAME"].unique().tolist() |
|
allegation_name = gr.components.Dropdown(choices=data, value=data[0]) |
|
|
|
return allegation_name |
|
|
|
def update_dropDown_allegation_desc(offense_name, allegation_name): |
|
allegationDesc_query = fourNSMOTE.query("OFFENSE_NAME == @offense_name and ALLEGATION_NAME == @allegation_name") |
|
data = allegationDesc_query["ALLEGATION_DESC"].unique().tolist() |
|
allegation_desc = gr.components.Dropdown(choices=data, value=data[0]) |
|
|
|
return allegation_desc |
|
|
|
offense_name.change(fn=update_dropDown_allegation, inputs=offense_name, outputs=[allegation_name]) |
|
offense_name.change(fn=update_dropDown_allegation_desc, inputs=[offense_name, allegation_name], outputs=[allegation_desc]) |
|
allegation_name.change(fn=update_dropDown_allegation_desc, inputs=[offense_name, allegation_name], outputs=[allegation_desc]) |
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
rn1 = gr.components.Radio(choices=["ถูก", "ผิด"], |
|
label="ปรากฎลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่", |
|
value="ถูก") |
|
rn2 = gr.components.Radio(choices=["ถูก", "ผิด"], |
|
label="ปรากฎประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย", |
|
value = "ถูก") |
|
rn3 = gr.components.Radio(choices=["ถูก", "ผิด"], |
|
label="ปรากฎประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว", |
|
value = "ถูก") |
|
with gr.Row(): |
|
|
|
education = gr.components.Dropdown(choices=fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist(), |
|
label="สถาณะการศึกษา", |
|
value = fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist()[0]) |
|
occupation = gr.components.Dropdown(choices=fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist(), |
|
label="สถาณะการประกอบอาชีพ", |
|
value = fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist()[0]) |
|
province = gr.components.Dropdown(choices=fourNSMOTE.PROVINCE_NAME.unique().tolist(), |
|
label="จังหวัดที่กระทำผิด", |
|
value = fourNSMOTE.PROVINCE_NAME.unique().tolist()[0]) |
|
|
|
|
|
def generate_input(birth_year, nationality_name, religion_name, sex, |
|
inform_status, age, offense_name, allegation_name, |
|
allegation_desc, rn1, rn2, rn3, education, occupation, province): |
|
|
|
birth_year = f"เกิดเมื่อปี พ.ศ. {int(birth_year)}" |
|
|
|
if int(age) >= 10 and int(age) <=15: |
|
age = f"มีอายุอยู่ในช่วง 10 ถึง 15 ปี" |
|
elif int(age) >=16 and int(age) <= 20: |
|
age = f"มีอายุอยู่ในช่วง 16 ถึง 20 ปี" |
|
elif int(age) >=21 and int(age) <= 25: |
|
age = f"มีอายุอยู่ในช่วง 21 ถึง 25 ปี" |
|
elif int(age) >=26: |
|
age = f"มีอายุอยู่ในช่วง 26 ปีขึ้นไป" |
|
|
|
if rn1 == "ถูก": |
|
rn1 = "มีลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่" |
|
else: |
|
rn1 = "ไม่มีลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่" |
|
|
|
if rn2 == "ถูก": |
|
rn2 = "มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย" |
|
else: |
|
rn2 = "ไม่มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย" |
|
|
|
if rn3 == "ถูก": |
|
rn3 = "มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว" |
|
else: |
|
rn3 = "ไม่มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว" |
|
|
|
instruciton = "จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้" |
|
input = f"{birth_year} {nationality_name} {religion_name} {sex} {inform_status} {age} {offense_name} {allegation_name} {allegation_desc} {rn1} {rn2} {rn3} {education} {occupation} {province}" |
|
|
|
|
|
return input |
|
|
|
def generate_full_input(inst ,input): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return f"{inst} {input}" |
|
|
|
def test_fucn(inst, input, stream): |
|
return str(inst) |
|
|
|
|
|
|
|
instruction = gr.Textbox(label = "คำสั่ง", value="จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้", visible=False, interactive=False) |
|
input_compo = gr.Textbox(label = "ข้อมูลเข้า (input)", show_copy_button = True, visible=False) |
|
|
|
|
|
|
|
full_input = gr.Textbox(label = "full prompt", visible=True, show_copy_button=True) |
|
btn1 = gr.Button("GENERATE INPUT") |
|
|
|
btn1.click(fn=generate_input, inputs=[birth_year, nationality_name, religion_name, sex, |
|
inform_status, age, offense_name, allegation_name, |
|
allegation_desc, rn1, rn2, rn3, education, occupation, province], |
|
outputs=input_compo) |
|
|
|
|
|
input_compo.change(fn = generate_full_input, inputs=[instruction, input_compo], outputs=full_input) |
|
|
|
|
|
outputModel = gr.Textbox(label= "ผลลัพธ์ (output)") |
|
btn2 = gr.Button("GENERATE OUTPUT") |
|
btn2.click(fn=evaluate, inputs=[instruction, input_compo], outputs=outputModel) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch(debug=True, share=True) |