Timmyafolami's picture
Update main.py
06f8aa3 verified
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from detection import initialize_directories, slice_geotiff, detect_weeds_in_slices, cleanup, create_zip # Import detection functions
from db_user_info import insert_user_info
from db_bucket import upload_file_to_bucket
from typing import List
from io import BytesIO
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Or specify the frontend origin
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files
app.mount("/static", StaticFiles(directory="static"), name="static")
# WebSocket manager
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_message(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
manager = ConnectionManager()
@app.get("/", response_class=HTMLResponse)
async def read_index():
with open("static/index.html") as f:
return HTMLResponse(content=f.read())
@app.get("/app", response_class=HTMLResponse)
async def read_app():
with open("static/app.html") as f:
return HTMLResponse(content=f.read())
@app.post("/register")
async def register_user(request: Request):
user_info = await request.json()
try:
insert_user_info(user_info)
return JSONResponse(content={"detail": "User registered successfully"}, status_code=200)
except Exception as e:
return JSONResponse(content={"detail": str(e)}, status_code=400)
@app.post("/upload_geotiff/")
async def upload_geotiff(file: UploadFile = File(...)):
# Initialize directories at the start
initialize_directories()
file_location = f"uploaded_geotiff.tif"
with open(file_location, "wb") as f:
f.write(file.file.read())
await manager.send_message("GeoTIFF file uploaded successfully. Slicing started.")
slices = await slice_geotiff(file_location, slice_size=3000)
await manager.send_message("Slicing complete. Starting weed detection.")
weed_bboxes = await detect_weeds_in_slices(slices)
await manager.send_message("Weed detection complete. Generating shapefile.")
# Create zip file
zip_file_path = await create_zip()
await manager.send_message("Shapefiles Generated. Zipping shapefile.")
# Upload the zip file to the bucket
response = upload_file_to_bucket(zip_file_path)
print(response)
await manager.send_message("Zip file uploaded to bucket storage.")
# Read zip file into buffer for download
zip_buffer = BytesIO()
with open(zip_file_path, 'rb') as f:
zip_buffer.write(f.read())
zip_buffer.seek(0)
# Cleanup files and directories
cleanup()
return StreamingResponse(zip_buffer, media_type="application/zip", headers={"Content-Disposition": "attachment; filename=weed_detections.zip"})
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.send_message(data)
except WebSocketDisconnect:
manager.disconnect(websocket)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)