image-caption / app.py
AnkitS1997's picture
added exception handling
177e69b
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch
import io
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load the model and processor
try:
model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded")
model.load_adapter('blip-cpu-model')
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
raise RuntimeError(f"Failed to load the model or processor: {str(e)}")
@app.post("/generate-caption/")
async def generate_caption(file: UploadFile = File(...)):
try:
image = Image.open(io.BytesIO(await file.read()))
except UnidentifiedImageError:
# Raise a 400 error if the file is not a valid image
raise HTTPException(status_code=400, detail="Uploaded file is not a valid image.")
except Exception as e:
# Catch any other unexpected errors related to image processing
raise HTTPException(status_code=500, detail=f"An unexpected error occurred while processing the image: {str(e)}")
try:
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)
with torch.no_grad():
caption_ids = model.generate(**inputs, max_length=128)
caption = processor.decode(caption_ids[0], skip_special_tokens=True)
return {"caption": caption}
except Exception as e:
# Catch any errors during the caption generation process
raise HTTPException(status_code=500, detail=f"An error occurred while generating the caption: {str(e)}")