|
import streamlit as st |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
MODEL_NAME = "wedo2910/research_ai" |
|
TOKENIZER_NAME = "wedo2910/research_ai_tok" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
device_map="auto" |
|
) |
|
else: |
|
device = "cpu" |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_NAME, |
|
trust_remote_code=True, |
|
device_map="cpu" |
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, trust_remote_code=True) |
|
|
|
|
|
model.eval() |
|
|
|
def single_inference(question: str, max_new_tokens: int, temperature: float) -> str: |
|
""" |
|
Generates an answer for the given question. |
|
|
|
The prompt is constructed using a system instruction in Arabic, and the question is appended. |
|
""" |
|
|
|
messages = [ |
|
{"role": "system", "content": "اجب علي الاتي بالعربي فقط."}, |
|
{"role": "user", "content": question}, |
|
] |
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template"): |
|
input_ids = tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True, |
|
return_tensors="pt" |
|
).to(device) |
|
else: |
|
system_prompt = "اجب علي الاتي بالعربي فقط.\n" |
|
user_prompt = f"السؤال: {question}\n" |
|
full_prompt = system_prompt + user_prompt |
|
input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device) |
|
|
|
|
|
outputs = model.generate( |
|
input_ids, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
temperature=temperature, |
|
|
|
) |
|
|
|
|
|
generated_ids = outputs[0][input_ids.shape[-1]:] |
|
|
|
|
|
output_text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
|
|
|
return output_text |
|
|
|
|
|
st.title("Arabic AI Research QA") |
|
st.subheader("Ask a question and get an answer from the research AI model.") |
|
|
|
|
|
question = st.text_input("Question", placeholder="Enter your question here...") |
|
|
|
|
|
st.subheader("Settings") |
|
max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000, value=256) |
|
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.4, step=0.1) |
|
|
|
|
|
if st.button("Get Answer"): |
|
if not question: |
|
st.error("Please enter a question.") |
|
else: |
|
with st.spinner("Generating answer..."): |
|
try: |
|
answer = single_inference(question, max_new_tokens, temperature) |
|
st.subheader("Result") |
|
st.markdown(f"**Question:** {question}") |
|
st.markdown(f"**Answer:** {answer}") |
|
except Exception as e: |
|
st.error(f"Error: {e}") |
|
|