|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model_name = "wedo2910/research_ai" |
|
tokenizer_name = "wedo2910/research_ai_tok" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
def single_inference(question, max_new_tokens, temperature): |
|
|
|
messages = [ |
|
{"role": "system", "content": "اجب علي الاتي بالعربي فقط."}, |
|
{"role": "user", "content": question}, |
|
] |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(model.device) |
|
|
|
|
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
] |
|
|
|
|
|
outputs = model.generate( |
|
input_ids, |
|
max_new_tokens=max_new_tokens, |
|
eos_token_id=terminators, |
|
do_sample=True, |
|
temperature=temperature, |
|
) |
|
|
|
|
|
response = outputs[0][input_ids.shape[-1]:] |
|
output = tokenizer.decode(response, skip_special_tokens=True) |
|
return output |
|
|
|
|
|
st.title("Arabic AI Research QA") |
|
st.subheader("Ask a question to get an answer from the research AI model.") |
|
|
|
|
|
question = st.text_input("Question", placeholder="Enter your question here...") |
|
|
|
|
|
st.subheader("Settings") |
|
max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000, value=256) |
|
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.4, step=0.1) |
|
|
|
|
|
if st.button("Get Answer"): |
|
if not question: |
|
st.error("The question field is required.") |
|
else: |
|
try: |
|
answer = single_inference(question, max_new_tokens, temperature) |
|
st.subheader("Result") |
|
st.write(f"**Question:** {question}") |
|
st.write(f"**Answer:** {answer}") |
|
except Exception as e: |
|
st.error(f"Error: {e}") |
|
|