King-Afridi commited on
Commit
b0f7449
·
verified ·
1 Parent(s): 798b261

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py CHANGED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import requests
5
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
6
+ from stable_baselines3 import PPO
7
+ from gym import spaces
8
+ from stable_baselines3.common.vec_env import DummyVecEnv
9
+ import plotly.graph_objects as go
10
+
11
+ # 1. Set up Groq API Integration
12
+ GROQ_API_KEY = "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" # Replace with your API key
13
+ groq_api_url = "https://api.groq.com/v1/traffic/optimize" # Replace with correct endpoint
14
+
15
+ headers = {
16
+ 'Authorization': f'Bearer {GROQ_API_KEY}',
17
+ }
18
+
19
+ # Load Traffic Data from CSV (Replace with uploaded file handling for Gradio)
20
+ def load_traffic_data(file_path):
21
+ data = pd.read_csv(file_path)
22
+ return data
23
+
24
+ # 3. Function to interact with Groq API to get traffic optimization strategies
25
+ def get_optimization_strategy(traffic_data):
26
+ traffic_data = [int(x) if isinstance(x, np.int64) else x for x in traffic_data]
27
+ response = requests.post(groq_api_url, json={'traffic_data': traffic_data}, headers=headers)
28
+ if response.status_code == 200:
29
+ optimization_strategy = response.json()
30
+ return optimization_strategy
31
+ else:
32
+ return f"Error: {response.status_code}, {response.text}"
33
+
34
+ # 4. Create a Custom Traffic Environment for RL Simulation
35
+ class TrafficEnv(gym.Env):
36
+ def __init__(self, traffic_data):
37
+ super(TrafficEnv, self).__init__()
38
+
39
+ self.action_space = spaces.Discrete(3) # 3 possible actions
40
+ self.observation_space = spaces.Box(low=0, high=50, shape=(5,), dtype=np.float32)
41
+ self.current_state = np.zeros(5) # Start with zero traffic data
42
+ self.traffic_data = traffic_data
43
+
44
+ def reset(self):
45
+ self.current_state = np.array(self.traffic_data.iloc[0, 3:7], dtype=np.float32) # Use first row for starting state
46
+ return self.current_state
47
+
48
+ def step(self, action):
49
+ if action == 0: # Decrease traffic
50
+ self.current_state = self.current_state - np.random.randint(1, 5, size=self.current_state.shape)
51
+ elif action == 1: # No change
52
+ self.current_state = self.current_state
53
+ elif action == 2: # Increase traffic
54
+ self.current_state = self.current_state + np.random.randint(1, 5, size=self.current_state.shape)
55
+ self.current_state = np.clip(self.current_state, 0, None)
56
+ reward = -np.sum(self.current_state) # Minimize traffic (negative sum as reward)
57
+ done = np.sum(self.current_state) < 50
58
+ return self.current_state, reward, done, {}
59
+
60
+ def create_environment(traffic_data):
61
+ return DummyVecEnv([lambda: TrafficEnv(traffic_data)])
62
+
63
+ # Visualize Traffic Flow using Plotly
64
+ def visualize_traffic_flow(traffic_data):
65
+ locations = ['CarCount', 'BikeCount', 'BusCount', 'TruckCount']
66
+ traffic_flow = traffic_data.iloc[0, 3:7] # Use first row of traffic counts
67
+ fig = go.Figure(data=[go.Bar(x=locations, y=traffic_flow)])
68
+ fig.update_layout(title='Real-Time Traffic Flow', xaxis_title='Location', yaxis_title='Traffic Volume')
69
+ return fig
70
+
71
+ # RAG-based Optimization Strategy using Hugging Face Transformers
72
+ def rag_based_optimization(query):
73
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
74
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
75
+
76
+ inputs = tokenizer(query, return_tensors="pt")
77
+ generated_ids = model.generate(input_ids=inputs['input_ids'])
78
+ generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
79
+ return generated_text
80
+
81
+ # Gradio Interface for the App
82
+ def optimize_traffic(traffic_file, query):
83
+ # Load Traffic Data from uploaded file
84
+ traffic_data = load_traffic_data(traffic_file.name)
85
+
86
+ # Get optimization strategy from Groq API
87
+ optimization_strategy = get_optimization_strategy(traffic_data.iloc[0, 3:7].values.tolist())
88
+
89
+ # Visualize traffic flow
90
+ traffic_fig = visualize_traffic_flow(traffic_data)
91
+
92
+ # Get RAG-based optimization strategy
93
+ rag_strategy = rag_based_optimization(query)
94
+
95
+ return optimization_strategy, traffic_fig, rag_strategy
96
+
97
+ # Create Gradio Interface
98
+ iface = gr.Interface(
99
+ fn=optimize_traffic,
100
+ inputs=[
101
+ gr.File(label="Upload Traffic Data CSV"),
102
+ gr.Textbox(label="Enter Optimization Query", value="Optimize traffic flow for downtown area.")
103
+ ],
104
+ outputs=[
105
+ gr.JSON(label="Optimization Strategy from Groq API"),
106
+ gr.Plot(label="Traffic Flow Visualization"),
107
+ gr.Textbox(label="RAG-based Optimization Strategy")
108
+ ],
109
+ live=True,
110
+ title="Traffic Optimization App",
111
+ description="This app optimizes traffic flow using RL, Groq API, and RAG model-based optimization strategies."
112
+ )
113
+
114
+ if __name__ == "__main__":
115
+ iface.launch()