tzintsunzu commited on
Commit
4a0d8b9
1 Parent(s): 0183abe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ import gc
5
+ import re
6
+ import random
7
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
8
+ from diffusers import StableDiffusionPipeline
9
+ import gradio as gr
10
+
11
+ # Initialize the text generation pipeline with the pre-quantized 8-bit model
12
+ model_name = 'HuggingFaceTB/SmolLM-1.7B-Instruct'
13
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
15
+ text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1) # Use CPU
16
+
17
+ # Load the Stable Diffusion model
18
+ model_id = "stabilityai/stable-diffusion-2-1-base" # Smaller model
19
+ pipe = StableDiffusionPipeline.from_pretrained(model_id)
20
+ pipe = pipe.to("cpu") # Use CPU
21
+
22
+ # Create a directory to save the generated images
23
+ output_dir = 'generated_images'
24
+ os.makedirs(output_dir, exist_ok=True)
25
+ os.chmod(output_dir, 0o777)
26
+
27
+ # Function to generate a detailed visual description prompt
28
+ def generate_description_prompt(user_prompt, user_examples):
29
+ prompt = f'generate enclosed in quotes in the format "<description>" description according to guidelines of {user_prompt} different from {user_examples}'
30
+ try:
31
+ generated_text = text_generator(prompt, max_length=150, num_return_sequences=1, truncation=True)[0]['generated_text']
32
+ match = re.search(r'"(.*?)"', generated_text)
33
+ if match:
34
+ generated_description = match.group(1).strip() # Capture the description between quotes
35
+ return f'"{generated_description}"'
36
+ else:
37
+ return None
38
+ except Exception as e:
39
+ print(f"Error generating description for prompt '{user_prompt}': {e}")
40
+ return None
41
+
42
+ # Seed words pool
43
+ seed_words = []
44
+
45
+ used_words = set()
46
+
47
+ def generate_description(user_prompt, user_examples_list):
48
+ seed_words.extend(user_examples_list)
49
+
50
+ # Select a subject that has not been used
51
+ available_subjects = [word for word in seed_words if word not in used_words]
52
+ if not available_subjects:
53
+ print("No more available subjects to use.")
54
+ return None, None
55
+
56
+ subject = random.choice(available_subjects)
57
+ generated_description = generate_description_prompt(user_prompt, subject)
58
+
59
+ if generated_description:
60
+ # Remove any offending symbols
61
+ clean_description = generated_description.encode('ascii', 'ignore').decode('ascii')
62
+
63
+ # Print the generated description to the command line
64
+ print(f"Generated description for subject '{subject}': {clean_description}")
65
+
66
+ # Update used words and seed words
67
+ used_words.add(subject)
68
+ seed_words.append(clean_description.strip('"')) # Add the generated description to the seed bank array without quotes
69
+
70
+ return clean_description, subject
71
+ else:
72
+ return None, None
73
+
74
+ # Function to generate an image based on the description
75
+ def generate_image(description, seed=42):
76
+ prompt = f'detailed photorealistic full shot of {description}'
77
+ generator = torch.Generator().manual_seed(seed)
78
+ image = pipe(
79
+ prompt=prompt,
80
+ width=512,
81
+ height=512,
82
+ num_inference_steps=10, # Use 10 inference steps
83
+ generator=generator,
84
+ guidance_scale=7.5,
85
+ ).images[0]
86
+ return image
87
+
88
+ # Gradio interface
89
+ def gradio_interface(user_prompt, user_examples):
90
+ user_examples_list = [example.strip().strip('"') for example in user_examples.split(',')]
91
+ generated_description, subject = generate_description(user_prompt, user_examples_list)
92
+
93
+ if generated_description:
94
+ # Generate image
95
+ image = generate_image(generated_description)
96
+ image_path = os.path.join(output_dir, f"image_{len(os.listdir(output_dir))}.png")
97
+ image.save(image_path)
98
+ os.chmod(image_path, 0o777)
99
+
100
+ return image, generated_description
101
+ else:
102
+ return None, "Failed to generate description."
103
+
104
+ iface = gr.Interface(
105
+ fn=gradio_interface,
106
+ inputs=[
107
+ gr.Textbox(lines=2, placeholder="Enter the generation task or general thing you are looking for"),
108
+ gr.Textbox(lines=2, placeholder='Provide a few examples (enclosed in quotes and separated by commas)')
109
+ ],
110
+ outputs=[
111
+ gr.Image(label="Generated Image"),
112
+ gr.Textbox(label="Generated Description")
113
+ ],
114
+ title="Description and Image Generator",
115
+ description="Generate detailed descriptions and images based on your input."
116
+ )
117
+
118
+ iface.launch(server_name="0.0.0.0", server_port=7860)
119
+
120
+ # Clear GPU memory when the process is closed
121
+ def clear_gpu_memory():
122
+ torch.cuda.empty_cache()
123
+ gc.collect()
124
+ print("GPU memory cleared.")