3
File size: 4,320 Bytes
c037e8f
1355278
 
 
 
d8cf98e
1355278
 
c037e8f
1355278
c037e8f
 
1355278
 
 
 
 
 
 
 
 
 
d15702a
1355278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680aa7b
d8cf98e
680aa7b
 
 
 
 
 
c037e8f
 
 
1355278
c037e8f
 
 
 
 
 
1355278
c037e8f
1355278
c037e8f
1355278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c037e8f
1355278
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import numpy as np
import random
from diffusers import DiffusionPipeline
import torch
from huggingface_hub import InferenceClient
from PIL import Image
from io import BytesIO

# Initialize the Hugging Face client for chat
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

# Initialize the DiffusionPipeline for image generation
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
    torch.cuda.max_memory_allocated(device=device)
    pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
    pipe.enable_xformers_memory_efficient_attention()
    pipe = pipe.to(device)
else: 
    pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
    pipe = pipe.to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt = prompt, 
        negative_prompt = negative_prompt,
        guidance_scale = guidance_scale, 
        num_inference_steps = num_inference_steps, 
        width = width, 
        height = height,
        generator = generator
    ).images[0] 
    return image

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Check for image generation request
    if "generate an image" in message.lower():
        prompt = message.replace("generate an image", "").strip()
        image = infer(
            prompt=prompt,
            negative_prompt="",
            seed=0,
            randomize_seed=True,
            width=512,
            height=512,
            guidance_scale=7.5,
            num_inference_steps=50
        )
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_str = buffered.getvalue()
        return "Here is your generated image:", img_str

    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

# Define Gradio Blocks interface
with gr.Blocks() as demo:
    gr.Markdown("# Chat and Image Generation")
    
    with gr.Row():
        with gr.Column():
            chat_interface = gr.ChatInterface(
                respond,
                additional_inputs=[
                    gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
                    gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
                    gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
                    gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.95,
                        step=0.05,
                        label="Top-p (nucleus sampling)",
                    ),
                ],
            )
    
    def process_image_request(prompt):
        image = infer(
            prompt=prompt,
            negative_prompt="",
            seed=0,
            randomize_seed=True,
            width=512,
            height=512,
            guidance_scale=7.5,
            num_inference_steps=50
        )
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        return buffered.getvalue()
    
    gr.Examples(
        examples=["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "An astronaut riding a green horse", "A delicious ceviche cheesecake slice"],
        inputs=[gr.Textbox(label="Prompt", placeholder="Enter your prompt")],
        outputs=[gr.Image()]
    )

demo.queue().launch()