mouadenna commited on
Commit
90a4d87
1 Parent(s): 355da42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -37
app.py CHANGED
@@ -12,9 +12,8 @@ import os
12
 
13
  import torch
14
  from peft import PeftModel, PeftConfig
15
- from transformers import AutoModelForCausalLM,AutoTokenizer, LlamaForCausalLM
16
 
17
- import webbrowser
18
  import joblib
19
  from deployML import predd
20
 
@@ -22,21 +21,21 @@ import time
22
 
23
  #load the chatbot
24
 
25
- model = LlamaForCausalLM.from_pretrained(
26
- "medalpaca/medalpaca-7b",
27
- return_dict=True,
28
- load_in_8bit=True,
29
- device_map="auto",
30
- )
31
 
32
- tokenizer = AutoTokenizer.from_pretrained("medalpaca/medalpaca-7b")
 
 
 
 
33
 
34
 
35
  #load the first interface
36
 
37
 
38
  def fn(*args):
39
- global flag
40
  global symptoms
41
  all_symptoms = [symptom for symptom_list in args for symptom in symptom_list]
42
 
@@ -46,14 +45,12 @@ def fn(*args):
46
  raise gr.Error("Please select at least 3 symptoms.")
47
 
48
  symptoms = all_symptoms # Update global symptoms list
49
- #webbrowser.open_new_tab(url)
50
- flag=1
51
- return symptoms
52
 
53
 
54
  symptoms = []
55
 
56
- flag=0
57
 
58
 
59
  demo = gr.Interface(
@@ -72,20 +69,13 @@ demo = gr.Interface(
72
  gr.CheckboxGroup(['knee pain', 'hip joint pain', 'swelling joints'], label='Joint and Bone Issues'),
73
  gr.CheckboxGroup(['spinning movements', 'unsteadiness'], label='Neurological Movements')
74
  ],
75
- outputs=None,
76
  )
77
 
78
 
79
 
80
 
81
 
82
- def wait_until_variable_changes( target_value):
83
- global flag
84
- while flag != target_value:
85
- time.sleep(1)
86
-
87
-
88
-
89
 
90
  def predict(message, history):
91
  prompt = f"""
@@ -96,7 +86,7 @@ def predict(message, history):
96
  batch = tokenizer(prompt, return_tensors='pt')
97
  with torch.cuda.amp.autocast():
98
 
99
- output_tokens = model.generate(**batch, max_new_tokens=60)
100
 
101
  return tokenizer.decode(output_tokens[0], skip_special_tokens=True).replace(prompt,"")
102
 
@@ -105,20 +95,10 @@ loaded_rf = joblib.load("model_joblib")
105
  Fmessage="hello im here to help you!"
106
 
107
 
108
- if __name__ == "__main__":
109
- demo.launch()
110
-
111
- wait_until_variable_changes(1)
112
-
113
-
114
-
115
- if symptoms:
116
- Fmessage=predd(loaded_rf,symptoms)
117
-
118
-
119
- chatbot=gr.ChatInterface(predict,chatbot=gr.Chatbot(value=[(None, Fmessage)],), clear_btn=None, retry_btn=None, undo_btn=None)
120
 
 
121
 
122
 
123
- demo.close()
124
- chatbot.launch()
 
 
12
 
13
  import torch
14
  from peft import PeftModel, PeftConfig
15
+ from transformers import AutoModelForCausalLM,AutoTokenizer
16
 
 
17
  import joblib
18
  from deployML import predd
19
 
 
21
 
22
  #load the chatbot
23
 
24
+ peft_model_id = "medalpaca/medalpaca-7b"
25
+
26
+ config = PeftConfig.from_pretrained(peft_model_id)
 
 
 
27
 
28
+ model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto')
29
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
30
+
31
+ # Load the Lora model
32
+ model = PeftModel.from_pretrained(model, peft_model_id)
33
 
34
 
35
  #load the first interface
36
 
37
 
38
  def fn(*args):
 
39
  global symptoms
40
  all_symptoms = [symptom for symptom_list in args for symptom in symptom_list]
41
 
 
45
  raise gr.Error("Please select at least 3 symptoms.")
46
 
47
  symptoms = all_symptoms # Update global symptoms list
48
+ return predd(loaded_rf,symptoms)
49
+
 
50
 
51
 
52
  symptoms = []
53
 
 
54
 
55
 
56
  demo = gr.Interface(
 
69
  gr.CheckboxGroup(['knee pain', 'hip joint pain', 'swelling joints'], label='Joint and Bone Issues'),
70
  gr.CheckboxGroup(['spinning movements', 'unsteadiness'], label='Neurological Movements')
71
  ],
72
+ outputs="textbox",allow_flagging="never"
73
  )
74
 
75
 
76
 
77
 
78
 
 
 
 
 
 
 
 
79
 
80
  def predict(message, history):
81
  prompt = f"""
 
86
  batch = tokenizer(prompt, return_tensors='pt')
87
  with torch.cuda.amp.autocast():
88
 
89
+ output_tokens = model.generate(**batch, max_new_tokens=100)
90
 
91
  return tokenizer.decode(output_tokens[0], skip_special_tokens=True).replace(prompt,"")
92
 
 
95
  Fmessage="hello im here to help you!"
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ chatbot=gr.ChatInterface(predict, clear_btn=None, retry_btn=None, undo_btn=None)
100
 
101
 
102
+ gr.TabbedInterface(
103
+ [demo, chatbot], ["symptoms checker", "chatbot"]
104
+ ).launch()