wedo2910 commited on
Commit
973b318
·
verified ·
1 Parent(s): 2f43e16

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import pipeline
3
+
4
+ # Load the model and tokenizer from the models/ directory
5
+ qa_pipeline = pipeline(
6
+ "question-answering",
7
+ model="models/qa_arabic_model_final",
8
+ tokenizer="models/qa_arabic_model_final"
9
+ )
10
+
11
+ # Default settings
12
+ default_settings = {
13
+ "max_new_tokens": 512,
14
+ "temperature": 0.7,
15
+ "top_p": 0.9,
16
+ "min_p": 0,
17
+ "top_k": 0,
18
+ "repetition_penalty": 1.0,
19
+ "presence_penalty": 0,
20
+ "frequency_penalty": 0,
21
+ "max_answer_len": 50,
22
+ "doc_stride": 128,
23
+ }
24
+
25
+ # Streamlit UI
26
+ st.title("Arabic AI Question Answering")
27
+ st.subheader("Provide context and ask a question to get answers.")
28
+
29
+ # Input fields
30
+ context = st.text_area("Context", placeholder="Enter the context here...", height=200)
31
+ question = st.text_input("Question", placeholder="Enter your question here...")
32
+
33
+ # Settings sliders
34
+ st.subheader("Settings")
35
+ max_new_tokens = st.number_input("Max New Tokens", min_value=1, max_value=1000000, value=512)
36
+ temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.7, step=0.1)
37
+ top_p = st.slider("Top P", min_value=0.0, max_value=1.0, value=0.9, step=0.1)
38
+ min_p = st.slider("Min P", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
39
+ top_k = st.number_input("Top K", min_value=0, max_value=1000, value=0)
40
+ repetition_penalty = st.slider("Repetition Penalty", min_value=0.01, max_value=5.0, value=1.0, step=0.1)
41
+ presence_penalty = st.slider("Presence Penalty", min_value=-2.0, max_value=2.0, value=0.0, step=0.1)
42
+ frequency_penalty = st.slider("Frequency Penalty", min_value=-2.0, max_value=2.0, value=0.0, step=0.1)
43
+ max_answer_len = st.number_input("Max Answer Length", min_value=1, value=50)
44
+ doc_stride = st.number_input("Document Stride", min_value=1, value=128)
45
+
46
+ # Generate Answer button
47
+ if st.button("Get Answer"):
48
+ if not context or not question:
49
+ st.error("Both context and question fields are required.")
50
+ else:
51
+ # Generate answer
52
+ try:
53
+ prediction = qa_pipeline(
54
+ {"context": context, "question": question},
55
+ max_answer_len=max_answer_len,
56
+ doc_stride=doc_stride
57
+ )
58
+ st.subheader("Result")
59
+ st.write(f"**Question:** {question}")
60
+ st.write(f"**Answer:** {prediction['answer']}")
61
+ except Exception as e:
62
+ st.error(f"Error: {e}")