RelaxxOfficial commited on
Commit
07dadcd
1 Parent(s): 4e24657

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
+ from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
4
+ import torch
5
+ import cv2
6
+ import os
7
+ import base64
8
+ import soundfile as sf
9
+ import time
10
+
11
+ # --- Set up Models ---
12
+
13
+ # Stable Diffusion for image generation
14
+ scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="scheduler")
15
+ pipe = StableDiffusionPipeline.from_pretrained(
16
+ "stabilityai/stable-diffusion-2-1",
17
+ scheduler=scheduler,
18
+ torch_dtype=torch.float16
19
+ ).to("cuda")
20
+
21
+ # LLaVA for vision-based language understanding
22
+ tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
23
+ model = AutoModelForCausalLM.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers").to("cuda")
24
+
25
+ # Open-source language model for text generation (e.g., GPT-Neo)
26
+ gpt_neo_pipe = pipeline("text-generation", model="EleutherAI/gpt-neo-1.3B")
27
+
28
+ # Text-to-Speech
29
+ text_to_speech = pipeline(
30
+ "text-to-speech", model="espnet/fastspeech2_en_ljspeech"
31
+ )
32
+
33
+ # --- Functions ---
34
+
35
+ def process_image(image_base64, chat_history):
36
+ """Processes an image, sends it to LLaVA, and generates a response."""
37
+ # Prepare LLaVA input
38
+ input_text = f"""<image> {image_base64} </image>\n\nWhat do you see in this image?"""
39
+ inputs = tokenizer(input_text, return_tensors="pt").to("cuda")
40
+
41
+ # Generate response using LLaVA
42
+ with torch.no_grad():
43
+ outputs = model(**inputs)
44
+ response = tokenizer.decode(outputs.logits.argmax(-1)[0], skip_special_tokens=True)
45
+
46
+ # Generate speech from the response
47
+ audio = text_to_speech(response)
48
+ audio_path = "generated_audio.wav"
49
+ sf.write(audio_path, audio[0].numpy(), samplerate=22050)
50
+
51
+ # Update chat history
52
+ chat_history += "You: Image\n"
53
+ chat_history += "Model: " + response + "\n"
54
+
55
+ return chat_history, audio_path
56
+
57
+ def generate_image(prompt, chat_history):
58
+ """Generates an image using Stable Diffusion based on a prompt."""
59
+ image = pipe(
60
+ prompt=prompt,
61
+ guidance_scale=7.5,
62
+ num_inference_steps=50,
63
+ ).images[0]
64
+
65
+ # Update chat history
66
+ chat_history += "You: " + prompt + "\n"
67
+ chat_history += "Model: Image\n"
68
+
69
+ return chat_history, image
70
+
71
+ def process_text(text, chat_history):
72
+ """Processes text, generates a response using GPT-Neo, and generates speech."""
73
+ # Generate response using GPT-Neo
74
+ response = gpt_neo_pipe(
75
+ text,
76
+ max_length=100,
77
+ num_return_sequences=1,
78
+ )[0]["generated_text"]
79
+
80
+ # Generate speech from the response
81
+ audio = text_to_speech(response)
82
+ audio_path = "generated_audio.wav"
83
+ sf.write(audio_path, audio[0].numpy(), samplerate=22050)
84
+
85
+ # Update chat history
86
+ chat_history += "You: " + text + "\n"
87
+ chat_history += "Model: " + response + "\n"
88
+
89
+ return chat_history, audio_path
90
+
91
+ # --- Webcam Capture ---
92
+
93
+ def capture_image():
94
+ """Captures a screenshot from the webcam."""
95
+ cap = cv2.VideoCapture(0)
96
+ ret, frame = cap.read()
97
+ cap.release()
98
+ image = Image.fromarray(frame)
99
+ image_bytes = image.convert("RGB").save("captured_image.jpg", "JPEG")
100
+ with open("captured_image.jpg", "rb") as f:
101
+ image_base64 = base64.b64encode(f.read()).decode("utf-8")
102
+ return image_base64
103
+
104
+ # --- Gradio Interface ---
105
+
106
+ with gr.Blocks() as demo:
107
+ gr.Markdown("## Llama-LLaVA Vision Speech Assistant")
108
+ chat_history = gr.Textbox(label="Chat History", lines=10, interactive=False)
109
+ webcam_output = gr.Image(label="Webcam Feed", interactive=False)
110
+ image_input = gr.Image(label="Uploaded Image")
111
+ text_input = gr.Textbox(label="Enter Text")
112
+ audio_output = gr.Audio(label="Audio Response")
113
+
114
+ # Screenshot button
115
+ screenshot_button = gr.Button("Capture Screenshot")
116
+ screenshot_button.click(fn=capture_image, outputs=image_input)
117
+
118
+ # Image processing (LLaVA)
119
+ image_input.change(fn=process_image, inputs=[image_input, chat_history], outputs=[chat_history, audio_output])
120
+
121
+ # Text processing (GPT-Neo)
122
+ text_input.submit(fn=process_text, inputs=[text_input, chat_history], outputs=[chat_history, audio_output])
123
+
124
+ # Image generation (Stable Diffusion)
125
+ with gr.Tab("Image Generation"):
126
+ image_prompt = gr.Textbox(label="Enter image prompt:")
127
+ image_generation_output = gr.Image(label="Generated Image")
128
+ generate_image_button = gr.Button("Generate Image")
129
+ generate_image_button.click(
130
+ fn=generate_image, inputs=[image_prompt, chat_history], outputs=[chat_history, image_generation_output]
131
+ )
132
+
133
+ # Webcam stream
134
+ with gr.Tab("Webcam"):
135
+ webcam_output = gr.Image(label="Webcam Feed", source="webcam", interactive=False)
136
+ # Update webcam image every second
137
+ def update_webcam():
138
+ cap = cv2.VideoCapture(0)
139
+ while True:
140
+ ret, frame = cap.read()
141
+ if not ret:
142
+ break
143
+ image = Image.fromarray(frame)
144
+ yield image
145
+ time.sleep(1) # Update every second
146
+
147
+ webcam_output.source = update_webcam()
148
+
149
+ demo.launch(share=True)