from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from PIL import Image, UnidentifiedImageError from transformers import AutoProcessor, Blip2ForConditionalGeneration import torch import io app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load the model and processor try: model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded") model.load_adapter('blip-cpu-model') processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) except Exception as e: raise RuntimeError(f"Failed to load the model or processor: {str(e)}") @app.post("/generate-caption/") async def generate_caption(file: UploadFile = File(...)): try: image = Image.open(io.BytesIO(await file.read())) except UnidentifiedImageError: # Raise a 400 error if the file is not a valid image raise HTTPException(status_code=400, detail="Uploaded file is not a valid image.") except Exception as e: # Catch any other unexpected errors related to image processing raise HTTPException(status_code=500, detail=f"An unexpected error occurred while processing the image: {str(e)}") try: inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) with torch.no_grad(): caption_ids = model.generate(**inputs, max_length=128) caption = processor.decode(caption_ids[0], skip_special_tokens=True) return {"caption": caption} except Exception as e: # Catch any errors during the caption generation process raise HTTPException(status_code=500, detail=f"An error occurred while generating the caption: {str(e)}")