from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse,HTMLResponse from fastapi.middleware.cors import CORSMiddleware import cv2 import numpy as np from pillmodel import get_prediction import base64 from fastapi.staticfiles import StaticFiles import os import google.generativeai as genai from google.generativeai.types import HarmCategory, HarmBlockThreshold import google.ai.generativelanguage as glm from PIL import Image import io import random import re import json api_keys = os.getenv('GEMINI_API_KEYS').split(',') print(api_keys) from inference_sdk import InferenceHTTPClient app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/predict") async def predict(image: UploadFile = File(...)): contents = await image.read() nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Save the image to a temporary location # temp_image_path = "temp_image.jpg" # cv2.imwrite(temp_image_path, img) # Prediction predicted_image, count_dict = get_prediction(img) # Encode predicted image to base64 _, buffer = cv2.imencode('.jpg', predicted_image) predicted_image_str = base64.b64encode(buffer).decode('utf-8') # Send a confirmation message message_to_send = ( f"There are {count_dict.get('capsules', 0)} capsules and {count_dict.get('tablets', 0)} tablets. " f"A total of {count_dict.get('capsules', 0) + count_dict.get('tablets', 0)} pills." ) return JSONResponse(content={"message": message_to_send, "count": count_dict, "predicted_image": predicted_image_str}) @app.post("/predict_wheat") async def predict_wheat(image: UploadFile = File(...), model_id: str = "grian/1"): contents = await image.read() nparr = np.frombuffer(contents, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # delete the image if exists try: os.remove("temp_image.jpg") except: print("temp_image.jpg does not exist") # Save the image to a temporary location temp_image_path = "temp_image.jpg" cv2.imwrite(temp_image_path, img) CLIENT = InferenceHTTPClient( api_url="https://detect.roboflow.com", api_key="PpEebXofNuob5VSx7YP3" ) result = CLIENT.infer("temp_image.jpg", model_id=model_id) # Prediction predicted_count = len(result['predictions']) message_to_send = ( f"There are {predicted_count} wheat grains." ) for prediction in result['predictions']: x = int(prediction['x']) y = int(prediction['y']) width = int(prediction['width']) height = int(prediction['height']) cv2.rectangle(img, (x, y), (x + width, y + height), (0, 255, 0), 2) # Encode predicted image to base64 _, buffer = cv2.imencode('.jpg', img) predicted_image_str = base64.b64encode(buffer).decode('utf-8') return JSONResponse(content={"message": message_to_send, "count": predicted_count, "predicted_image": predicted_image_str}) def process_image(file: UploadFile): image = Image.open(file.file) # Convert the image to RGB if not already if image.mode != 'RGB': image = image.convert('RGB') # Convert the image to a byte array img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='JPEG') # Create a Blob object blob = glm.Blob( mime_type='image/jpeg', data=img_byte_arr.getvalue() ) return blob @app.post("/analyze-image") async def analyze_image(file: UploadFile = File(...)): selected_api_key = random.choice(api_keys) print(f"Selected API Key: {selected_api_key}") genai.configure(api_key=selected_api_key) generation_config = { "temperature": 1, "top_p": 0.95, "top_k": 64, "max_output_tokens": 8192, "response_mime_type": "text/plain", } safety_settings = { HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } # Process the image blob = process_image(file) # Initialize the Generative Model model = genai.GenerativeModel( model_name="gemini-1.5-flash", generation_config=generation_config, safety_settings=safety_settings ) # Prompt for content generation prompt = """ give a safety score for a website called unipall which is a olx, now when a user is uploading a product, tell me this in json like: only give this json nothing else not be too harmful when a picture contains some accessories in a scene focus on them and don't flag it don't flag text on the product { useable_on_website: true/false, safety_score: /100, category: "", reason: "", suggested_product_title: "", suggested_product_description: "" } """ # Generate content using the AI model response = model.generate_content([prompt, blob]) if '```json' not in response.text: return JSONResponse(content=response.text ,media_type="application/json") # Extract JSON string from Markdown-formatted JSON string json_string = re.search(r'```json(.*?)```', response.text, re.DOTALL).group(1) # Clean JSON string cleaned_response = json_string.strip() # Parse the cleaned string as JSON data = json.loads(cleaned_response) fd = json.dumps(data, indent=4) # Return the AI-generated response return JSONResponse(content=fd ,media_type="application/json") app.mount("/", StaticFiles(directory="static"), name="static") @app.get("/") async def home(): return HTMLResponse(content="")