RanM commited on
Commit
0f197c4
·
verified ·
1 Parent(s): 8d29889

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -42
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import os
2
- import asyncio
3
- from concurrent.futures import ProcessPoolExecutor
4
  from io import BytesIO
5
  from diffusers import AutoPipelineForText2Image
6
  import gradio as gr
 
7
  from generate_prompts import generate_prompt
8
 
9
  # Load the model once at the start
@@ -17,52 +16,37 @@ def truncate_prompt(prompt, max_length=77):
17
  prompt = " ".join(tokens[:max_length])
18
  return prompt
19
 
20
- def generate_image(prompt, prompt_name):
21
  try:
22
  truncated_prompt = truncate_prompt(prompt)
23
- print(f"Generating image for {prompt_name} with truncated prompt: {truncated_prompt}")
24
 
25
  # Call the model
26
  output = model(prompt=truncated_prompt, num_inference_steps=1, guidance_scale=0.0)
27
 
28
- # Debugging: Print full model output
29
- print(f"Full model output for {prompt_name}: {output}")
30
-
31
  # Check if output is valid
32
  if output is not None and hasattr(output, 'images') and output.images:
33
- print(f"Image generated for {prompt_name}")
34
  image = output.images[0]
35
- buffered = BytesIO()
36
- image.save(buffered, format="JPEG")
37
- image_bytes = buffered.getvalue()
38
- return image_bytes
39
  else:
40
- print(f"No images found or generated output is None for {prompt_name}")
41
- return None
42
  except Exception as e:
43
- print(f"An error occurred while generating image for {prompt_name}: {e}")
44
- return None
45
 
46
- async def queue_api_calls(sentence_mapping, character_dict, selected_style):
47
- print("Starting to queue API calls...")
48
- prompts = []
49
- for paragraph_number, sentences in sentence_mapping.items():
50
- combined_sentence = " ".join(sentences)
51
- prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
52
- prompts.append((paragraph_number, prompt))
53
- print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
54
 
55
- loop = asyncio.get_running_loop()
56
- with ProcessPoolExecutor() as pool:
57
- tasks = [
58
- loop.run_in_executor(pool, generate_image, prompt, f"Prompt {paragraph_number}")
59
- for paragraph_number, prompt in prompts
60
- ]
61
- responses = await asyncio.gather(*tasks)
62
-
63
- images = {paragraph_number: response for (paragraph_number, _), response in zip(prompts, responses)}
64
- print("Finished queuing API calls. Generated images: ", images)
65
- return images
66
 
67
  def process_prompt(sentence_mapping, character_dict, selected_style):
68
  print("Processing prompt...")
@@ -70,14 +54,20 @@ def process_prompt(sentence_mapping, character_dict, selected_style):
70
  print(f"Character Dict: {character_dict}")
71
  print(f"Selected Style: {selected_style}")
72
 
73
- # Ensure we are in the right event loop context
74
- loop = asyncio.new_event_loop()
75
- asyncio.set_event_loop(loop)
76
- cmpt_return = loop.run_until_complete(queue_api_calls(sentence_mapping, character_dict, selected_style))
77
- loop.close()
 
 
 
 
 
 
78
 
79
- print("Prompt processing complete. Generated images: ", cmpt_return)
80
- return cmpt_return
81
 
82
  gradio_interface = gr.Interface(
83
  fn=process_prompt,
 
1
  import os
 
 
2
  from io import BytesIO
3
  from diffusers import AutoPipelineForText2Image
4
  import gradio as gr
5
+ import base64
6
  from generate_prompts import generate_prompt
7
 
8
  # Load the model once at the start
 
16
  prompt = " ".join(tokens[:max_length])
17
  return prompt
18
 
19
+ def generate_image(prompt):
20
  try:
21
  truncated_prompt = truncate_prompt(prompt)
22
+ print(f"Generating image with truncated prompt: {truncated_prompt}")
23
 
24
  # Call the model
25
  output = model(prompt=truncated_prompt, num_inference_steps=1, guidance_scale=0.0)
26
 
 
 
 
27
  # Check if output is valid
28
  if output is not None and hasattr(output, 'images') and output.images:
29
+ print(f"Image generated")
30
  image = output.images[0]
31
+ return image, None
 
 
 
32
  else:
33
+ print(f"No images found or generated output is None")
34
+ return None, "No images found or generated output is None"
35
  except Exception as e:
36
+ print(f"An error occurred while generating image: {e}")
37
+ return None, str(e)
38
 
39
+ def inference(prompt):
40
+ print(f"Received prompt: {prompt}") # Debugging statement
41
+ image, error = generate_image(prompt)
42
+ if error:
43
+ print(f"Error generating image: {error}") # Debugging statement
44
+ return "Error: " + error
 
 
45
 
46
+ buffered = BytesIO()
47
+ image.save(buffered, format="PNG")
48
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
49
+ return img_str
 
 
 
 
 
 
 
50
 
51
  def process_prompt(sentence_mapping, character_dict, selected_style):
52
  print("Processing prompt...")
 
54
  print(f"Character Dict: {character_dict}")
55
  print(f"Selected Style: {selected_style}")
56
 
57
+ prompts = []
58
+ for paragraph_number, sentences in sentence_mapping.items():
59
+ combined_sentence = " ".join(sentences)
60
+ prompt, negative_prompt = generate_prompt(combined_sentence, sentence_mapping, character_dict, selected_style)
61
+ prompts.append((paragraph_number, prompt))
62
+ print(f"Generated prompt for paragraph {paragraph_number}: {prompt}")
63
+
64
+ images = {}
65
+ for paragraph_number, prompt in prompts:
66
+ img_str = inference(prompt)
67
+ images[paragraph_number] = img_str
68
 
69
+ print("Prompt processing complete. Generated images: ", images)
70
+ return images
71
 
72
  gradio_interface = gr.Interface(
73
  fn=process_prompt,