wedo2910 commited on
Commit
33b295a
·
verified ·
1 Parent(s): 930a873

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -42
app.py CHANGED
@@ -1,72 +1,94 @@
1
- import os
2
- os.environ["TRANSFORMERS_NO_BITSANDBYTES"] = "1" # Disable bitsandbytes integration
3
-
4
  import streamlit as st
 
5
  from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
- # Load the new model and tokenizer
8
- model_name = "wedo2910/research_ai"
9
- tokenizer_name = "wedo2910/research_ai_tok"
 
 
 
 
 
 
 
 
 
 
10
 
11
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=False)
13
 
14
- # Define the custom inference function
15
- def single_inference(question, max_new_tokens, temperature):
16
- # Prepare the prompt messages
 
 
 
 
17
  messages = [
18
  {"role": "system", "content": "اجب علي الاتي بالعربي فقط."},
19
  {"role": "user", "content": question},
20
  ]
21
 
22
- # Use the tokenizer's chat template functionality
23
- input_ids = tokenizer.apply_chat_template(
24
- messages,
25
- add_generation_prompt=True,
26
- return_tensors="pt"
27
- ).to(model.device)
28
-
29
- # Define terminator tokens (end-of-sequence markers)
30
- terminators = [
31
- tokenizer.eos_token_id,
32
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
33
- ]
34
-
35
- # Generate the output
 
 
 
 
 
 
36
  outputs = model.generate(
37
  input_ids,
38
  max_new_tokens=max_new_tokens,
39
- eos_token_id=terminators,
40
  do_sample=True,
41
  temperature=temperature,
 
42
  )
43
 
44
- # Decode only the newly generated tokens (i.e. skip the prompt)
45
- response = outputs[0][input_ids.shape[-1]:]
46
- output = tokenizer.decode(response, skip_special_tokens=True)
47
- return output
 
 
 
48
 
49
  # Streamlit UI
50
  st.title("Arabic AI Research QA")
51
- st.subheader("Ask a question to get an answer from the research AI model.")
52
 
53
- # Input field for the question
54
  question = st.text_input("Question", placeholder="Enter your question here...")
55
 
56
- # Settings sliders for generation parameters
57
  st.subheader("Settings")
58
  max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000, value=256)
59
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.4, step=0.1)
60
 
61
- # Generate Answer button
62
  if st.button("Get Answer"):
63
  if not question:
64
- st.error("The question field is required.")
65
  else:
66
- try:
67
- answer = single_inference(question, max_new_tokens, temperature)
68
- st.subheader("Result")
69
- st.write(f"**Question:** {question}")
70
- st.write(f"**Answer:** {answer}")
71
- except Exception as e:
72
- st.error(f"Error: {e}")
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # Define your repository names.
6
+ # For a fully merged model, you typically use the model repo (and a matching tokenizer repo).
7
+ MODEL_NAME = "wedo2910/research_ai"
8
+ TOKENIZER_NAME = "wedo2910/research_ai_tok"
9
+
10
+ # Load the tokenizer and model.
11
+ # Note: Use trust_remote_code=True if your model repo uses custom code.
12
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, trust_remote_code=True)
13
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
14
+
15
+ # Move model to the appropriate device.
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = model.to(device)
18
 
19
+ # Optionally set model to evaluation mode.
20
+ model.eval()
21
 
22
+ def single_inference(question: str, max_new_tokens: int, temperature: float) -> str:
23
+ """
24
+ Generates an answer for the given question.
25
+
26
+ The prompt is constructed using a system instruction in Arabic, and the question is appended.
27
+ """
28
+ # Define the messages that simulate a chat conversation.
29
  messages = [
30
  {"role": "system", "content": "اجب علي الاتي بالعربي فقط."},
31
  {"role": "user", "content": question},
32
  ]
33
 
34
+ # Some tokenizers provided by custom repos may implement apply_chat_template.
35
+ # If available, use it; otherwise, build a prompt manually.
36
+ if hasattr(tokenizer, "apply_chat_template"):
37
+ input_ids = tokenizer.apply_chat_template(
38
+ messages,
39
+ add_generation_prompt=True,
40
+ return_tensors="pt"
41
+ ).to(device)
42
+ else:
43
+ # Manually build the prompt
44
+ system_prompt = "اجب علي الاتي بالعربي فقط.\n"
45
+ user_prompt = f"السؤال: {question}\n"
46
+ full_prompt = system_prompt + user_prompt
47
+ input_ids = tokenizer(full_prompt, return_tensors="pt").input_ids.to(device)
48
+
49
+ # Define the terminator tokens.
50
+ # (For a merged model, usually the eos_token_id is sufficient.)
51
+ terminators = [tokenizer.eos_token_id]
52
+
53
+ # Generate the output.
54
  outputs = model.generate(
55
  input_ids,
56
  max_new_tokens=max_new_tokens,
 
57
  do_sample=True,
58
  temperature=temperature,
59
+ # Optionally add other generation parameters (top_p, top_k, etc.) if needed.
60
  )
61
 
62
+ # Remove the prompt part from the output.
63
+ generated_ids = outputs[0][input_ids.shape[-1]:]
64
+
65
+ # Decode the tokens into a string.
66
+ output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
67
+
68
+ return output_text
69
 
70
  # Streamlit UI
71
  st.title("Arabic AI Research QA")
72
+ st.subheader("Ask a question and get an answer from the research AI model.")
73
 
74
+ # Input field for the question.
75
  question = st.text_input("Question", placeholder="Enter your question here...")
76
 
77
+ # Settings for generation.
78
  st.subheader("Settings")
79
  max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000, value=256)
80
  temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.4, step=0.1)
81
 
82
+ # When the button is pressed, generate the answer.
83
  if st.button("Get Answer"):
84
  if not question:
85
+ st.error("Please enter a question.")
86
  else:
87
+ with st.spinner("Generating answer..."):
88
+ try:
89
+ answer = single_inference(question, max_new_tokens, temperature)
90
+ st.subheader("Result")
91
+ st.markdown(f"**Question:** {question}")
92
+ st.markdown(f"**Answer:** {answer}")
93
+ except Exception as e:
94
+ st.error(f"Error: {e}")