from fastapi import APIRouter, Form, BackgroundTasks from config import settings import os import json import utils import torch import requests from PIL import Image from io import BytesIO from pydantic import BaseModel import base64 import uuid from transformers import AutoTokenizer, AutoModelForCausalLM from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionInpaintPipeline # tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b") # model = AutoModelForCausalLM.from_pretrained( # "openlm-research/open_llama_7b", device_map="auto", load_in_4bit=True # ) # model_id_or_path = "runwayml/stable-diffusion-v1-5" # pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, ) pipe = pipe.to("cuda") router = APIRouter() class ActionBody(BaseModel): url: str maskUrl: str prompt: str strength: float guidance_scale: float resizeW: int resizeH: int @router.post("/perform-action") async def performAction(actionBody: ActionBody): # model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda") # generated_ids = model.generate(**model_inputs) # output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] response = requests.get(actionBody.url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((actionBody.resizeW, actionBody.resizeH)) response = requests.get(actionBody.maskUrl) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((actionBody.resizeW, actionBody.resizeH)) images = pipe(prompt=actionBody.prompt, image=init_image, mask_image=mask_image, strength=actionBody.strength, guidance_scale=actionBody.guidance_scale).images print(images) buffered = BytesIO() images[0].save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) imgUUID = str(uuid.uuid4()) images[0].save(imgUUID+".png") return { "imageName" : imgUUID+".png", "image": "data:image/jpeg;base64,"+img_str.decode(), # "output": output } @router.get("/hi") async def hifunction(): url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" response = requests.get(url) init_image = Image.open(BytesIO(response.content)).convert("RGB") init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images print(images) print(images[0]) buffered = BytesIO() images[0].save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) # images[0].save("fantasy_landscape.png") return { "image": "data:image/jpeg;base64,"+img_str.decode() }