karthikeyan-r commited on
Commit
be79c33
·
verified ·
1 Parent(s): eeaf9ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -129
app.py CHANGED
@@ -34,59 +34,54 @@ if "tokenizer" not in st.session_state:
34
  if "qa_pipeline" not in st.session_state:
35
  st.session_state["qa_pipeline"] = None
36
  if "conversation" not in st.session_state:
37
- # We'll store conversation as a list of dicts,
38
- # e.g. [{"role": "assistant", "content": "Hello..."}, {"role": "user", "content": "..."}]
39
  st.session_state["conversation"] = []
40
 
41
  # ----- Load Model -----
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  if load_model_button:
43
- with st.spinner("Loading model..."):
44
- try:
45
- if model_choice == model_options["1"]:
46
- # Load the calculation model
47
- tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
48
- model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
49
-
50
- # Add special tokens if needed
51
- if tokenizer.pad_token is None:
52
- tokenizer.add_special_tokens({'pad_token': '[PAD]'})
53
- model.resize_token_embeddings(len(tokenizer))
54
- if tokenizer.eos_token is None:
55
- tokenizer.add_special_tokens({'eos_token': '[EOS]'})
56
- model.resize_token_embeddings(len(tokenizer))
57
-
58
- model.config.pad_token_id = tokenizer.pad_token_id
59
- model.config.eos_token_id = tokenizer.eos_token_id
60
-
61
- st.session_state["model"] = model
62
- st.session_state["tokenizer"] = tokenizer
63
- st.session_state["qa_pipeline"] = None # Not needed for calculation model
64
-
65
- elif model_choice == model_options["2"]:
66
- # Load the T5 model for general QA
67
- device = 0 if torch.cuda.is_available() else -1
68
- model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
69
- tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
70
- qa_pipe = pipeline(
71
- "text2text-generation",
72
- model=model,
73
- tokenizer=tokenizer,
74
- device=device
75
- )
76
- st.session_state["model"] = model
77
- st.session_state["tokenizer"] = tokenizer
78
- st.session_state["qa_pipeline"] = qa_pipe
79
-
80
- # If conversation is empty, insert a welcome message
81
- if len(st.session_state["conversation"]) == 0:
82
- st.session_state["conversation"].append({
83
- "role": "assistant",
84
- "content": "Hello! I’m your assistant. How can I help you today?"
85
- })
86
-
87
- st.success("Model loaded successfully and ready!")
88
- except Exception as e:
89
- st.error(f"Error loading model: {e}")
90
 
91
  # ----- Clear Model -----
92
  if clear_model_button:
@@ -103,93 +98,42 @@ if clear_conversation_button:
103
  # ----- Title -----
104
  st.title("Chat Conversation UI")
105
 
106
-
107
- user_input = None
108
-
109
- if st.session_state["qa_pipeline"]:
110
- # T5 pipeline
111
- user_input = st.chat_input("Enter your query:")
112
- if user_input:
113
- # 1) Save user message
114
- st.session_state["conversation"].append({
115
- "role": "user",
116
- "content": user_input
117
- })
118
-
119
- # 2) Generate assistant response
120
  try:
121
- response = st.session_state["qa_pipeline"](
122
- f"Q: {user_input}", max_length=250
123
- )
124
  answer = response[0]["generated_text"]
125
  except Exception as e:
126
  answer = f"Error: {str(e)}"
127
-
128
- # 3) Append assistant message to conversation
129
- st.session_state["conversation"].append({
130
- "role": "assistant",
131
- "content": answer
132
- })
133
-
134
- elif st.session_state["model"] and (model_choice == model_options["1"]):
135
- # Calculation model
136
- user_input = st.chat_input("Enter your query for calculation:")
137
- if user_input:
138
- # 1) Save user message
139
- st.session_state["conversation"].append({
140
- "role": "user",
141
- "content": user_input
142
- })
143
-
144
- # 2) Generate assistant response
145
- tokenizer = st.session_state["tokenizer"]
146
- model = st.session_state["model"]
147
-
148
  try:
149
- inputs = tokenizer(
150
- f"Input: {user_input}\nOutput:",
151
- return_tensors="pt",
152
- padding=True,
153
- truncation=True
154
- )
155
- input_ids = inputs.input_ids
156
- attention_mask = inputs.attention_mask
157
-
158
- output = model.generate(
159
- input_ids=input_ids,
160
- attention_mask=attention_mask,
161
- max_length=250,
162
- pad_token_id=tokenizer.pad_token_id,
163
- eos_token_id=tokenizer.eos_token_id,
164
- do_sample=False
165
- )
166
-
167
- decoded_output = tokenizer.decode(
168
- output[0],
169
- skip_special_tokens=True
170
- )
171
- # Extract answer after 'Output:' if present
172
- if "Output:" in decoded_output:
173
- answer = decoded_output.split("Output:")[-1].strip()
174
- else:
175
- answer = decoded_output.strip()
176
  except Exception as e:
177
  answer = f"Error: {str(e)}"
 
 
178
 
179
- # 3) Append assistant message to conversation
180
- st.session_state["conversation"].append({
181
- "role": "assistant",
182
- "content": answer
183
- })
184
- else:
185
- # If no model is loaded:
186
- st.info("No model is loaded. Please select a model and click 'Load Model' from the sidebar.")
187
-
188
 
 
189
  for message in st.session_state["conversation"]:
190
- if message["role"] == "user":
191
- with st.chat_message("user"):
192
- st.write(message["content"])
193
- else:
194
- with st.chat_message("assistant"):
195
- st.write(message["content"])
 
34
  if "qa_pipeline" not in st.session_state:
35
  st.session_state["qa_pipeline"] = None
36
  if "conversation" not in st.session_state:
 
 
37
  st.session_state["conversation"] = []
38
 
39
  # ----- Load Model -----
40
+ def load_model():
41
+ if st.session_state["model"] is None or st.session_state["tokenizer"] is None:
42
+ with st.spinner("Loading model..."):
43
+ try:
44
+ if model_choice == model_options["1"]:
45
+ # Load the calculation model
46
+ tokenizer = AutoTokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
47
+ model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir="./model_cache")
48
+
49
+ # Add special tokens if needed
50
+ if tokenizer.pad_token is None:
51
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
52
+ model.resize_token_embeddings(len(tokenizer))
53
+ if tokenizer.eos_token is None:
54
+ tokenizer.add_special_tokens({'eos_token': '[EOS]'})
55
+ model.resize_token_embeddings(len(tokenizer))
56
+
57
+ model.config.pad_token_id = tokenizer.pad_token_id
58
+ model.config.eos_token_id = tokenizer.eos_token_id
59
+
60
+ st.session_state["model"] = model
61
+ st.session_state["tokenizer"] = tokenizer
62
+ st.session_state["qa_pipeline"] = None # Not needed for calculation model
63
+
64
+ elif model_choice == model_options["2"]:
65
+ # Load the T5 model for general QA
66
+ device = 0 if torch.cuda.is_available() else -1
67
+ model = T5ForConditionalGeneration.from_pretrained(model_choice, cache_dir="./model_cache")
68
+ tokenizer = T5Tokenizer.from_pretrained(model_choice, cache_dir="./model_cache")
69
+ qa_pipe = pipeline(
70
+ "text2text-generation",
71
+ model=model,
72
+ tokenizer=tokenizer,
73
+ device=device
74
+ )
75
+ st.session_state["model"] = model
76
+ st.session_state["tokenizer"] = tokenizer
77
+ st.session_state["qa_pipeline"] = qa_pipe
78
+
79
+ st.success("Model loaded successfully and ready!")
80
+ except Exception as e:
81
+ st.error(f"Error loading model: {e}")
82
+
83
  if load_model_button:
84
+ load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # ----- Clear Model -----
87
  if clear_model_button:
 
98
  # ----- Title -----
99
  st.title("Chat Conversation UI")
100
 
101
+ # ----- User Input and Processing -----
102
+ user_input = st.chat_input("Enter your query:")
103
+ if user_input:
104
+ # Save user input
105
+ st.session_state["conversation"].append({
106
+ "role": "user",
107
+ "content": user_input
108
+ })
109
+
110
+ # Generate response
111
+ if st.session_state["qa_pipeline"]:
 
 
 
112
  try:
113
+ response = st.session_state["qa_pipeline"](f"Q: {user_input}", max_length=250)
 
 
114
  answer = response[0]["generated_text"]
115
  except Exception as e:
116
  answer = f"Error: {str(e)}"
117
+ elif st.session_state["model"] and model_choice == model_options["1"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  try:
119
+ tokenizer = st.session_state["tokenizer"]
120
+ model = st.session_state["model"]
121
+
122
+ inputs = tokenizer(f"Input: {user_input}\nOutput:", return_tensors="pt", padding=True, truncation=True)
123
+ output = model.generate(inputs.input_ids, max_length=250, pad_token_id=tokenizer.pad_token_id)
124
+ answer = tokenizer.decode(output[0], skip_special_tokens=True).split("Output:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  except Exception as e:
126
  answer = f"Error: {str(e)}"
127
+ else:
128
+ answer = "No model is loaded. Please select and load a model."
129
 
130
+ # Save assistant response
131
+ st.session_state["conversation"].append({
132
+ "role": "assistant",
133
+ "content": answer
134
+ })
 
 
 
 
135
 
136
+ # Display conversation
137
  for message in st.session_state["conversation"]:
138
+ with st.chat_message(message["role"]):
139
+ st.write(message["content"])