File size: 5,675 Bytes
4448dda
c0490dd
 
 
68c48e6
54dc753
68c48e6
c0490dd
fcb0cff
c0490dd
 
 
68c48e6
fcb0cff
 
31ab06b
2dfa8ab
 
3b844a5
54dc753
061ce15
17869bc
fcb0cff
68c48e6
 
 
 
2dfa8ab
 
 
 
3b844a5
 
fcb0cff
a5422e0
 
 
 
 
 
 
c0490dd
fcb0cff
 
baa503f
 
68c48e6
baa503f
68c48e6
 
 
6a500ed
3b844a5
 
c8f5641
 
 
 
f81b3d1
e5efe2c
c0490dd
68c48e6
cb89b3f
6a500ed
cb89b3f
8651de6
68c48e6
 
 
6a500ed
cb89b3f
8651de6
68c48e6
c0490dd
 
 
 
d06c74e
fcb0cff
c0490dd
 
bfd8827
 
fcb0cff
c0490dd
54dc753
c0490dd
 
bc47113
 
c0490dd
 
bfd8827
c0490dd
fcb0cff
c0490dd
 
 
 
fcb0cff
c0490dd
d06c74e
c8f5641
3b844a5
fcb0cff
c0490dd
fcb0cff
c0490dd
 
 
fcb0cff
c0490dd
 
fcb0cff
3b844a5
c0490dd
 
 
bc47113
c0490dd
 
 
 
 
 
 
fcb0cff
bfd8827
fcb0cff
 
 
2dfa8ab
 
 
 
 
 
 
 
 
c0c87ac
 
 
 
2dfa8ab
 
17869bc
 
 
2dfa8ab
 
4988ed2
 
2dfa8ab
 
061ce15
 
 
2dfa8ab
061ce15
2dfa8ab
17869bc
 
 
 
fcb0cff
3b844a5
54dc753
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import spaces
import logging
import random
import warnings
import os
import shutil
import subprocess
import torch
import numpy as np
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from PIL import Image
from huggingface_hub import snapshot_download, login
import io
import base64
from fastapi import FastAPI, File, UploadFile,Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from concurrent.futures import ThreadPoolExecutor
import uvicorn
import asyncio
import time  # Import time module for measuring execution time

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# FastAPI app for image processing
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])

# ThreadPoolExecutor for managing image processing threads
executor = ThreadPoolExecutor()

#Determine the device (GPU or CPU)
if torch.cuda.is_available():
    device = "cuda"
    logger.info("CUDA is available. Using GPU.")
else:
    device = "cpu"
    logger.info("CUDA is not available. Using CPU.")

# Load model from Huggingface Hub
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
if huggingface_token:
    login(token=huggingface_token)
    logger.info("Hugging Face token found and logged in.")
else:
    logger.warning("Hugging Face token not found in environment variables.")

# Download model using snapshot_download
model_path = snapshot_download(
    repo_id="black-forest-labs/FLUX.1-dev",
    repo_type="model",
    ignore_patterns=["*.md", "*..gitattributes"],
    local_dir="FLUX.1-dev",
    token=huggingface_token
)
logger.info("Model downloaded to: %s", model_path)

# Load pipeline
logger.info('Loading ControlNet model.')
cache_dir = "./model_cache"  
controlnet = FluxControlNetModel.from_pretrained(
    "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16,cache_dir=cache_dir
).to(device)
logger.info("ControlNet model loaded successfully.")

logger.info('Loading pipeline.')
pipe = FluxControlNetPipeline.from_pretrained(
    model_path, controlnet=controlnet, torch_dtype=torch.bfloat16,cache_dir=cache_dir
).to(device)
logger.info("Pipeline loaded successfully.")

MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 1024 * 1024

#@spaces.GPU
def process_input(input_image, upscale_factor):
    w, h = input_image.size
    aspect_ratio = w / h
    was_resized = False

    # Resize if input size exceeds the maximum pixel budget
    if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
        warnings.warn("Requested output image is too large. Resizing to fit within pixel budget.")
        input_image = input_image.resize(
            (
                int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor),
                int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor),
            )
        )
        was_resized = True

    # Adjust dimensions to be a multiple of 8
    w, h = input_image.size
    w = w - w % 8
    h = h - h % 8

    return input_image.resize((w, h)), was_resized

#@spaces.GPU
def run_inference(input_image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale):
    logger.info("Processing inference.")
    input_image, was_resized = process_input(input_image, upscale_factor)

    # Rescale image for ControlNet processing
    w, h = input_image.size
    control_image = input_image.resize((w * upscale_factor, h * upscale_factor))

    # Set the random generator for inference
    generator = torch.Generator().manual_seed(seed)

    # Perform inference using the pipeline
    logger.info("Running pipeline.")
    image = pipe(
        prompt="",
        control_image=control_image,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        num_inference_steps=num_inference_steps,
        guidance_scale=3.5,
        height=control_image.size[1],
        width=control_image.size[0],
        generator=generator,
    ).images[0]

    # Resize output image back to the original dimensions if needed
    if was_resized:
        original_size = (input_image.width * upscale_factor, input_image.height * upscale_factor)
        image = image.resize(original_size)

    # Convert the output image to base64
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

    return image_base64

@app.post("/infer")
async def infer(input_image: UploadFile = File(...), 
                upscale_factor: int = Form(4),  # Default value of 4
                seed: int = Form(42),            # Default value of 42
                num_inference_steps: int = Form(28),  # Default value of 28
                controlnet_conditioning_scale: float = Form(0.6)):
    logger.info("Received request for inference.")
    
    # Start timing the entire inference process
    start_time = time.time()

    # Read the uploaded image
    contents = await input_image.read()
    print(type(contents))
    contents = bytes(contents)
    image = Image.open(io.BytesIO(contents))

    # Get the current event loop
    loop = asyncio.get_event_loop()

    # Run inference in a separate thread
    base64_image = await loop.run_in_executor(executor, run_inference, image, upscale_factor, seed, num_inference_steps, controlnet_conditioning_scale)

    # Calculate the time taken
    time_taken = time.time() - start_time

    return JSONResponse(content={"base64_image": base64_image, "time_taken": time_taken})

if __name__ == "__main__":
    # Start FastAPI server
    uvicorn.run(app, host="0.0.0.0", port=7860)