Sreekan commited on
Commit
8569432
·
verified ·
1 Parent(s): a35a697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -161,16 +161,28 @@ chain2 = LLMChain(llm=llm2, prompt=PromptTemplate(
161
  template="You are in state s2. {{query}}"
162
  ))
163
 
164
- # Create a state graph for managing the chatbot's states
165
- graph = StateGraph()
166
-
167
- # Create states and add them to the graph
168
- state1 = graph.add_state("s1") # State for the first agent
169
- state2 = graph.add_state("s2") # State for the second agent
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  # Define transitions
172
- graph.add_edge(state1, state2, "next") # Transition from s1 to s2
173
- graph.add_edge(state2, state1, "back") # Transition from s2 to s1
174
 
175
  # Initialize the current state
176
  current_state = state1
@@ -180,11 +192,11 @@ def handle_input(query):
180
  output = ''
181
 
182
  # Process user input based on current state
183
- if current_state == state1:
184
  output = chain1.invoke(input=query) # Invoke chain1 with user input
185
  response = agent1(output) # Process output through Agent 1
186
  current_state = state2 # Transition to state s2
187
- elif current_state == state2:
188
  output = chain2.invoke(input=query) # Invoke chain2 with user input
189
  response = agent2(output) # Process output through Agent 2
190
  current_state = state1 # Transition back to state s1
@@ -207,3 +219,4 @@ with gr.Blocks() as demo:
207
 
208
  # Launch the Gradio application
209
  demo.launch()
 
 
161
  template="You are in state s2. {{query}}"
162
  ))
163
 
164
+ # Define the state schema
165
+ state_schema = {
166
+ "s1": {
167
+ "inputs": ["query"],
168
+ "outputs": ["response"]
169
+ },
170
+ "s2": {
171
+ "inputs": ["query"],
172
+ "outputs": ["response"]
173
+ }
174
+ }
175
+
176
+ # Create a state graph with required schemas for inputs and outputs
177
+ graph = StateGraph(state_schema=state_schema)
178
+
179
+ # Add states to the graph
180
+ state1 = graph.add_state(name="s1")
181
+ state2 = graph.add_state(name="s2")
182
 
183
  # Define transitions
184
+ graph.add_edge(state1, state2, label="next") # Transition from s1 to s2
185
+ graph.add_edge(state2, state1, label="back") # Transition from s2 to s1
186
 
187
  # Initialize the current state
188
  current_state = state1
 
192
  output = ''
193
 
194
  # Process user input based on current state
195
+ if current_state.name == "s1":
196
  output = chain1.invoke(input=query) # Invoke chain1 with user input
197
  response = agent1(output) # Process output through Agent 1
198
  current_state = state2 # Transition to state s2
199
+ elif current_state.name == "s2":
200
  output = chain2.invoke(input=query) # Invoke chain2 with user input
201
  response = agent2(output) # Process output through Agent 2
202
  current_state = state1 # Transition back to state s1
 
219
 
220
  # Launch the Gradio application
221
  demo.launch()
222
+