RanM commited on
Commit
dee6e4c
·
verified ·
1 Parent(s): b71021a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -48
app.py CHANGED
@@ -1,79 +1,82 @@
1
  import os
2
  from io import BytesIO
 
3
  from diffusers import AutoPipelineForText2Image
4
  import gradio as gr
5
  import base64
6
- from generate_prompts import generate_prompt
7
 
8
  # Load the model once at the start
9
  print("Loading the Stable Diffusion model...")
10
- model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
11
- print("Model loaded successfully.")
12
-
13
- def truncate_prompt(prompt, max_length=77):
14
- tokens = prompt.split()
15
- if len(tokens) > max_length:
16
- prompt = " ".join(tokens[:max_length])
17
- return prompt
18
 
19
  def generate_image(prompt):
20
  try:
21
- truncated_prompt = truncate_prompt(prompt)
22
- print(f"Generating image with truncated prompt: {truncated_prompt}")
23
-
24
- # Call the model
25
- output = model(prompt=truncated_prompt, num_inference_steps=1, guidance_scale=0.0)
26
 
27
- # Check if output is valid
28
- if output is not None and hasattr(output, 'images') and output.images:
29
- print(f"Image generated")
 
 
 
 
 
 
30
  image = output.images[0]
31
- return image, None
 
 
 
 
 
 
32
  else:
33
- print(f"No images found or generated output is None")
34
- return None, "No images found or generated output is None"
35
  except Exception as e:
36
  print(f"An error occurred while generating image: {e}")
37
  return None, str(e)
38
 
39
- def inference(prompt):
40
- print(f"Received prompt: {prompt}") # Debugging statement
41
- image, error = generate_image(prompt)
42
- if error:
43
- print(f"Error generating image: {error}") # Debugging statement
44
- return "Error: " + error
45
-
46
- buffered = BytesIO()
47
- image.save(buffered, format="PNG")
48
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
49
- return img_str
50
-
51
- def process_prompt(sentence_mapping, character_dict, selected_style):
52
- print("Processing prompt...")
53
- print(f"Sentence Mapping: {sentence_mapping}")
54
- print(f"Character Dict: {character_dict}")
55
- print(f"Selected Style: {selected_style}")
56
 
57
- images = {}
58
- for paragraph_number, sentences in sentence_mapping.items():
59
- combined_sentence = " ".join(sentences)
60
- prompt, _ = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
61
- img_str = inference(prompt)
62
- images[paragraph_number] = img_str
63
 
64
- print("Prompt processing complete. Generated images: ", images)
65
- return images
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  gradio_interface = gr.Interface(
68
- fn=process_prompt,
69
  inputs=[
70
  gr.JSON(label="Sentence Mapping"),
71
  gr.JSON(label="Character Dict"),
72
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
73
  ],
74
  outputs="json"
75
- ).queue(default_concurrency_limit=20) # Set concurrency limit if needed
76
 
77
  if __name__ == "__main__":
78
  print("Launching Gradio interface...")
79
- gradio_interface.launch()
 
1
  import os
2
  from io import BytesIO
3
+ from PIL import Image
4
  from diffusers import AutoPipelineForText2Image
5
  import gradio as gr
6
  import base64
 
7
 
8
  # Load the model once at the start
9
  print("Loading the Stable Diffusion model...")
10
+ try:
11
+ model = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo")
12
+ print("Model loaded successfully.")
13
+ except Exception as e:
14
+ print(f"Error loading model: {e}")
15
+ model = None
 
 
16
 
17
  def generate_image(prompt):
18
  try:
19
+ if model is None:
20
+ raise ValueError("Model not loaded properly.")
 
 
 
21
 
22
+ print(f"Generating image with prompt: {prompt}")
23
+ output = model(prompt=prompt, num_inference_steps=1, guidance_scale=0.0)
24
+ print(f"Model output: {output}")
25
+
26
+ if output is None:
27
+ raise ValueError("Model returned None")
28
+
29
+ if hasattr(output, 'images') and output.images:
30
+ print(f"Image generated successfully")
31
  image = output.images[0]
32
+ buffered = BytesIO()
33
+ image.save(buffered, format="JPEG")
34
+ image_bytes = buffered.getvalue()
35
+ img_str = base64.b64encode(image_bytes).decode("utf-8")
36
+ print("Image encoded to base64")
37
+ print(f'img_str: {img_str[:100]}...') # Print a snippet of the base64 string
38
+ return img_str, None
39
  else:
40
+ print(f"No images found in model output")
41
+ raise ValueError("No images found in model output")
42
  except Exception as e:
43
  print(f"An error occurred while generating image: {e}")
44
  return None, str(e)
45
 
46
+ def inference(sentence_mapping, character_dict, selected_style):
47
+ try:
48
+ print(f"Received sentence_mapping: {sentence_mapping}, type: {type(sentence_mapping)}")
49
+ print(f"Received character_dict: {character_dict}, type: {type(character_dict)}")
50
+ print(f"Received selected_style: {selected_style}, type: {type(selected_style)}")
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ if sentence_mapping is None or character_dict is None or selected_style is None:
53
+ return {"error": "One or more inputs are None"}
 
 
 
 
54
 
55
+ images = {}
56
+ for paragraph_number, sentences in sentence_mapping.items():
57
+ combined_sentence = " ".join(sentences)
58
+ prompt = f"Make an illustration in {selected_style} style from: {combined_sentence}"
59
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
60
+ img_str, error = generate_image(prompt)
61
+ if error:
62
+ images[paragraph_number] = f"Error: {error}"
63
+ else:
64
+ images[paragraph_number] = img_str
65
+ return images
66
+ except Exception as e:
67
+ print(f"An error occurred during inference: {e}")
68
+ return {"error": str(e)}
69
 
70
  gradio_interface = gr.Interface(
71
+ fn=inference,
72
  inputs=[
73
  gr.JSON(label="Sentence Mapping"),
74
  gr.JSON(label="Character Dict"),
75
  gr.Dropdown(["oil painting", "sketch", "watercolor"], label="Selected Style")
76
  ],
77
  outputs="json"
78
+ )
79
 
80
  if __name__ == "__main__":
81
  print("Launching Gradio interface...")
82
+ gradio_interface.launch()