|
import torch |
|
import io |
|
from fireworks.flumina import FluminaModule, main as flumina_main |
|
from fireworks.flumina.route import post |
|
import pydantic |
|
from pydantic import BaseModel |
|
from fastapi import File, Form, Header, UploadFile, HTTPException |
|
from fastapi.responses import Response |
|
import math |
|
import os |
|
import re |
|
import PIL.Image as Image |
|
from typing import Dict, Optional, Set, Tuple |
|
|
|
from diffusers import StableDiffusion3Pipeline |
|
from diffusers.models import FluxMultiControlNetModel |
|
|
|
|
|
|
|
def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]: |
|
""" |
|
Convert specified aspect ratio to a height/width pair. |
|
""" |
|
if ":" not in aspect_ratio: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
w, h = aspect_ratio.split(":") |
|
try: |
|
w, h = int(w), int(h) |
|
except ValueError: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
valid_aspect_ratios = [ |
|
(1, 1), |
|
(21, 9), |
|
(16, 9), |
|
(3, 2), |
|
(5, 4), |
|
(4, 5), |
|
(2, 3), |
|
(9, 16), |
|
(9, 21), |
|
] |
|
if (w, h) not in valid_aspect_ratios: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be one of {valid_aspect_ratios}" |
|
) |
|
|
|
|
|
TARGET_SIZE_MP = 1 |
|
target_size = TARGET_SIZE_MP * 2**20 |
|
|
|
width = math.sqrt(target_size / (w * h)) * w |
|
height = math.sqrt(target_size / (w * h)) * h |
|
|
|
PAD_MULTIPLE = 64 |
|
|
|
if PAD_MULTIPLE: |
|
width = width // PAD_MULTIPLE * PAD_MULTIPLE |
|
height = height // PAD_MULTIPLE * PAD_MULTIPLE |
|
|
|
return int(width), int(height) |
|
|
|
|
|
def encode_image( |
|
image: Image.Image, mime_type: str, jpeg_quality: int = 95 |
|
) -> bytes: |
|
buffered = io.BytesIO() |
|
if mime_type == "image/jpeg": |
|
if jpeg_quality < 0 or jpeg_quality > 100: |
|
raise ValueError( |
|
f"jpeg_quality must be between 0 and 100, not {jpeg_quality}" |
|
) |
|
image.save(buffered, format="JPEG", quality=jpeg_quality) |
|
elif mime_type == "image/png": |
|
image.save(buffered, format="PNG") |
|
else: |
|
raise ValueError(f"invalid mime_type {mime_type}") |
|
return buffered.getvalue() |
|
|
|
|
|
def parse_accept_header(accept: str) -> str: |
|
|
|
parts = accept.split(",") |
|
weighted_types = [] |
|
|
|
for part in parts: |
|
|
|
match = re.match( |
|
r"(?P<media_type>[^;]+)(;q=(?P<q_factor>\d+(\.\d+)?))?", part.strip() |
|
) |
|
if match: |
|
media_type = match.group("media_type").strip() |
|
q_factor = ( |
|
float(match.group("q_factor")) if match.group("q_factor") else 1.0 |
|
) |
|
weighted_types.append((media_type, q_factor)) |
|
else: |
|
raise ValueError(f"Malformed Accept header value: {part.strip()}") |
|
|
|
|
|
sorted_types = sorted(weighted_types, key=lambda x: x[1], reverse=True) |
|
|
|
|
|
supported_types = ["image/jpeg", "image/png"] |
|
|
|
for media_type, _ in sorted_types: |
|
if media_type in supported_types: |
|
return media_type |
|
elif media_type == "*/*": |
|
return supported_types[0] |
|
elif media_type == "image/*": |
|
|
|
return supported_types[0] |
|
|
|
raise ValueError(f"Accept header did not include any supported MIME types: {supported_types}") |
|
|
|
|
|
|
|
class Text2ImageRequest(BaseModel): |
|
prompt: str |
|
aspect_ratio: str = "16:9" |
|
guidance_scale: float = 0.0 |
|
num_inference_steps: int = 4 |
|
seed: int = 0 |
|
|
|
|
|
class Error(BaseModel): |
|
object: str = "error" |
|
type: str = "invalid_request_error" |
|
message: str |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
error: Error = pydantic.Field(default_factory=Error) |
|
|
|
|
|
class BillingInfo(BaseModel): |
|
steps: int |
|
height: int |
|
width: int |
|
is_control_net: bool |
|
|
|
|
|
class FluminaModule(FluminaModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.hf_model = StableDiffusion3Pipeline.from_pretrained('./data', torch_dtype=torch.bfloat16) |
|
self.hf_model.to(device='cuda', dtype=torch.bfloat16) |
|
|
|
self._test_return_sync_response = False |
|
|
|
def _error_response(self, code: int, message: str) -> Response: |
|
response_json = ErrorResponse( |
|
error=Error(message=message), |
|
).json() |
|
if self._test_return_sync_response: |
|
return response_json |
|
else: |
|
return Response( |
|
response_json, |
|
status_code=code, |
|
media_type="application/json", |
|
) |
|
|
|
def _image_response(self, img: Image.Image, mime_type: str, billing_info: BillingInfo): |
|
image_bytes = encode_image(img, mime_type) |
|
if self._test_return_sync_response: |
|
return image_bytes |
|
else: |
|
headers = {'Fireworks-Billing-Properties': billing_info.json()} |
|
return Response(image_bytes, status_code=200, media_type=mime_type, headers=headers) |
|
|
|
@post('/text_to_image') |
|
async def text_to_image( |
|
self, |
|
body: Text2ImageRequest, |
|
accept: str = Header("image/jpeg"), |
|
): |
|
mime_type = parse_accept_header(accept) |
|
width, height = _aspect_ratio_to_width_height(body.aspect_ratio) |
|
img = self.hf_model( |
|
prompt=body.prompt, |
|
height=height, |
|
width=width, |
|
guidance_scale=body.guidance_scale, |
|
num_inference_steps=body.num_inference_steps, |
|
generator=torch.Generator('cuda').manual_seed(body.seed), |
|
) |
|
assert len(img.images) == 1, len(img.images) |
|
|
|
billing_info = BillingInfo( |
|
steps=body.num_inference_steps, |
|
height=height, |
|
width=width, |
|
is_control_net=False, |
|
) |
|
return self._image_response(img.images[0], mime_type, billing_info) |
|
|
|
@property |
|
def supported_addon_types(self): |
|
return [] |
|
|
|
|
|
if __name__ == "__flumina_main__": |
|
f = FluminaModule() |
|
flumina_main(f) |
|
|
|
if __name__ == "__main__": |
|
f = FluminaModule() |
|
f._test_return_sync_response = True |
|
|
|
import asyncio |
|
|
|
|
|
t2i_out = asyncio.run(f.text_to_image( |
|
Text2ImageRequest( |
|
prompt="A quick brown fox", |
|
aspect_ratio="16:9", |
|
guidance_scale=0.0, |
|
num_inference_steps=4, |
|
seed=0, |
|
), |
|
accept="image/jpeg", |
|
)) |
|
assert isinstance(t2i_out, bytes), t2i_out |
|
with open('output.png', 'wb') as out_file: |
|
out_file.write(t2i_out) |
|
|