File size: 3,526 Bytes
973b318
33b295a
52c3a3a
973b318
33b295a
 
 
 
30fe006
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33b295a
30fe006
 
9d0454f
30fe006
33b295a
9d0454f
33b295a
 
 
 
 
 
30fe006
52c3a3a
 
 
 
 
30fe006
33b295a
 
 
 
 
 
 
 
 
 
 
 
 
52c3a3a
 
 
 
 
30fe006
52c3a3a
 
30fe006
33b295a
 
 
 
 
 
9edd3a8
973b318
52c3a3a
33b295a
973b318
33b295a
973b318
 
33b295a
973b318
52c3a3a
 
973b318
33b295a
973b318
1598ad8
33b295a
973b318
33b295a
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Define your repository names.
MODEL_NAME = "wedo2910/research_ai"
TOKENIZER_NAME = "wedo2910/research_ai_tok"

# Check if CUDA is available and choose an appropriate device mapping.
if torch.cuda.is_available():
    device = "cuda"
    # When using GPU, you might let the model auto-map to available GPUs.
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, 
        trust_remote_code=True,
        device_map="auto"
    )
else:
    device = "cpu"
    # Force CPU loading; this bypasses GPU-specific integrations like bitsandbytes.
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, 
        trust_remote_code=True,
        device_map="cpu"
    )

# Load the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, trust_remote_code=True)

# Optionally set the model to evaluation mode.
model.eval()

def single_inference(question: str, max_new_tokens: int, temperature: float) -> str:
    """
    Generates an answer for the given question.
    
    The prompt is constructed using a system instruction in Arabic, and the question is appended.
    """
    # Define messages for a simulated chat conversation.
    messages = [
        {"role": "system", "content": "اجب علي الاتي بالعربي فقط."},
        {"role": "user", "content": question},
    ]
    
    # If the tokenizer has an `apply_chat_template` method, use it; otherwise, build the prompt manually.
    if hasattr(tokenizer, "apply_chat_template"):
        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(device)
    else:
        system_prompt = "اجب علي الاتي بالعربي فقط.\n"
        user_prompt = f"السؤال: {question}\n"
        full_prompt = system_prompt + user_prompt
        input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
    
    # Generate the output.
    outputs = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        # You can add more generation parameters if needed.
    )
    
    # Remove the prompt part from the generated output.
    generated_ids = outputs[0][input_ids.shape[-1]:]
    
    # Decode the tokens into a string.
    output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    return output_text

# Streamlit UI
st.title("Arabic AI Research QA")
st.subheader("Ask a question and get an answer from the research AI model.")

# Input field for the question.
question = st.text_input("Question", placeholder="Enter your question here...")

# Settings for generation.
st.subheader("Settings")
max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000, value=256)
temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.4, step=0.1)

# When the button is pressed, generate the answer.
if st.button("Get Answer"):
    if not question:
        st.error("Please enter a question.")
    else:
        with st.spinner("Generating answer..."):
            try:
                answer = single_inference(question, max_new_tokens, temperature)
                st.subheader("Result")
                st.markdown(f"**Question:** {question}")
                st.markdown(f"**Answer:** {answer}")
            except Exception as e:
                st.error(f"Error: {e}")