petrockapi / routers /training.py
woods-today
Workin on it
0dfd0db
raw
history blame
3.21 kB
from fastapi import APIRouter, Form, BackgroundTasks
from config import settings
import os
import json
import utils
import torch
import requests
from PIL import Image
from io import BytesIO
from pydantic import BaseModel
import base64
import uuid
from transformers import AutoTokenizer, AutoModelForCausalLM
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
# tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b")
# model = AutoModelForCausalLM.from_pretrained(
# "openlm-research/open_llama_7b", device_map="auto", load_in_4bit=True
# )
# model_id_or_path = "runwayml/stable-diffusion-v1-5"
# pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
router = APIRouter()
class ActionBody(BaseModel):
url: str
maskUrl: str
prompt: str
strength: float
guidance_scale: float
resizeW: int
resizeH: int
@router.post("/perform-action")
async def performAction(actionBody: ActionBody):
# model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt").to("cuda")
# generated_ids = model.generate(**model_inputs)
# output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response = requests.get(actionBody.url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((actionBody.resizeW, actionBody.resizeH))
response = requests.get(actionBody.maskUrl)
mask_image = Image.open(BytesIO(response.content)).convert("RGB")
mask_image = mask_image.resize((actionBody.resizeW, actionBody.resizeH))
# images = pipe(prompt=actionBody.prompt, image=init_image, strength=actionBody.strength, guidance_scale=actionBody.guidance_scale).images
images = pipe(prompt=actionBody.prompt, image=init_image, mask_image=mask_image).images
print(images)
buffered = BytesIO()
images[0].save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
imgUUID = str(uuid.uuid4())
images[0].save(imgUUID+".png")
return {
"imageName" : imgUUID+".png",
"image": "data:image/jpeg;base64,"+img_str.decode(),
# "output": output
}
@router.get("/hi")
async def hifunction():
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
print(images)
print(images[0])
buffered = BytesIO()
images[0].save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
# images[0].save("fantasy_landscape.png")
return {
"image": "data:image/jpeg;base64,"+img_str.decode()
}