Really-amin commited on
Commit
a065de3
·
verified ·
1 Parent(s): f8c382e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -31
app.py CHANGED
@@ -1,49 +1,73 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
- # بارگذاری مدل
 
 
 
6
  @st.cache_resource
7
- def load_model():
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/bert-fa-base-uncased")
9
- model = AutoModelForSequenceClassification.from_pretrained("HooshvareLab/bert-fa-base-uncased")
10
- return tokenizer, model
 
 
 
 
 
 
 
 
11
 
12
- tokenizer, model = load_model()
 
 
 
13
 
14
- # تولید پاسخ
15
- def generate_response(text):
16
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
17
- with torch.no_grad():
18
- outputs = model(**inputs)
19
  logits = outputs.logits
20
- predicted_class = torch.argmax(logits, dim=1).item()
21
- response = f"پاسخ مدل: دسته‌بندی {predicted_class}"
22
- return response
 
 
 
 
23
 
24
  # رابط کاربری
25
  def main():
26
- st.set_page_config(page_title="دستیار هوش مصنوعی", layout="wide")
27
- st.title("دستیار هوش مصنوعی (BERT فارسی)")
28
 
29
- # مدیریت پیام‌ها با Session State
30
- if "messages" not in st.session_state:
31
- st.session_state.messages = []
 
 
 
32
 
33
- # نمایش پیام‌ها
34
- for message in st.session_state.messages:
35
- if message["role"] == "user":
36
- st.write(f"👤 کاربر: {message['content']}")
37
  else:
38
- st.write(f"🤖 دستیار: {message['content']}")
 
 
 
39
 
40
- # ورودی کاربر
41
- user_input = st.text_input("پیام خود را وارد کنید:")
42
- if user_input:
43
- st.session_state.messages.append({"role": "user", "content": user_input})
44
- response = generate_response(user_input)
45
- st.session_state.messages.append({"role": "assistant", "content": response})
46
- st.experimental_rerun()
 
47
 
48
  if __name__ == "__main__":
49
  main()
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForMaskedLM
3
  import torch
4
 
5
+ # تنظیمات صفحه
6
+ st.set_page_config(page_title="دستیار هوش مصنوعی (پر کردن ماسک)", layout="wide")
7
+
8
+ # تابع بارگذاری مدل و pipeline
9
  @st.cache_resource
10
+ def load_model_and_pipeline():
11
+ # بارگذاری pipeline
12
+ pipe = pipeline("fill-mask", model="HooshvareLab/bert-fa-base-uncased")
13
+
14
+ # بارگذاری مدل و توکنایزر به صورت مستقیم
15
  tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/bert-fa-base-uncased")
16
+ model = AutoModelForMaskedLM.from_pretrained("HooshvareLab/bert-fa-base-uncased")
17
+
18
+ return pipe, tokenizer, model
19
+
20
+ pipe, tokenizer, model = load_model_and_pipeline()
21
+
22
+ # تابع پیش‌بینی با pipeline
23
+ def predict_with_pipeline(text):
24
+ results = pipe(text)
25
+ return [{"word": res["token_str"], "score": res["score"]} for res in results]
26
 
27
+ # تابع پیش‌بینی با مدل مستقیم
28
+ def predict_with_model(text):
29
+ inputs = tokenizer(text, return_tensors="pt")
30
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
31
 
32
+ # اجرای مدل
33
+ outputs = model(**inputs)
 
 
 
34
  logits = outputs.logits
35
+
36
+ # پیدا کردن توکن‌های برتر
37
+ mask_token_logits = logits[0, mask_token_index, :]
38
+ top_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
39
+
40
+ results = [{"word": tokenizer.decode([token])} for token in top_tokens]
41
+ return results
42
 
43
  # رابط کاربری
44
  def main():
45
+ st.title("دستیار هوش مصنوعی (پر کردن ماسک)")
 
46
 
47
+ # ورودی کاربر
48
+ st.subheader("ورودی متن:")
49
+ text = st.text_input(
50
+ "متن خود را وارد کنید (از [MASK] برای نشان دادن کلمه حذف شده استفاده کنید):",
51
+ value="ما در هوشواره معتقدیم [MASK] دانش و آگاهی می‌تواند جامعه را تغییر دهد.",
52
+ )
53
 
54
+ if st.button("پیش‌بینی با pipeline"):
55
+ if "[MASK]" not in text:
56
+ st.error("لطفاً یک متن شامل [MASK] وارد کنید.")
 
57
  else:
58
+ st.subheader("نتایج پیش‌بینی با pipeline:")
59
+ predictions = predict_with_pipeline(text)
60
+ for pred in predictions:
61
+ st.write(f"کلمه: {pred['word']} - احتمال: {pred['score']:.2f}")
62
 
63
+ if st.button("پیش‌بینی با مدل مستقیم"):
64
+ if "[MASK]" not in text:
65
+ st.error("لطفاً یک متن شامل [MASK] وارد کنید.")
66
+ else:
67
+ st.subheader("نتایج پیش‌بینی با مدل مستقیم:")
68
+ predictions = predict_with_model(text)
69
+ for pred in predictions:
70
+ st.write(f"کلمه: {pred['word']}")
71
 
72
  if __name__ == "__main__":
73
  main()