Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
import torch | |
import gc | |
import re | |
import random | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from diffusers import StableDiffusionPipeline | |
import gradio as gr | |
# Initialize the text generation pipeline with the pre-quantized 8-bit model | |
model_name = 'HuggingFaceTB/SmolLM-1.7B-Instruct' | |
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
text_generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=-1) # Use CPU | |
# Load the Stable Diffusion model | |
model_id = "stabilityai/stable-diffusion-2-1-base" # Smaller model | |
pipe = StableDiffusionPipeline.from_pretrained(model_id) | |
pipe = pipe.to("cpu") # Use CPU | |
# Create a directory to save the generated images | |
output_dir = 'generated_images' | |
os.makedirs(output_dir, exist_ok=True) | |
os.chmod(output_dir, 0o777) | |
# Function to generate a detailed visual description prompt | |
def generate_description_prompt(user_prompt, user_examples): | |
prompt = f'generate enclosed in quotes in the format "<description>" description according to guidelines of {user_prompt} different from {user_examples}' | |
try: | |
generated_text = text_generator(prompt, max_length=150, num_return_sequences=1, truncation=True)[0]['generated_text'] | |
match = re.search(r'"(.*?)"', generated_text) | |
if match: | |
generated_description = match.group(1).strip() # Capture the description between quotes | |
return f'"{generated_description}"' | |
else: | |
return None | |
except Exception as e: | |
print(f"Error generating description for prompt '{user_prompt}': {e}") | |
return None | |
# Seed words pool | |
seed_words = [] | |
used_words = set() | |
def generate_description(user_prompt, user_examples_list): | |
seed_words.extend(user_examples_list) | |
# Select a subject that has not been used | |
available_subjects = [word for word in seed_words if word not in used_words] | |
if not available_subjects: | |
print("No more available subjects to use.") | |
return None, None | |
subject = random.choice(available_subjects) | |
generated_description = generate_description_prompt(user_prompt, subject) | |
if generated_description: | |
# Remove any offending symbols | |
clean_description = generated_description.encode('ascii', 'ignore').decode('ascii') | |
# Print the generated description to the command line | |
print(f"Generated description for subject '{subject}': {clean_description}") | |
# Update used words and seed words | |
used_words.add(subject) | |
seed_words.append(clean_description.strip('"')) # Add the generated description to the seed bank array without quotes | |
return clean_description, subject | |
else: | |
return None, None | |
# Function to generate an image based on the description | |
def generate_image(description, seed=42): | |
prompt = f'detailed photorealistic full shot of {description}' | |
generator = torch.Generator().manual_seed(seed) | |
image = pipe( | |
prompt=prompt, | |
width=512, | |
height=512, | |
num_inference_steps=10, # Use 10 inference steps | |
generator=generator, | |
guidance_scale=7.5, | |
).images[0] | |
return image | |
# Gradio interface | |
def gradio_interface(user_prompt, user_examples): | |
user_examples_list = [example.strip().strip('"') for example in user_examples.split(',')] | |
generated_description, subject = generate_description(user_prompt, user_examples_list) | |
if generated_description: | |
# Generate image | |
image = generate_image(generated_description) | |
image_path = os.path.join(output_dir, f"image_{len(os.listdir(output_dir))}.png") | |
image.save(image_path) | |
os.chmod(image_path, 0o777) | |
return image, generated_description | |
else: | |
return None, "Failed to generate description." | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="Enter the generation task or general thing you are looking for"), | |
gr.Textbox(lines=2, placeholder='Provide a few examples (enclosed in quotes and separated by commas)') | |
], | |
outputs=[ | |
gr.Image(label="Generated Image"), | |
gr.Textbox(label="Generated Description") | |
], | |
title="Description and Image Generator", | |
description="Generate detailed descriptions and images based on your input." | |
) | |
iface.launch(server_name="0.0.0.0", server_port=7860) | |
# Clear GPU memory when the process is closed | |
def clear_gpu_memory(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
print("GPU memory cleared.") | |