kedimestan's picture
Update app.py
03b5980 verified
raw
history blame
2.18 kB
from fastapi import FastAPI, UploadFile, File
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import numpy as np
from PIL import Image
import io
from fastapi.responses import JSONResponse
import matplotlib.pyplot as plt
import uvicorn
import matplotlib
matplotlib.use('Agg')
app = FastAPI()
# Load the model and processor once, globally
processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
model = DPTForDepthEstimation.from_pretrained("model/")
# Define the focal length and sensor width (adjust these values based on your camera)
focal_length = 14.35
sensor_width = 4.88
image_width = 3072
focal_length_px = (image_width * focal_length) / sensor_width
@app.post("/predict/")
async def predict_depth(file: UploadFile = File(...)):
# Read the uploaded image
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
# Prepare image for the model
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
# Convert to numpy for further processing
output = prediction.squeeze().cpu().numpy()
# Normalize and format depth map for display
formatted = (output * 255 / np.max(output)).astype("uint8")
depth_map_image = Image.fromarray(formatted)
# Convert depth to real-world centimeters using focal length and sensor width
cm = focal_length_px / (output + 1e-6)
# Save the depth map visualization to a buffer
fig, ax = plt.subplots()
heat = ax.imshow(cm, cmap="plasma")
plt.colorbar(heat)
buf = io.BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
# Return the result as a JSON response
return JSONResponse({
"depth_map": f"data:image/png;base64,{base64.b64encode(buf.read()).decode()}"
})
# For local testing
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)