wedo2910 commited on
Commit
52c3a3a
·
verified ·
1 Parent(s): 9edd3a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -49
app.py CHANGED
@@ -1,71 +1,69 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
 
4
- from transformers import AutoTokenizer, AutoModelForQuestionAnswering
 
 
5
 
6
- model_name = "wedo2910/qa_arabic_model"
7
- tokenizer = AutoTokenizer.from_pretrained("aubmindlab/bert-base-arabertv02")
8
- model = AutoModelForQuestionAnswering.from_pretrained(model_name)
9
 
10
- qa_pipeline = pipeline(
11
- "question-answering",
12
- model=model,
13
- tokenizer=tokenizer
14
- )
 
 
 
 
 
 
 
 
 
15
 
16
- # Default settings
17
- default_settings = {
18
- "max_new_tokens": 512,
19
- "temperature": 0.7,
20
- "top_p": 0.9,
21
- "min_p": 0,
22
- "top_k": 0,
23
- "repetition_penalty": 1.0,
24
- "presence_penalty": 0,
25
- "frequency_penalty": 0,
26
- "max_answer_len": 50,
27
- "doc_stride": 128,
28
- }
29
 
30
- # Define a default context (e.g., a general knowledge text or topic)
31
- default_context = """
32
- التزم بنص السؤال.
33
- """
 
 
 
 
 
 
 
 
 
34
 
35
  # Streamlit UI
36
- st.title("Arabic AI Question Answering")
37
- st.subheader("Ask a question to get an answer.")
38
 
39
- # Input field for the question only
40
  question = st.text_input("Question", placeholder="Enter your question here...")
41
 
42
- # Settings sliders
43
  st.subheader("Settings")
44
- max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000000, value=512)
45
- temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.1)
46
- top_p = st.slider("Top P", min_value=0.0, max_value=1.0, value=0.9, step=0.1)
47
- min_p = st.slider("Min P", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
48
- top_k = st.number_input("Top K", min_value=0, max_value=1000, value=0)
49
- repetition_penalty = st.slider("Repetition Penalty", min_value=0.01, max_value=5.0, value=1.0, step=0.1)
50
- presence_penalty = st.slider("Presence Penalty", min_value=-2.0, max_value=2.0, value=0.0, step=0.1)
51
- frequency_penalty = st.slider("Frequency Penalty", min_value=-2.0, max_value=2.0, value=0.0, step=0.1)
52
- max_answer_len = st.number_input("Max Answer Length", min_value=1, value=50)
53
- doc_stride = st.number_input("Document Stride", min_value=1, value=128)
54
 
55
  # Generate Answer button
56
  if st.button("Get Answer"):
57
  if not question:
58
  st.error("The question field is required.")
59
  else:
60
- # Generate answer using the default context
61
  try:
62
- prediction = qa_pipeline(
63
- {"context": default_context, "question": question},
64
- max_answer_len=max_answer_len,
65
- doc_stride=doc_stride
66
- )
67
  st.subheader("Result")
68
  st.write(f"**Question:** {question}")
69
- st.write(f"**Answer:** {prediction['answer']}")
70
  except Exception as e:
71
- st.error(f"Error: {e}")
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ # Load the new model and tokenizer
5
+ model_name = "wedo2910/research_ai"
6
+ tokenizer_name = "wedo2910/research_ai_tok"
7
 
8
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
9
+ model = AutoModelForCausalLM.from_pretrained(model_name)
 
10
 
11
+ # Define the custom inference function
12
+ def single_inference(question, max_new_tokens, temperature):
13
+ # Prepare the prompt messages
14
+ messages = [
15
+ {"role": "system", "content": "اجب علي الاتي بالعربي فقط."},
16
+ {"role": "user", "content": question},
17
+ ]
18
+
19
+ # Use the tokenizer's chat template functionality
20
+ input_ids = tokenizer.apply_chat_template(
21
+ messages,
22
+ add_generation_prompt=True,
23
+ return_tensors="pt"
24
+ ).to(model.device)
25
 
26
+ # Define terminator tokens (end-of-sequence markers)
27
+ terminators = [
28
+ tokenizer.eos_token_id,
29
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
30
+ ]
 
 
 
 
 
 
 
 
31
 
32
+ # Generate the output
33
+ outputs = model.generate(
34
+ input_ids,
35
+ max_new_tokens=max_new_tokens,
36
+ eos_token_id=terminators,
37
+ do_sample=True,
38
+ temperature=temperature,
39
+ )
40
+
41
+ # Decode only the newly generated tokens (i.e. skip the prompt)
42
+ response = outputs[0][input_ids.shape[-1]:]
43
+ output = tokenizer.decode(response, skip_special_tokens=True)
44
+ return output
45
 
46
  # Streamlit UI
47
+ st.title("Arabic AI Research QA")
48
+ st.subheader("Ask a question to get an answer from the research AI model.")
49
 
50
+ # Input field for the question
51
  question = st.text_input("Question", placeholder="Enter your question here...")
52
 
53
+ # Settings sliders for generation parameters
54
  st.subheader("Settings")
55
+ max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000, value=256)
56
+ temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.4, step=0.1)
 
 
 
 
 
 
 
 
57
 
58
  # Generate Answer button
59
  if st.button("Get Answer"):
60
  if not question:
61
  st.error("The question field is required.")
62
  else:
 
63
  try:
64
+ answer = single_inference(question, max_new_tokens, temperature)
 
 
 
 
65
  st.subheader("Result")
66
  st.write(f"**Question:** {question}")
67
+ st.write(f"**Answer:** {answer}")
68
  except Exception as e:
69
+ st.error(f"Error: {e}")