Spaces:
Paused
Paused
from fastapi import APIRouter, Form, BackgroundTasks | |
from config import settings | |
import os | |
import json | |
import utils | |
import torch | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
from pydantic import BaseModel | |
import base64 | |
import uuid | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from diffusers import StableDiffusionImg2ImgPipeline | |
from diffusers import StableDiffusionInpaintPipeline | |
# tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b") | |
# model = AutoModelForCausalLM.from_pretrained( | |
# "openlm-research/open_llama_7b", device_map="auto", load_in_4bit=True | |
# ) | |
# model_id_or_path = "runwayml/stable-diffusion-v1-5" | |
# pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) | |
pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
) | |
pipe = pipe.to("cuda") | |
router = APIRouter() | |
class ActionBody(BaseModel): | |
url: str | |
maskUrl: str | |
prompt: str | |
strength: float | |
guidance_scale: float | |
resizeW: int | |
resizeH: int | |
async def performAction(actionBody: ActionBody): | |
# model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda") | |
# generated_ids = model.generate(**model_inputs) | |
# output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
response = requests.get(actionBody.url) | |
init_image = Image.open(BytesIO(response.content)).convert("RGB") | |
init_image = init_image.resize((actionBody.resizeW, actionBody.resizeH)) | |
response = requests.get(actionBody.maskUrl) | |
mask_image = Image.open(BytesIO(response.content)).convert("RGB") | |
mask_image = mask_image.resize((actionBody.resizeW, actionBody.resizeH)) | |
# images = pipe(prompt=actionBody.prompt, image=init_image, strength=actionBody.strength, guidance_scale=actionBody.guidance_scale).images | |
images = pipe(prompt=actionBody.prompt, image=init_image, mask_image=mask_image).images | |
print(images) | |
buffered = BytesIO() | |
images[0].save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
imgUUID = str(uuid.uuid4()) | |
images[0].save(imgUUID+".png") | |
return { | |
"imageName" : imgUUID+".png", | |
"image": "data:image/jpeg;base64,"+img_str.decode(), | |
# "output": output | |
} | |
async def hifunction(): | |
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" | |
response = requests.get(url) | |
init_image = Image.open(BytesIO(response.content)).convert("RGB") | |
init_image = init_image.resize((768, 512)) | |
prompt = "A fantasy landscape, trending on artstation" | |
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images | |
print(images) | |
print(images[0]) | |
buffered = BytesIO() | |
images[0].save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
# images[0].save("fantasy_landscape.png") | |
return { | |
"image": "data:image/jpeg;base64,"+img_str.decode() | |
} | |