3 / app.py
vatistasdimitris's picture
Update app.py
1355278 verified
raw
history blame
4.32 kB
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
from huggingface_hub import InferenceClient
from PIL import Image
from io import BytesIO
# Initialize the Hugging Face client for chat
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# Initialize the DiffusionPipeline for image generation
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch.cuda.max_memory_allocated(device=device)
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe.enable_xformers_memory_efficient_attention()
pipe = pipe.to(device)
else:
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt = prompt,
negative_prompt = negative_prompt,
guidance_scale = guidance_scale,
num_inference_steps = num_inference_steps,
width = width,
height = height,
generator = generator
).images[0]
return image
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Check for image generation request
if "generate an image" in message.lower():
prompt = message.replace("generate an image", "").strip()
image = infer(
prompt=prompt,
negative_prompt="",
seed=0,
randomize_seed=True,
width=512,
height=512,
guidance_scale=7.5,
num_inference_steps=50
)
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = buffered.getvalue()
return "Here is your generated image:", img_str
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
# Define Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown("# Chat and Image Generation")
with gr.Row():
with gr.Column():
chat_interface = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
def process_image_request(prompt):
image = infer(
prompt=prompt,
negative_prompt="",
seed=0,
randomize_seed=True,
width=512,
height=512,
guidance_scale=7.5,
num_inference_steps=50
)
buffered = BytesIO()
image.save(buffered, format="PNG")
return buffered.getvalue()
gr.Examples(
examples=["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice"],
inputs=[gr.Textbox(label="Prompt", placeholder="Enter your prompt")],
outputs=[gr.Image()]
)
demo.queue().launch()