petrockapi / routers /training.py
woods-today
Working on it
771c3c3
raw
history blame
2.14 kB
from fastapi import APIRouter, Form, BackgroundTasks
from config import settings
import os
import json
import utils
import torch
import requests
from PIL import Image
from io import BytesIO
from pydantic import BaseModel
import base64
import uuid
from diffusers import StableDiffusionImg2ImgPipeline
model_id_or_path = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
router = APIRouter()
class ActionBody(BaseModel):
url: str
prompt: str
strength: float
guidance_scale: float
resizeW: int
resizeH: int
@router.post("/perform-action")
async def performAction(actionBody: ActionBody):
response = requests.get(actionBody.url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((actionBody.resizeW, actionBody.resizeH))
images = pipe(prompt=actionBody.prompt, image=init_image, strength=actionBody.strength, guidance_scale=actionBody.guidance_scale).images
print(images)
buffered = BytesIO()
images[0].save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
imgUUID = str(uuid.uuid4())
images[0].save(imgUUID+".png")
return {
"imageName" : imgUUID+".png",
"image": "data:image/jpeg;base64,"+img_str
}
@router.get("/hi")
async def hifunction():
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
print(images)
print(images[0])
buffered = BytesIO()
images[0].save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
# images[0].save("fantasy_landscape.png")
return {
"image": "data:image/jpeg;base64,"+img_str
}