File size: 3,109 Bytes
e199da3
 
dee6e4c
affc855
e199da3
0f197c4
5a982ce
a7a041a
 
 
ea8a0d1
 
 
 
 
 
 
 
 
e199da3
0f197c4
ea8a0d1
e199da3
dee6e4c
 
c064275
dee6e4c
 
 
 
 
 
 
 
 
e199da3
dee6e4c
 
 
 
 
 
 
e199da3
dee6e4c
 
e199da3
0f197c4
 
e199da3
dee6e4c
 
 
 
 
8d29889
dee6e4c
 
 
5a982ce
dee6e4c
 
 
 
 
 
 
 
 
 
e199da3
 
dee6e4c
e199da3
 
 
 
 
ea8a0d1
a7a041a
e199da3
 
 
ea8a0d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from io import BytesIO
from PIL import Image
from diffusers import AutoPipelineForText2Image
import gradio as gr
import base64
from generate_prompts import generate_prompt

CONCURRENCY_LIMIT = 10

def load_model():
    print("Loading the Stable Diffusion model...")
    try:
        model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
        print("Model loaded successfully.")
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

def generate_image(prompt):
    model = load_model()
    try:
        if model is None:
            raise ValueError("Model not loaded properly.")
        
        print(f"Generating image with prompt: {prompt}")
        output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
        print(f"Model output: {output}")

        if output is None:
            raise ValueError("Model returned None")

        if hasattr(output, 'images') and output.images:
            print(f"Image generated successfully")
            image = output.images[0]
            buffered = BytesIO()
            image.save(buffered, format="JPEG")
            image_bytes = buffered.getvalue()
            img_str = base64.b64encode(image_bytes).decode("utf-8")
            print("Image encoded to base64")
            print(f'img_str: {img_str[:100]}...')  # Print a snippet of the base64 string
            return img_str, None
        else:
            print(f"No images found in model output")
            raise ValueError("No images found in model output")
    except Exception as e:
        print(f"An error occurred while generating image: {e}")
        return None, str(e)

def inference(sentence_mapping, character_dict, selected_style):
    try:
        print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}")
        print(f"Received character_dict: {character_dict}, type: {type(character_dict)}")
        print(f"Received selected_style: {selected_style}, type: {type(selected_style)}")

        images = {}
        for paragraph_number, sentences in sentence_mapping.items():
            combined_sentence = " ".join(sentences)
            prompt = generate_prompt(combined_sentence,character_dict, selected_style)
            print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
            img_str, error = generate_image(prompt)
            if error:
                images[paragraph_number] = f"Error: {error}"
            else:
                images[paragraph_number] = img_str
        return images
    except Exception as e:
        print(f"An error occurred during inference: {e}")
        return {"error": str(e)}

gradio_interface = gr.Interface(
    fn=inference,
    inputs=[
        gr.JSON(label="Sentence Mapping"),
        gr.JSON(label="Character Dict"),
        gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
    ],
    outputs="json",
    concurrency_limit=CONCURRENCY_LIMIT)

if __name__ == "__main__":
    print("Launching Gradio interface...")
    gradio_interface.launch()