Timmyafolami commited on
Commit
79b0a63
·
verified ·
1 Parent(s): 72ef8c1

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +99 -99
main.py CHANGED
@@ -1,99 +1,99 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Request
2
- from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from detection import initialize_directories, slice_geotiff, detect_weeds_in_slices, cleanup, create_zip # Import detection functions
5
- from db_user_info import insert_user_info
6
- from db_bucket import upload_file_to_bucket
7
- from typing import List
8
- from io import BytesIO
9
-
10
- app = FastAPI()
11
-
12
- # Mount static files
13
- app.mount("/static", StaticFiles(directory="static"), name="static")
14
-
15
- # WebSocket manager
16
- class ConnectionManager:
17
- def __init__(self):
18
- self.active_connections: List[WebSocket] = []
19
-
20
- async def connect(self, websocket: WebSocket):
21
- await websocket.accept()
22
- self.active_connections.append(websocket)
23
-
24
- def disconnect(self, websocket: WebSocket):
25
- self.active_connections.remove(websocket)
26
-
27
- async def send_message(self, message: str):
28
- for connection in self.active_connections:
29
- await connection.send_text(message)
30
-
31
- manager = ConnectionManager()
32
-
33
- @app.get("/", response_class=HTMLResponse)
34
- async def read_index():
35
- with open("static/index.html") as f:
36
- return HTMLResponse(content=f.read())
37
-
38
- @app.get("/app", response_class=HTMLResponse)
39
- async def read_app():
40
- with open("static/app.html") as f:
41
- return HTMLResponse(content=f.read())
42
-
43
- @app.post("/register")
44
- async def register_user(request: Request):
45
- user_info = await request.json()
46
- try:
47
- insert_user_info(user_info)
48
- return JSONResponse(content={"detail": "User registered successfully"}, status_code=200)
49
- except Exception as e:
50
- return JSONResponse(content={"detail": str(e)}, status_code=400)
51
-
52
- @app.post("/upload_geotiff/")
53
- async def upload_geotiff(file: UploadFile = File(...)):
54
- # Initialize directories at the start
55
- initialize_directories()
56
-
57
- file_location = f"uploaded_geotiff.tif"
58
- with open(file_location, "wb") as f:
59
- f.write(file.file.read())
60
-
61
- await manager.send_message("GeoTIFF file uploaded successfully. Slicing started.")
62
- slices = await slice_geotiff(file_location, slice_size=3000)
63
- await manager.send_message("Slicing complete. Starting weed detection.")
64
- weed_bboxes = await detect_weeds_in_slices(slices)
65
- await manager.send_message("Weed detection complete. Generating shapefile.")
66
-
67
- # Create zip file
68
- zip_file_path = await create_zip()
69
- await manager.send_message("Shapefiles Generated. Zipping shapefile.")
70
-
71
- # Upload the zip file to the bucket
72
- response = upload_file_to_bucket(zip_file_path)
73
- print(response)
74
- await manager.send_message("Zip file uploaded to bucket storage.")
75
-
76
- # Read zip file into buffer for download
77
- zip_buffer = BytesIO()
78
- with open(zip_file_path, 'rb') as f:
79
- zip_buffer.write(f.read())
80
- zip_buffer.seek(0)
81
-
82
- # Cleanup files and directories
83
- cleanup()
84
-
85
- return StreamingResponse(zip_buffer, media_type="application/zip", headers={"Content-Disposition": "attachment; filename=weed_detections.zip"})
86
-
87
- @app.websocket("/ws")
88
- async def websocket_endpoint(websocket: WebSocket):
89
- await manager.connect(websocket)
90
- try:
91
- while True:
92
- data = await websocket.receive_text()
93
- await manager.send_message(data)
94
- except WebSocketDisconnect:
95
- manager.disconnect(websocket)
96
-
97
-
98
- if __name__ == "__main__":
99
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Request
2
+ from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from detection import initialize_directories, slice_geotiff, detect_weeds_in_slices, cleanup, create_zip # Import detection functions
5
+ from db_user_info import insert_user_info
6
+ from db_bucket import upload_file_to_bucket
7
+ from typing import List
8
+ from io import BytesIO
9
+
10
+ app = FastAPI()
11
+
12
+ # Mount static files
13
+ app.mount("/static", StaticFiles(directory="static"), name="static")
14
+
15
+ # WebSocket manager
16
+ class ConnectionManager:
17
+ def __init__(self):
18
+ self.active_connections: List[WebSocket] = []
19
+
20
+ async def connect(self, websocket: WebSocket):
21
+ await websocket.accept()
22
+ self.active_connections.append(websocket)
23
+
24
+ def disconnect(self, websocket: WebSocket):
25
+ self.active_connections.remove(websocket)
26
+
27
+ async def send_message(self, message: str):
28
+ for connection in self.active_connections:
29
+ await connection.send_text(message)
30
+
31
+ manager = ConnectionManager()
32
+
33
+ @app.get("/", response_class=HTMLResponse)
34
+ async def read_index():
35
+ with open("static/index.html") as f:
36
+ return HTMLResponse(content=f.read())
37
+
38
+ @app.get("/app", response_class=HTMLResponse)
39
+ async def read_app():
40
+ with open("static/app.html") as f:
41
+ return HTMLResponse(content=f.read())
42
+
43
+ @app.post("/register")
44
+ async def register_user(request: Request):
45
+ user_info = await request.json()
46
+ try:
47
+ insert_user_info(user_info)
48
+ return JSONResponse(content={"detail": "User registered successfully"}, status_code=200)
49
+ except Exception as e:
50
+ return JSONResponse(content={"detail": str(e)}, status_code=400)
51
+
52
+ @app.post("/upload_geotiff/")
53
+ async def upload_geotiff(file: UploadFile = File(...)):
54
+ # Initialize directories at the start
55
+ initialize_directories()
56
+
57
+ file_location = f"uploaded_geotiff.tif"
58
+ with open(file_location, "wb") as f:
59
+ f.write(file.file.read())
60
+
61
+ await manager.send_message("GeoTIFF file uploaded successfully. Slicing started.")
62
+ slices = await slice_geotiff(file_location, slice_size=3000)
63
+ await manager.send_message("Slicing complete. Starting weed detection.")
64
+ weed_bboxes = await detect_weeds_in_slices(slices)
65
+ await manager.send_message("Weed detection complete. Generating shapefile.")
66
+
67
+ # Create zip file
68
+ zip_file_path = await create_zip()
69
+ await manager.send_message("Shapefiles Generated. Zipping shapefile.")
70
+
71
+ # Upload the zip file to the bucket
72
+ response = upload_file_to_bucket(zip_file_path)
73
+ print(response)
74
+ await manager.send_message("Zip file uploaded to bucket storage.")
75
+
76
+ # Read zip file into buffer for download
77
+ zip_buffer = BytesIO()
78
+ with open(zip_file_path, 'rb') as f:
79
+ zip_buffer.write(f.read())
80
+ zip_buffer.seek(0)
81
+
82
+ # Cleanup files and directories
83
+ cleanup()
84
+
85
+ return StreamingResponse(zip_buffer, media_type="application/zip", headers={"Content-Disposition": "attachment; filename=weed_detections.zip"})
86
+
87
+ @app.websocket("/ws")
88
+ async def websocket_endpoint(websocket: WebSocket):
89
+ await manager.connect(websocket)
90
+ try:
91
+ while True:
92
+ data = await websocket.receive_text()
93
+ await manager.send_message(data)
94
+ except WebSocketDisconnect:
95
+ manager.disconnect(websocket)
96
+
97
+
98
+ if __name__ == "__main__":
99
+ uvicorn.run(app, host="0.0.0.0", port=7860)