Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import requests
|
5 |
+
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
6 |
+
from stable_baselines3 import PPO
|
7 |
+
from gym import spaces
|
8 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
|
11 |
+
# 1. Set up Groq API Integration
|
12 |
+
GROQ_API_KEY = "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" # Replace with your API key
|
13 |
+
groq_api_url = "https://api.groq.com/v1/traffic/optimize" # Replace with correct endpoint
|
14 |
+
|
15 |
+
headers = {
|
16 |
+
'Authorization': f'Bearer {GROQ_API_KEY}',
|
17 |
+
}
|
18 |
+
|
19 |
+
# Load Traffic Data from CSV (Replace with uploaded file handling for Gradio)
|
20 |
+
def load_traffic_data(file_path):
|
21 |
+
data = pd.read_csv(file_path)
|
22 |
+
return data
|
23 |
+
|
24 |
+
# 3. Function to interact with Groq API to get traffic optimization strategies
|
25 |
+
def get_optimization_strategy(traffic_data):
|
26 |
+
traffic_data = [int(x) if isinstance(x, np.int64) else x for x in traffic_data]
|
27 |
+
response = requests.post(groq_api_url, json={'traffic_data': traffic_data}, headers=headers)
|
28 |
+
if response.status_code == 200:
|
29 |
+
optimization_strategy = response.json()
|
30 |
+
return optimization_strategy
|
31 |
+
else:
|
32 |
+
return f"Error: {response.status_code}, {response.text}"
|
33 |
+
|
34 |
+
# 4. Create a Custom Traffic Environment for RL Simulation
|
35 |
+
class TrafficEnv(gym.Env):
|
36 |
+
def __init__(self, traffic_data):
|
37 |
+
super(TrafficEnv, self).__init__()
|
38 |
+
|
39 |
+
self.action_space = spaces.Discrete(3) # 3 possible actions
|
40 |
+
self.observation_space = spaces.Box(low=0, high=50, shape=(5,), dtype=np.float32)
|
41 |
+
self.current_state = np.zeros(5) # Start with zero traffic data
|
42 |
+
self.traffic_data = traffic_data
|
43 |
+
|
44 |
+
def reset(self):
|
45 |
+
self.current_state = np.array(self.traffic_data.iloc[0, 3:7], dtype=np.float32) # Use first row for starting state
|
46 |
+
return self.current_state
|
47 |
+
|
48 |
+
def step(self, action):
|
49 |
+
if action == 0: # Decrease traffic
|
50 |
+
self.current_state = self.current_state - np.random.randint(1, 5, size=self.current_state.shape)
|
51 |
+
elif action == 1: # No change
|
52 |
+
self.current_state = self.current_state
|
53 |
+
elif action == 2: # Increase traffic
|
54 |
+
self.current_state = self.current_state + np.random.randint(1, 5, size=self.current_state.shape)
|
55 |
+
self.current_state = np.clip(self.current_state, 0, None)
|
56 |
+
reward = -np.sum(self.current_state) # Minimize traffic (negative sum as reward)
|
57 |
+
done = np.sum(self.current_state) < 50
|
58 |
+
return self.current_state, reward, done, {}
|
59 |
+
|
60 |
+
def create_environment(traffic_data):
|
61 |
+
return DummyVecEnv([lambda: TrafficEnv(traffic_data)])
|
62 |
+
|
63 |
+
# Visualize Traffic Flow using Plotly
|
64 |
+
def visualize_traffic_flow(traffic_data):
|
65 |
+
locations = ['CarCount', 'BikeCount', 'BusCount', 'TruckCount']
|
66 |
+
traffic_flow = traffic_data.iloc[0, 3:7] # Use first row of traffic counts
|
67 |
+
fig = go.Figure(data=[go.Bar(x=locations, y=traffic_flow)])
|
68 |
+
fig.update_layout(title='Real-Time Traffic Flow', xaxis_title='Location', yaxis_title='Traffic Volume')
|
69 |
+
return fig
|
70 |
+
|
71 |
+
# RAG-based Optimization Strategy using Hugging Face Transformers
|
72 |
+
def rag_based_optimization(query):
|
73 |
+
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
74 |
+
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
|
75 |
+
|
76 |
+
inputs = tokenizer(query, return_tensors="pt")
|
77 |
+
generated_ids = model.generate(input_ids=inputs['input_ids'])
|
78 |
+
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
79 |
+
return generated_text
|
80 |
+
|
81 |
+
# Gradio Interface for the App
|
82 |
+
def optimize_traffic(traffic_file, query):
|
83 |
+
# Load Traffic Data from uploaded file
|
84 |
+
traffic_data = load_traffic_data(traffic_file.name)
|
85 |
+
|
86 |
+
# Get optimization strategy from Groq API
|
87 |
+
optimization_strategy = get_optimization_strategy(traffic_data.iloc[0, 3:7].values.tolist())
|
88 |
+
|
89 |
+
# Visualize traffic flow
|
90 |
+
traffic_fig = visualize_traffic_flow(traffic_data)
|
91 |
+
|
92 |
+
# Get RAG-based optimization strategy
|
93 |
+
rag_strategy = rag_based_optimization(query)
|
94 |
+
|
95 |
+
return optimization_strategy, traffic_fig, rag_strategy
|
96 |
+
|
97 |
+
# Create Gradio Interface
|
98 |
+
iface = gr.Interface(
|
99 |
+
fn=optimize_traffic,
|
100 |
+
inputs=[
|
101 |
+
gr.File(label="Upload Traffic Data CSV"),
|
102 |
+
gr.Textbox(label="Enter Optimization Query", value="Optimize traffic flow for downtown area.")
|
103 |
+
],
|
104 |
+
outputs=[
|
105 |
+
gr.JSON(label="Optimization Strategy from Groq API"),
|
106 |
+
gr.Plot(label="Traffic Flow Visualization"),
|
107 |
+
gr.Textbox(label="RAG-based Optimization Strategy")
|
108 |
+
],
|
109 |
+
live=True,
|
110 |
+
title="Traffic Optimization App",
|
111 |
+
description="This app optimizes traffic flow using RL, Groq API, and RAG model-based optimization strategies."
|
112 |
+
)
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
iface.launch()
|