|
import torch |
|
import io |
|
from fireworks.flumina import FluminaModule, main as flumina_main |
|
from fireworks.flumina.route import post |
|
import pydantic |
|
from pydantic import BaseModel |
|
from fastapi import Header |
|
from fastapi.responses import Response |
|
import math |
|
import re |
|
import PIL.Image as Image |
|
from typing import Tuple |
|
from tqdm import tqdm |
|
|
|
from sd3_infer import SD3Inferencer, CONFIGS |
|
from sd3_impls import SD3LatentFormat |
|
|
|
|
|
|
|
def _aspect_ratio_to_width_height(aspect_ratio: str) -> Tuple[int, int]: |
|
""" |
|
Convert specified aspect ratio to a height/width pair. |
|
""" |
|
if ":" not in aspect_ratio: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
w, h = aspect_ratio.split(":") |
|
try: |
|
w, h = int(w), int(h) |
|
except ValueError: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be in w:h format, e.g. 16:9" |
|
) |
|
|
|
valid_aspect_ratios = [ |
|
(1, 1), |
|
(21, 9), |
|
(16, 9), |
|
(3, 2), |
|
(5, 4), |
|
(4, 5), |
|
(2, 3), |
|
(9, 16), |
|
(9, 21), |
|
] |
|
if (w, h) not in valid_aspect_ratios: |
|
raise ValueError( |
|
f"Invalid aspect ratio: {aspect_ratio}. Aspect ratio must be one of {valid_aspect_ratios}" |
|
) |
|
|
|
|
|
TARGET_SIZE_MP = 1 |
|
target_size = TARGET_SIZE_MP * 2**20 |
|
|
|
width = math.sqrt(target_size / (w * h)) * w |
|
height = math.sqrt(target_size / (w * h)) * h |
|
|
|
PAD_MULTIPLE = 64 |
|
|
|
if PAD_MULTIPLE: |
|
width = width // PAD_MULTIPLE * PAD_MULTIPLE |
|
height = height // PAD_MULTIPLE * PAD_MULTIPLE |
|
|
|
return int(width), int(height) |
|
|
|
|
|
def encode_image( |
|
image: Image.Image, mime_type: str, jpeg_quality: int = 95 |
|
) -> bytes: |
|
buffered = io.BytesIO() |
|
if mime_type == "image/jpeg": |
|
if jpeg_quality < 0 or jpeg_quality > 100: |
|
raise ValueError( |
|
f"jpeg_quality must be between 0 and 100, not {jpeg_quality}" |
|
) |
|
image.save(buffered, format="JPEG", quality=jpeg_quality) |
|
elif mime_type == "image/png": |
|
image.save(buffered, format="PNG") |
|
else: |
|
raise ValueError(f"invalid mime_type {mime_type}") |
|
return buffered.getvalue() |
|
|
|
|
|
def parse_accept_header(accept: str) -> str: |
|
|
|
parts = accept.split(",") |
|
weighted_types = [] |
|
|
|
for part in parts: |
|
|
|
match = re.match( |
|
r"(?P<media_type>[^;]+)(;q=(?P<q_factor>\d+(\.\d+)?))?", part.strip() |
|
) |
|
if match: |
|
media_type = match.group("media_type").strip() |
|
q_factor = ( |
|
float(match.group("q_factor")) if match.group("q_factor") else 1.0 |
|
) |
|
weighted_types.append((media_type, q_factor)) |
|
else: |
|
raise ValueError(f"Malformed Accept header value: {part.strip()}") |
|
|
|
|
|
sorted_types = sorted(weighted_types, key=lambda x: x[1], reverse=True) |
|
|
|
|
|
supported_types = ["image/jpeg", "image/png"] |
|
|
|
for media_type, _ in sorted_types: |
|
if media_type in supported_types: |
|
return media_type |
|
elif media_type == "*/*": |
|
return supported_types[0] |
|
elif media_type == "image/*": |
|
|
|
return supported_types[0] |
|
|
|
raise ValueError(f"Accept header did not include any supported MIME types: {supported_types}") |
|
|
|
|
|
|
|
class Text2ImageRequest(BaseModel): |
|
prompt: str |
|
aspect_ratio: str = "16:9" |
|
guidance_scale: float = 4.5 |
|
num_inference_steps: int = 28 |
|
seed: int = 0 |
|
|
|
|
|
class Error(BaseModel): |
|
object: str = "error" |
|
type: str = "invalid_request_error" |
|
message: str |
|
|
|
|
|
class ErrorResponse(BaseModel): |
|
error: Error = pydantic.Field(default_factory=Error) |
|
|
|
|
|
class BillingInfo(BaseModel): |
|
steps: int |
|
height: int |
|
width: int |
|
is_control_net: bool |
|
|
|
|
|
MODEL = "models/sd3.5_medium.safetensors" |
|
VERBOSE = True |
|
|
|
|
|
class SD3InferencerInMemoryOutput(SD3Inferencer): |
|
def gen_image( |
|
self, |
|
prompts, |
|
width, |
|
height, |
|
steps, |
|
cfg_scale, |
|
sampler, |
|
seed, |
|
seed_type, |
|
init_image, |
|
denoise, |
|
): |
|
latent = self.get_empty_latent(width, height) |
|
if init_image: |
|
image_data = Image.open(init_image) |
|
image_data = image_data.resize((width, height), Image.LANCZOS) |
|
latent = self.vae_encode(image_data) |
|
latent = SD3LatentFormat().process_in(latent) |
|
neg_cond = self.get_cond("") |
|
seed_num = None |
|
assert len(prompts) == 1 |
|
pbar = tqdm(enumerate(prompts), total=len(prompts), position=0, leave=True) |
|
for i, prompt in pbar: |
|
if seed_type == "roll": |
|
seed_num = seed if seed_num is None else seed_num + 1 |
|
elif seed_type == "rand": |
|
seed_num = torch.randint(0, 100000, (1,)).item() |
|
else: |
|
seed_num = seed |
|
conditioning = self.get_cond(prompt) |
|
sampled_latent = self.do_sampling( |
|
latent, |
|
seed_num, |
|
conditioning, |
|
neg_cond, |
|
steps, |
|
cfg_scale, |
|
sampler, |
|
denoise if init_image else 1.0, |
|
) |
|
return self.vae_decode(sampled_latent) |
|
|
|
|
|
class FluminaModule(FluminaModule): |
|
def __init__(self): |
|
super().__init__() |
|
self.inferencer = SD3InferencerInMemoryOutput() |
|
with torch.inference_mode(): |
|
self.inferencer.load(model=MODEL, vae=MODEL, shift=CONFIGS["sd3.5_medium"]["shift"], verbose=VERBOSE) |
|
self.inferencer.clip_l.model.to("cuda") |
|
self.inferencer.clip_g.model.to("cuda") |
|
self.inferencer.t5xxl.model.to("cuda") |
|
self.inferencer.sd3.model.to("cuda") |
|
self.inferencer.vae.model.to("cuda") |
|
|
|
self._test_return_sync_response = False |
|
|
|
def _error_response(self, code: int, message: str) -> Response: |
|
response_json = ErrorResponse( |
|
error=Error(message=message), |
|
).json() |
|
if self._test_return_sync_response: |
|
return response_json |
|
else: |
|
return Response( |
|
response_json, |
|
status_code=code, |
|
media_type="application/json", |
|
) |
|
|
|
def _image_response(self, img: Image.Image, mime_type: str, billing_info: BillingInfo): |
|
image_bytes = encode_image(img, mime_type) |
|
if self._test_return_sync_response: |
|
return image_bytes |
|
else: |
|
headers = {'Fireworks-Billing-Properties': billing_info.json()} |
|
return Response(image_bytes, status_code=200, media_type=mime_type, headers=headers) |
|
|
|
@post('/text_to_image') |
|
async def text_to_image( |
|
self, |
|
body: Text2ImageRequest, |
|
accept: str = Header("image/jpeg"), |
|
): |
|
mime_type = parse_accept_header(accept) |
|
width, height = _aspect_ratio_to_width_height(body.aspect_ratio) |
|
with torch.inference_mode(): |
|
img = self.inferencer.gen_image( |
|
prompts=[body.prompt], |
|
width=width, |
|
height=height, |
|
steps=body.num_inference_steps, |
|
cfg_scale=body.guidance_scale, |
|
sampler=CONFIGS['sd3.5_medium']['sampler'], |
|
seed=body.seed, |
|
seed_type="roll", |
|
init_image=None, |
|
denoise=0.0, |
|
) |
|
|
|
billing_info = BillingInfo( |
|
steps=body.num_inference_steps, |
|
height=height, |
|
width=width, |
|
is_control_net=False, |
|
) |
|
return self._image_response(img, mime_type, billing_info) |
|
|
|
@property |
|
def supported_addon_types(self): |
|
return [] |
|
|
|
|
|
if __name__ == "__flumina_main__": |
|
f = FluminaModule() |
|
flumina_main(f) |
|
|
|
if __name__ == "__main__": |
|
f = FluminaModule() |
|
f._test_return_sync_response = True |
|
|
|
import asyncio |
|
|
|
|
|
t2i_out = asyncio.run(f.text_to_image( |
|
Text2ImageRequest( |
|
prompt="A quick brown fox", |
|
aspect_ratio="16:9", |
|
guidance_scale=3.5, |
|
num_inference_steps=30, |
|
seed=0, |
|
), |
|
accept="image/jpeg", |
|
)) |
|
assert isinstance(t2i_out, bytes), t2i_out |
|
with open('output.png', 'wb') as out_file: |
|
out_file.write(t2i_out) |
|
|