Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse | |
from fastapi.staticfiles import StaticFiles | |
from detection import initialize_directories, slice_geotiff, detect_weeds_in_slices, cleanup, create_zip # Import detection functions | |
from db_user_info import insert_user_info | |
from db_bucket import upload_file_to_bucket | |
from typing import List | |
from io import BytesIO | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Or specify the frontend origin | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Mount static files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# WebSocket manager | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections: List[WebSocket] = [] | |
async def connect(self, websocket: WebSocket): | |
await websocket.accept() | |
self.active_connections.append(websocket) | |
def disconnect(self, websocket: WebSocket): | |
self.active_connections.remove(websocket) | |
async def send_message(self, message: str): | |
for connection in self.active_connections: | |
await connection.send_text(message) | |
manager = ConnectionManager() | |
async def read_index(): | |
with open("static/index.html") as f: | |
return HTMLResponse(content=f.read()) | |
async def read_app(): | |
with open("static/app.html") as f: | |
return HTMLResponse(content=f.read()) | |
async def register_user(request: Request): | |
user_info = await request.json() | |
try: | |
insert_user_info(user_info) | |
return JSONResponse(content={"detail": "User registered successfully"}, status_code=200) | |
except Exception as e: | |
return JSONResponse(content={"detail": str(e)}, status_code=400) | |
async def upload_geotiff(file: UploadFile = File(...)): | |
# Initialize directories at the start | |
initialize_directories() | |
file_location = f"uploaded_geotiff.tif" | |
with open(file_location, "wb") as f: | |
f.write(file.file.read()) | |
await manager.send_message("GeoTIFF file uploaded successfully. Slicing started.") | |
slices = await slice_geotiff(file_location, slice_size=3000) | |
await manager.send_message("Slicing complete. Starting weed detection.") | |
weed_bboxes = await detect_weeds_in_slices(slices) | |
await manager.send_message("Weed detection complete. Generating shapefile.") | |
# Create zip file | |
zip_file_path = await create_zip() | |
await manager.send_message("Shapefiles Generated. Zipping shapefile.") | |
# Upload the zip file to the bucket | |
response = upload_file_to_bucket(zip_file_path) | |
print(response) | |
await manager.send_message("Zip file uploaded to bucket storage.") | |
# Read zip file into buffer for download | |
zip_buffer = BytesIO() | |
with open(zip_file_path, 'rb') as f: | |
zip_buffer.write(f.read()) | |
zip_buffer.seek(0) | |
# Cleanup files and directories | |
cleanup() | |
return StreamingResponse(zip_buffer, media_type="application/zip", headers={"Content-Disposition": "attachment; filename=weed_detections.zip"}) | |
async def websocket_endpoint(websocket: WebSocket): | |
await manager.connect(websocket) | |
try: | |
while True: | |
data = await websocket.receive_text() | |
await manager.send_message(data) | |
except WebSocketDisconnect: | |
manager.disconnect(websocket) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |