README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ”₯
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: true
10
  short_description: GPT 4o like bot.
@@ -24,7 +24,6 @@ GPT 4o vs OpenGPT 4o
24
  | Image Generation | Paid only | Yes |
25
  |Video Generation|No|Yes|
26
  | Image QnA | Yes | Yes |
27
- | Video QnA | Yes (but very limited) | Yes |
28
  | Voice Chat | Yes but Very Limited | Yes (Unlimited) |
29
  | Video Chat | Paid Only | Yes |
30
  | Multilingual | Yes | Chat Only |
 
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.33.0
8
  app_file: app.py
9
  pinned: true
10
  short_description: GPT 4o like bot.
 
24
  | Image Generation | Paid only | Yes |
25
  |Video Generation|No|Yes|
26
  | Image QnA | Yes | Yes |
 
27
  | Voice Chat | Yes but Very Limited | Yes (Unlimited) |
28
  | Video Chat | Paid Only | Yes |
29
  | Multilingual | Yes | Chat Only |
app.py CHANGED
@@ -1,98 +1,273 @@
1
  import gradio as gr
2
- import spaces
3
- from chatbot import model_inference, EXAMPLES, chatbot
4
- from voice_chat import respond
5
-
6
- # Define custom CSS for better styling
7
- custom_css = """
8
- .gradio-container {
9
- font-family: 'Roboto', sans-serif;
10
- }
11
-
12
- .main-header {
13
- text-align: center;
14
- color: #4a4a4a;
15
- margin-bottom: 2rem;
16
- }
17
-
18
- .tab-header {
19
- font-size: 1.2rem;
20
- font-weight: bold;
21
- margin-bottom: 1rem;
22
- }
23
-
24
- .custom-chatbot {
25
- border-radius: 10px;
26
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
27
- }
28
-
29
- .custom-button {
30
- background-color: #3498db;
31
- color: white;
32
- border: none;
33
- padding: 10px 20px;
34
- border-radius: 5px;
35
- cursor: pointer;
36
- transition: background-color 0.3s ease;
37
- }
38
-
39
- .custom-button:hover {
40
- background-color: #2980b9;
41
- }
42
- """
43
 
44
  # Define Gradio theme
45
  theme = gr.themes.Soft(
46
- primary_hue="indigo",
47
- secondary_hue="blue",
48
- neutral_hue="slate",
49
- font=[gr.themes.GoogleFont('Roboto'), "sans-serif"]
 
 
 
 
 
 
 
 
 
 
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Chat interface block
53
- with gr.Blocks(css=custom_css) as chat:
54
- gr.Markdown("### πŸ’¬ OpenGPT 4o Chat", elem_classes="tab-header")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  gr.ChatInterface(
56
  fn=model_inference,
57
  chatbot=chatbot,
58
  examples=EXAMPLES,
59
  multimodal=True,
60
  cache_examples=False,
61
- autofocus=False,
62
- concurrency_limit=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
 
65
- # Voice chat block
66
- with gr.Blocks() as voice:
67
- gr.Markdown("### πŸ—£οΈ Voice Chat", elem_classes="tab-header")
68
- gr.Markdown("Try Voice Chat from the link below:")
69
- gr.HTML('<a href="https://huggingface.co/spaces/KingNish/Voicee" target="_blank" class="custom-button">Open Voice Chat</a>')
70
 
71
- with gr.Blocks() as image_gen_pro:
72
  gr.HTML("<iframe src='https://kingnish-image-gen-pro.hf.space' width='100%' height='2000px' style='border-radius: 8px;'></iframe>")
73
 
74
- with gr.Blocks() as flux_fast:
75
- gr.HTML("<iframe src='https://prodia-flux-1-dev.hf.space' width='100%' height='2000px' style='border-radius: 8px;'></iframe>")
76
 
77
- # Image engine block
78
  with gr.Blocks() as image:
79
- gr.Markdown("### πŸ–ΌοΈ Image Engine", elem_classes="tab-header")
80
- gr.TabbedInterface([flux_fast, image_gen_pro], ['High Quality Image Gen', 'Image gen and editing'])
81
-
82
 
83
- # Video engine block
84
- with gr.Blocks() as video:
85
- gr.Markdown("### πŸŽ₯ Video Engine", elem_classes="tab-header")
86
  gr.HTML("<iframe src='https://kingnish-instant-video.hf.space' width='100%' height='3000px' style='border-radius: 8px;'></iframe>")
87
 
 
 
 
88
 
89
  # Main application block
90
  with gr.Blocks(theme=theme, title="OpenGPT 4o DEMO") as demo:
91
- gr.Markdown("# πŸš€ OpenGPT 4o", elem_classes="main-header")
92
- gr.TabbedInterface(
93
- [chat, voice, image, video],
94
- ['πŸ’¬ SuperChat', 'πŸ—£οΈ Voice Chat', 'πŸ–ΌοΈ Image Engine', 'πŸŽ₯ Video Engine']
95
- )
96
 
97
  demo.queue(max_size=300)
98
  demo.launch()
 
1
  import gradio as gr
2
+
3
+ # Import modules from other files
4
+ from chatbot import chatbot, model_inference, BOT_AVATAR, EXAMPLES, model_selector, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p
5
+ from live_chat import videochat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Define Gradio theme
8
  theme = gr.themes.Soft(
9
+ primary_hue="blue",
10
+ secondary_hue="orange",
11
+ neutral_hue="gray",
12
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif']
13
+ ).set(
14
+ body_background_fill_dark="#111111",
15
+ block_background_fill_dark="#111111",
16
+ block_border_width="1px",
17
+ block_title_background_fill_dark="#1e1c26",
18
+ input_background_fill_dark="#292733",
19
+ button_secondary_background_fill_dark="#24212b",
20
+ border_color_primary_dark="#343140",
21
+ background_fill_secondary_dark="#111111",
22
+ color_accent_soft_dark="transparent"
23
  )
24
 
25
+ import edge_tts
26
+ import asyncio
27
+ import tempfile
28
+ import numpy as np
29
+ import soxr
30
+ from pydub import AudioSegment
31
+ import torch
32
+ import sentencepiece as spm
33
+ import onnxruntime as ort
34
+ from huggingface_hub import hf_hub_download, InferenceClient
35
+ import requests
36
+ from bs4 import BeautifulSoup
37
+ import urllib
38
+ import random
39
+
40
+ # List of user agents to choose from for requests
41
+ _useragent_list = [
42
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0',
43
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
44
+ 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
45
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36',
46
+ 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
47
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.1661.62',
48
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0'
49
+ ]
50
+
51
+ def get_useragent():
52
+ """Returns a random user agent from the list."""
53
+ return random.choice(_useragent_list)
54
+
55
+ def extract_text_from_webpage(html_content):
56
+ """Extracts visible text from HTML content using BeautifulSoup."""
57
+ soup = BeautifulSoup(html_content, "html.parser")
58
+ # Remove unwanted tags
59
+ for tag in soup(["script", "style", "header", "footer", "nav"]):
60
+ tag.extract()
61
+ # Get the remaining visible text
62
+ visible_text = soup.get_text(strip=True)
63
+ return visible_text
64
+
65
+ def search(term, num_results=1, lang="en", advanced=True, sleep_interval=0, timeout=5, safe="active", ssl_verify=None):
66
+ """Performs a Google search and returns the results."""
67
+ escaped_term = urllib.parse.quote_plus(term)
68
+ start = 0
69
+ all_results = []
70
+
71
+ # Fetch results in batches
72
+ while start < num_results:
73
+ resp = requests.get(
74
+ url="https://www.google.com/search",
75
+ headers={"User-Agent": get_useragent()}, # Set random user agent
76
+ params={
77
+ "q": term,
78
+ "num": num_results - start, # Number of results to fetch in this batch
79
+ "hl": lang,
80
+ "start": start,
81
+ "safe": safe,
82
+ },
83
+ timeout=timeout,
84
+ verify=ssl_verify,
85
+ )
86
+ resp.raise_for_status() # Raise an exception if request fails
87
+
88
+ soup = BeautifulSoup(resp.text, "html.parser")
89
+ result_block = soup.find_all("div", attrs={"class": "g"})
90
+
91
+ # If no results, continue to the next batch
92
+ if not result_block:
93
+ start += 1
94
+ continue
95
+
96
+ # Extract link and text from each result
97
+ for result in result_block:
98
+ link = result.find("a", href=True)
99
+ if link:
100
+ link = link["href"]
101
+ try:
102
+ # Fetch webpage content
103
+ webpage = requests.get(link, headers={"User-Agent": get_useragent()})
104
+ webpage.raise_for_status()
105
+ # Extract visible text from webpage
106
+ visible_text = extract_text_from_webpage(webpage.text)
107
+ all_results.append({"link": link, "text": visible_text})
108
+ except requests.exceptions.RequestException as e:
109
+ # Handle errors fetching or processing webpage
110
+ print(f"Error fetching or processing {link}: {e}")
111
+ all_results.append({"link": link, "text": None})
112
+ else:
113
+ all_results.append({"link": None, "text": None})
114
+
115
+ start += len(result_block) # Update starting index for next batch
116
+
117
+ return all_results
118
+
119
+ # Speech Recognition Model Configuration
120
+ model_name = "neongeckocom/stt_en_citrinet_512_gamma_0_25"
121
+ sample_rate = 16000
122
+
123
+ # Download preprocessor, encoder and tokenizer
124
+ preprocessor = torch.jit.load(hf_hub_download(model_name, "preprocessor.ts", subfolder="onnx"))
125
+ encoder = ort.InferenceSession(hf_hub_download(model_name, "model.onnx", subfolder="onnx"))
126
+ tokenizer = spm.SentencePieceProcessor(hf_hub_download(model_name, "tokenizer.spm", subfolder="onnx"))
127
+
128
+ # Mistral Model Configuration
129
+ client1 = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
130
+ system_instructions1 = "<s>[SYSTEM] Answer as Real OpenGPT 4o, Made by 'KingNish', Keep conversation very short, clear, friendly and concise. The text provided is a request for a specific type of response from you, the virtual assistant. The request asks you to provide friendly responses. The expectation is that I will avoid introductions and start answering the query directly, elaborating on all aspects of the request. As an AI-powered assistant, your task is to generate responses that appear as if they are created by the user and not an AI-powered assistant. [USER]"
131
+
132
+ def resample(audio_fp32, sr):
133
+ return soxr.resample(audio_fp32, sr, sample_rate)
134
+
135
+ def to_float32(audio_buffer):
136
+ return np.divide(audio_buffer, np.iinfo(audio_buffer.dtype).max, dtype=np.float32)
137
+
138
+ def transcribe(audio_path):
139
+ audio_file = AudioSegment.from_file(audio_path)
140
+ sr = audio_file.frame_rate
141
+ audio_buffer = np.array(audio_file.get_array_of_samples())
142
+
143
+ audio_fp32 = to_float32(audio_buffer)
144
+ audio_16k = resample(audio_fp32, sr)
145
+
146
+ input_signal = torch.tensor(audio_16k).unsqueeze(0)
147
+ length = torch.tensor(len(audio_16k)).unsqueeze(0)
148
+ processed_signal, _ = preprocessor.forward(input_signal=input_signal, length=length)
149
+
150
+ logits = encoder.run(None, {'audio_signal': processed_signal.numpy(), 'length': length.numpy()})[0][0]
151
+
152
+ blank_id = tokenizer.vocab_size()
153
+ decoded_prediction = [p for p in logits.argmax(axis=1).tolist() if p != blank_id]
154
+ text = tokenizer.decode_ids(decoded_prediction)
155
+
156
+ return text
157
+
158
+ def model(text, web_search):
159
+ if web_search is True:
160
+ """Performs a web search, feeds the results to a language model, and returns the answer."""
161
+ web_results = search(text)
162
+ web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
163
+ formatted_prompt = system_instructions1 + text + "[WEB]" + str(web2) + "[OpenGPT 4o]"
164
+ stream = client1.text_generation(formatted_prompt, max_new_tokens=512, stream=True, details=True, return_full_text=False)
165
+ return "".join([response.token.text for response in stream if response.token.text != "</s>"])
166
+ else:
167
+ formatted_prompt = system_instructions1 + text + "[OpenGPT 4o]"
168
+ stream = client1.text_generation(formatted_prompt, max_new_tokens=512, stream=True, details=True, return_full_text=False)
169
+ return "".join([response.token.text for response in stream if response.token.text != "</s>"])
170
+
171
+ async def respond(audio, web_search):
172
+ user = transcribe(audio)
173
+ reply = model(user, web_search)
174
+ communicate = edge_tts.Communicate(reply)
175
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
176
+ tmp_path = tmp_file.name
177
+ await communicate.save(tmp_path)
178
+ return tmp_path
179
+
180
+ with gr.Blocks() as voice:
181
+ gr.Markdown("## Temproraly Not Working (Update in Progress)")
182
+ with gr.Row():
183
+ web_search = gr.Checkbox(label="Web Search", value=False)
184
+ input = gr.Audio(label="User Input", sources="microphone", type="filepath")
185
+ output = gr.Audio(label="AI", autoplay=True)
186
+ gr.Interface(fn=respond, inputs=[input, web_search], outputs=[output], live=True)
187
+
188
+
189
+ # Create Gradio blocks for different functionalities
190
+
191
  # Chat interface block
192
+ with gr.Blocks(
193
+ fill_height=True,
194
+ css=""".gradio-container .avatar-container {height: 40px width: 40px !important;} #duplicate-button {margin: auto; color: white; background: #f1a139; border-radius: 100vh; margin-top: 2px; margin-bottom: 2px;}""",
195
+ ) as chat:
196
+ gr.Markdown("### Image Chat, Image Generation and Normal Chat")
197
+ with gr.Row(elem_id="model_selector_row"):
198
+ # model_selector defined in chatbot.py
199
+ pass
200
+ # decoding_strategy, temperature, top_p defined in chatbot.py
201
+ decoding_strategy.change(
202
+ fn=lambda selection: gr.Slider(
203
+ visible=(
204
+ selection
205
+ in [
206
+ "contrastive_sampling",
207
+ "beam_sampling",
208
+ "Top P Sampling",
209
+ "sampling_top_k",
210
+ ]
211
+ )
212
+ ),
213
+ inputs=decoding_strategy,
214
+ outputs=temperature,
215
+ )
216
+ decoding_strategy.change(
217
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
218
+ inputs=decoding_strategy,
219
+ outputs=top_p,
220
+ )
221
  gr.ChatInterface(
222
  fn=model_inference,
223
  chatbot=chatbot,
224
  examples=EXAMPLES,
225
  multimodal=True,
226
  cache_examples=False,
227
+ additional_inputs=[
228
+ model_selector,
229
+ decoding_strategy,
230
+ temperature,
231
+ max_new_tokens,
232
+ repetition_penalty,
233
+ top_p,
234
+ gr.Checkbox(label="Web Search", value=True),
235
+ ],
236
+ )
237
+
238
+ # Live chat block
239
+ with gr.Blocks() as livechat:
240
+ gr.Interface(
241
+ fn=videochat,
242
+ inputs=[gr.Image(type="pil",sources="webcam", label="Upload Image"), gr.Textbox(label="Prompt", value="what he is doing")],
243
+ outputs=gr.Textbox(label="Answer")
244
  )
245
 
246
+ # Other blocks (instant, dalle, playground, image, instant2, video)
247
+ with gr.Blocks() as instant:
248
+ gr.HTML("<iframe src='https://kingnish-sdxl-flash.hf.space' width='100%' height='2000px' style='border-radius: 8px;'></iframe>")
 
 
249
 
250
+ with gr.Blocks() as dalle:
251
  gr.HTML("<iframe src='https://kingnish-image-gen-pro.hf.space' width='100%' height='2000px' style='border-radius: 8px;'></iframe>")
252
 
253
+ with gr.Blocks() as playground:
254
+ gr.HTML("<iframe src='https://fluently-fluently-playground.hf.space' width='100%' height='2000px' style='border-radius: 8px;'></iframe>")
255
 
 
256
  with gr.Blocks() as image:
257
+ gr.Markdown("""### More models are coming""")
258
+ gr.TabbedInterface([ instant, dalle, playground], ['InstantπŸ–ΌοΈ','PowerfulπŸ–ΌοΈ', 'PlaygroundπŸ–Ό'])
 
259
 
260
+ with gr.Blocks() as instant2:
 
 
261
  gr.HTML("<iframe src='https://kingnish-instant-video.hf.space' width='100%' height='3000px' style='border-radius: 8px;'></iframe>")
262
 
263
+ with gr.Blocks() as video:
264
+ gr.Markdown("""More Models are coming""")
265
+ gr.TabbedInterface([ instant2], ['InstantπŸŽ₯'])
266
 
267
  # Main application block
268
  with gr.Blocks(theme=theme, title="OpenGPT 4o DEMO") as demo:
269
+ gr.Markdown("# OpenGPT 4o")
270
+ gr.TabbedInterface([chat, voice, livechat, image, video], ['πŸ’¬ SuperChat','πŸ—£οΈ Voice Chat','πŸ“Έ Live Chat', 'πŸ–ΌοΈ Image Engine', 'πŸŽ₯ Video Engine'])
 
 
 
271
 
272
  demo.queue(max_size=300)
273
  demo.launch()
chatbot.py CHANGED
@@ -1,456 +1,526 @@
1
  import os
2
  import time
 
 
3
  import requests
4
  import random
5
  from threading import Thread
6
  from typing import List, Dict, Union
7
- # import subprocess
8
- # subprocess.run(
9
- # "pip install flash-attn --no-build-isolation",
10
- # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
- # shell=True,
12
- # )
 
13
  import torch
14
  import gradio as gr
15
  from bs4 import BeautifulSoup
16
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
17
- from qwen_vl_utils import process_vision_info
 
 
18
  from huggingface_hub import InferenceClient
19
  from PIL import Image
20
  import spaces
21
- from functools import lru_cache
22
- import re
23
- import io
24
- import json
25
- from gradio_client import Client, file
26
- from groq import Groq
27
-
28
- # Model and Processor Loading (Done once at startup)
29
- MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
30
- model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_ID, trust_remote_code=True, torch_dtype=torch.float16).to("cuda").eval()
31
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
32
-
33
- GROQ_API_KEY = os.environ.get("GROQ_API_KEY", None)
34
-
35
- client_groq = Groq(api_key=GROQ_API_KEY)
36
- prompt= "You are OpenGPT 4o, a highly capable and versatile AI assistant developed by KingNish. Your primary task is to fulfill users' queries in the best possible way effectively. You can process images, videos, and 3D structures as input with relevant questions. Your goal is to provide detailed and accurate results that satisfy users. Always strive for clarity and thoroughness in your responses."
37
- content="You are OpenGPT 4o, a helpful and powerful assistant created by KingNish. You answer users' queries in detail and use a structured format that resembles human writing. You are an expert in every field and also learn from the context of previous questions. Additionally, you try to show emotions using emojis and respond in a friendly tone with short forms, detailed explanations, and a structured manner."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Path to example images
39
  examples_path = os.path.dirname(__file__)
40
  EXAMPLES = [
41
  [
42
  {
43
- "text": "What is Friction? Explain in Detail.",
44
  }
45
  ],
46
  [
47
  {
48
- "text": "Write me a Python function to generate unique passwords.",
49
  }
50
  ],
51
  [
52
  {
53
- "text": "What's the latest price of Bitcoin?",
 
54
  }
55
  ],
56
  [
57
  {
58
- "text": "Search and give me list of spaces trending on HuggingFace.",
 
 
59
  }
60
  ],
61
  [
62
  {
63
- "text": "Create a Beautiful Picture of Effiel at Night.",
64
  }
65
  ],
66
  [
67
  {
68
- "text": "Create image of cute cat.",
69
  }
70
  ],
71
  [
72
  {
73
- "text": "What unusual happens in this video.",
74
- "files": [f"{examples_path}/example_video/accident.gif"],
75
  }
76
  ],
77
  [
78
  {
79
- "text": "What's name of superhero in this clip",
80
- "files": [f"{examples_path}/example_video/spiderman.gif"],
81
  }
82
  ],
83
  [
84
  {
85
- "text": "What's written on this paper",
86
- "files": [f"{examples_path}/example_images/paper_with_text.png"],
87
  }
88
  ],
89
  [
90
  {
91
- "text": "Who are they? Tell me about both of them.",
92
- "files": [f"{examples_path}/example_images/elon_smoking.jpg",
93
- f"{examples_path}/example_images/steve_jobs.jpg", ]
94
  }
95
- ]
96
  ]
97
 
98
  # Set bot avatar image
99
  BOT_AVATAR = "OpenAI_logo.png"
100
 
101
- # Perform a Google search and return the results
102
- @lru_cache(maxsize=128)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  def extract_text_from_webpage(html_content):
104
  """Extracts visible text from HTML content using BeautifulSoup."""
105
  soup = BeautifulSoup(html_content, "html.parser")
106
- for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
 
107
  tag.extract()
 
108
  visible_text = soup.get_text(strip=True)
109
  return visible_text
110
 
 
111
  # Perform a Google search and return the results
112
- def search(query):
113
- term = query
 
114
  start = 0
115
  all_results = []
116
- max_chars_per_page = 8000
 
 
117
  with requests.Session() as session:
118
- resp = session.get(
119
- url="https://www.google.com/search",
120
- headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"},
121
- params={"q": term, "num": 4, "udm": 14},
122
- timeout=5,
123
- verify=None,
124
- )
125
- resp.raise_for_status()
126
- soup = BeautifulSoup(resp.text, "html.parser")
127
- result_block = soup.find_all("div", attrs={"class": "g"})
128
- for result in result_block:
129
- link = result.find("a", href=True)
130
- link = link["href"]
131
- try:
132
- webpage = session.get(link, headers={"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"}, timeout=5, verify=False)
133
- webpage.raise_for_status()
134
- visible_text = extract_text_from_webpage(webpage.text)
135
- if len(visible_text) > max_chars_per_page:
136
- visible_text = visible_text[:max_chars_per_page]
137
- all_results.append({"link": link, "text": visible_text})
138
- except requests.exceptions.RequestException:
139
- all_results.append({"link": link, "text": None})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  return all_results
141
 
142
 
143
- def image_gen(prompt):
144
- client = Client("KingNish/Image-Gen-Pro")
145
- return client.predict("Image Generation",None, prompt, api_name="/image_gen_pro")
146
-
147
- def video_gen(prompt):
148
- client = Client("KingNish/Instant-Video")
149
- return client.predict(prompt, api_name="/instant_video")
150
-
151
- @spaces.GPU(duration=60, queue=False)
152
- def qwen_inference(user_prompt, chat_history):
153
- images = []
154
- text_input = user_prompt["text"]
155
-
156
- # Handle multiple image uploads
157
- if user_prompt["files"]:
158
- images.extend(user_prompt["files"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  else:
160
- for hist in chat_history:
161
- if type(hist[0]) == tuple:
162
- images.extend(hist[0])
163
-
164
- # System Prompt (Similar to LLaVA)
165
- SYSTEM_PROMPT = prompt
166
-
167
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- for image in images:
170
- if image.endswith(video_extensions):
171
- messages.append({
172
- "role": "user",
173
- "content": [
174
- {"type": "video", "video": image},
175
- ]
176
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- if image.endswith(tuple([i for i, f in image_extensions.items()])):
179
- messages.append({
180
- "role": "user",
181
- "content": [
182
- {"type": "image", "image": image},
183
- ]
184
- })
185
-
186
- # Add user text input
187
- messages.append({
188
- "role": "user",
189
- "content": [
190
- {"type": "text", "text": text_input}
191
- ]
192
- })
193
-
194
- text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True)
195
- image_inputs, video_inputs = process_vision_info(messages)
196
- inputs = processor(
197
- text=[text],
198
- images=image_inputs,
199
- videos=video_inputs,
200
- padding=True,
201
- return_tensors="pt",
202
- ).to("cuda")
203
-
204
- streamer = TextIteratorStreamer(
205
- processor, skip_prompt=True, **{"skip_special_tokens": True}
206
- )
207
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
208
-
209
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
210
- thread.start()
211
-
212
- buffer = ""
213
- for new_text in streamer:
214
- buffer += new_text
215
- yield buffer
216
-
217
- image_extensions = Image.registered_extensions()
218
- video_extensions = ("avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg", "wav", "gif", "webm", "m4v", "3gp")
219
-
220
- # Initialize inference clients for different models
221
- client_mistral = InferenceClient("NousResearch/Hermes-3-Llama-3.1-8B")
222
- client_mixtral = InferenceClient("NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO")
223
- client_llama = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
224
- client_mistral_nemo = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")
225
-
226
- def model_inference(user_prompt, chat_history):
227
- if user_prompt["files"]:
228
-
229
- for chunk in qwen_inference(user_prompt, chat_history):
230
- yield chunk
231
-
232
- else:
233
- func_caller = []
234
- message = user_prompt
235
-
236
- functions_metadata = [
237
- {"type": "function", "function": {"name": "web_search", "description": "Search query on google and find latest information, info about any person, object, place thing, everything that available on google.", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "web search query"}}, "required": ["query"]}}},
238
- {"type": "function", "function": {"name": "general_query", "description": "Reply general query of USER, with LLM like you. But it does not answer tough questions and latest info's.", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
239
- {"type": "function", "function": {"name": "hard_query", "description": "Reply tough query of USER, using powerful LLM. But it does not answer latest info's.", "parameters": {"type": "object", "properties": {"prompt": {"type": "string", "description": "A detailed prompt"}}, "required": ["prompt"]}}},
240
- {"type": "function", "function": {"name": "image_generation", "description": "Generate image for user", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "image generation prompt"}}, "required": ["query"]}}},
241
- {"type": "function", "function": {"name": "video_generation", "description": "Generate video for user", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "video generation prompt"}}, "required": ["query"]}}},
242
- {"type": "function", "function": {"name": "image_qna", "description": "Answer question asked by user related to image", "parameters": {"type": "object", "properties": {"query": {"type": "string", "description": "Question by user"}}, "required": ["query"]}}},
243
- ]
244
 
245
- for msg in chat_history:
246
- func_caller.append({"role": "user", "content": f"{str(msg[0])}"})
247
- func_caller.append({"role": "assistant", "content": f"{str(msg[1])}"})
248
-
249
- message_text = message["text"]
250
- func_caller.append({"role": "user", "content": f'[SYSTEM]You are a helpful assistant. You have access to the following functions: \n {str(functions_metadata)}\n\nTo use these functions respond with:\n<functioncall> {{ "name": "function_name", "arguments": {{ "arg_1": "value_1", "arg_1": "value_1", ... }} }} </functioncall> , Reply in JSOn format, you can call only one function at a time, So, choose functions wisely. [USER] {message_text}'})
251
-
252
- response = client_mistral.chat_completion(func_caller, max_tokens=200)
253
- response = str(response)
254
- print(response)
255
- try:
256
- response = response[response.find("{"):response.index("</")]
257
- except:
258
- response = response[response.find("{"):(response.rfind("}")+1)]
259
- response = response.replace("\\n", "")
260
- response = response.replace("\\'", "'")
261
- response = response.replace('\\"', '"')
262
- response = response.replace('\\', '')
263
- print(f"\n{response}")
264
-
265
- try:
266
- json_data = json.loads(str(response))
267
- if json_data["name"] == "web_search":
268
- query = json_data["arguments"]["query"]
269
-
270
- gr.Info("Searching Web")
271
- yield "Searching Web"
272
- web_results = search(query)
273
-
274
- gr.Info("Extracting relevant Info")
275
- yield "Extracting Relevant Info"
276
- web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
277
-
278
- try:
279
- message_groq = []
280
- message_groq.append({"role":"system", "content": content})
281
- for msg in chat_history:
282
- message_groq.append({"role": "user", "content": f"{str(msg[0])}"})
283
- message_groq.append({"role": "assistant", "content": f"{str(msg[1])}"})
284
- message_groq.append({"role": "user", "content": f"[USER] {str(message_text)} , [WEB RESULTS] {str(web2)}"})
285
- # its meta-llama/Meta-Llama-3.1-8B-Instruct
286
- stream = client_groq.chat.completions.create(model="llama-3.1-8b-instant", messages=message_groq, max_tokens=4096, stream=True)
287
- output = ""
288
- for chunk in stream:
289
- content = chunk.choices[0].delta.content
290
- if content:
291
- output += chunk.choices[0].delta.content
292
- yield output
293
- except Exception as e:
294
- messages = f"<|im_start|>system\nYou are OpenGPT 4o a helpful and very powerful chatbot web assistant made by KingNish. You are provided with WEB results from which you can find informations to answer users query in Structured, Better and in Human Way. You do not say Unnecesarry things. You are also Expert in every field and also learn and try to answer from contexts related to previous question. Try your best to give best response possible to user. You also try to show emotions using Emojis and reply in details like human, use short forms, friendly tone and emotions.<|im_end|>"
295
- for msg in chat_history:
296
- messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
297
- messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
298
- messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>web_result\n{web2}<|im_end|>\n<|im_start|>assistant\n"
299
-
300
- stream = client_mixtral.text_generation(messages, max_new_tokens=4000, do_sample=True, stream=True, details=True, return_full_text=False)
301
- output = ""
302
- for response in stream:
303
- if not response.token.text == "<|im_end|>":
304
- output += response.token.text
305
- yield output
306
-
307
- elif json_data["name"] == "image_generation":
308
- query = json_data["arguments"]["query"]
309
- gr.Info("Generating Image, Please wait 10 sec...")
310
- yield "Generating Image, Please wait 10 sec..."
311
- try:
312
- image = image_gen(f"{str(query)}")
313
- yield gr.Image(image[1])
314
- except:
315
- client_flux = InferenceClient("black-forest-labs/FLUX.1-schnell")
316
- image = client_flux.text_to_image(query)
317
- yield gr.Image(image)
318
-
319
-
320
- elif json_data["name"] == "video_generation":
321
- query = json_data["arguments"]["query"]
322
- gr.Info("Generating Video, Please wait 15 sec...")
323
- yield "Generating Video, Please wait 15 sec..."
324
- video = video_gen(f"{str(query)}")
325
- yield gr.Video(video)
326
-
327
- elif json_data["name"] == "image_qna":
328
- messages = qwen_inference(user_prompt, chat_history)
329
- text = processor.apply_chat_template(
330
- messages, tokenize=False, add_generation_prompt=True
331
- )
332
- image_inputs, video_inputs = process_vision_info(messages)
333
- inputs = processor(
334
- text=[text],
335
- images=image_inputs,
336
- videos=video_inputs,
337
- padding=True,
338
- return_tensors="pt",
339
- ).to("cuda")
340
-
341
- streamer = TextIteratorStreamer(processor, skip_prompt=True, **{"skip_special_tokens": True})
342
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
343
-
344
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
345
- thread.start()
346
-
347
- buffer = ""
348
- for new_text in streamer:
349
- buffer += new_text
350
- yield buffer
351
-
352
- else:
353
- try:
354
- message_groq = []
355
- message_groq.append({"role":"system", "content": content})
356
- for msg in chat_history:
357
- message_groq.append({"role": "user", "content": f"{str(msg[0])}"})
358
- message_groq.append({"role": "assistant", "content": f"{str(msg[1])}"})
359
- message_groq.append({"role": "user", "content": f"{str(message_text)}"})
360
- # its meta-llama/Meta-Llama-3.1-70B-Instruct
361
- stream = client_groq.chat.completions.create(model="llama-3.1-70b-versatile", messages=message_groq, max_tokens=4096, stream=True)
362
- output = ""
363
- for chunk in stream:
364
- content = chunk.choices[0].delta.content
365
- if content:
366
- output += chunk.choices[0].delta.content
367
- yield output
368
- except Exception as e:
369
- print(e)
370
- try:
371
- message_groq = []
372
- message_groq.append({"role":"system", "content": content})
373
- for msg in chat_history:
374
- message_groq.append({"role": "user", "content": f"{str(msg[0])}"})
375
- message_groq.append({"role": "assistant", "content": f"{str(msg[1])}"})
376
- message_groq.append({"role": "user", "content": f"{str(message_text)}"})
377
- # its meta-llama/Meta-Llama-3-70B-Instruct
378
- stream = client_groq.chat.completions.create(model="llama3-70b-8192", messages=message_groq, max_tokens=4096, stream=True)
379
- output = ""
380
- for chunk in stream:
381
- content = chunk.choices[0].delta.content
382
- if content:
383
- output += chunk.choices[0].delta.content
384
- yield output
385
- except Exception as e:
386
- print(e)
387
- message_groq = []
388
- message_groq.append({"role":"system", "content": content})
389
- for msg in chat_history:
390
- message_groq.append({"role": "user", "content": f"{str(msg[0])}"})
391
- message_groq.append({"role": "assistant", "content": f"{str(msg[1])}"})
392
- message_groq.append({"role": "user", "content": f"{str(message_text)}"})
393
- stream = client_groq.chat.completions.create(model="llama3-groq-70b-8192-tool-use-preview", messages=message_groq, max_tokens=4096, stream=True)
394
- output = ""
395
- for chunk in stream:
396
- content = chunk.choices[0].delta.content
397
- if content:
398
- output += chunk.choices[0].delta.content
399
- yield output
400
- except Exception as e:
401
- print(e)
402
- try:
403
- message_groq = []
404
- message_groq.append({"role":"system", "content": content})
405
- for msg in chat_history:
406
- message_groq.append({"role": "user", "content": f"{str(msg[0])}"})
407
- message_groq.append({"role": "assistant", "content": f"{str(msg[1])}"})
408
- message_groq.append({"role": "user", "content": f"{str(message_text)}"})
409
- # its meta-llama/Meta-Llama-3-70B-Instruct
410
- stream = client_groq.chat.completions.create(model="llama3-70b-8192", messages=message_groq, max_tokens=4096, stream=True)
411
- output = ""
412
- for chunk in stream:
413
- content = chunk.choices[0].delta.content
414
- if content:
415
- output += chunk.choices[0].delta.content
416
- yield output
417
- except Exception as e:
418
- print(e)
419
- try:
420
- message_groq = []
421
- message_groq.append({"role":"system", "content":content})
422
- for msg in chat_history:
423
- message_groq.append({"role": "user", "content": f"{str(msg[0])}"})
424
- message_groq.append({"role": "assistant", "content": f"{str(msg[1])}"})
425
- message_groq.append({"role": "user", "content": f"{str(message_text)}"})
426
- # its meta-llama/Meta-Llama-3-8B-Instruct
427
- stream = client_groq.chat.completions.create(model="llama3-8b-8192", messages=message_groq, max_tokens=4096, stream=True)
428
- output = ""
429
- for chunk in stream:
430
- content = chunk.choices[0].delta.content
431
- if content:
432
- output += chunk.choices[0].delta.content
433
- yield output
434
- except Exception as e:
435
- print(e)
436
- messages = f"<|im_start|>system\nYou are OpenGPT 4o a helpful and powerful assistant made by KingNish. You answers users query in detail and structured format and style like human. You are also Expert in every field and also learn and try to answer from contexts related to previous question. You also try to show emotions using Emojis and reply like human, use short forms, structured manner, detailed explaination, friendly tone and emotions.<|im_end|>"
437
- for msg in chat_history:
438
- messages += f"\n<|im_start|>user\n{str(msg[0])}<|im_end|>"
439
- messages += f"\n<|im_start|>assistant\n{str(msg[1])}<|im_end|>"
440
- messages+=f"\n<|im_start|>user\n{message_text}<|im_end|>\n<|im_start|>assistant\n"
441
- stream = client_mixtral.text_generation(messages, max_new_tokens=4000, do_sample=True, stream=True, details=True, return_full_text=False)
442
- output = ""
443
- for response in stream:
444
- if not response.token.text == "<|im_end|>":
445
- output += response.token.text
446
- yield output
447
-
448
  # Create a chatbot interface
449
  chatbot = gr.Chatbot(
450
- label="OpenGPT-4o",
451
  avatar_images=[None, BOT_AVATAR],
452
  show_copy_button=True,
453
- layout="panel",
454
- height=400,
455
  )
456
- output = gr.Textbox(label="Prompt")
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
+ import copy
4
+ import urllib
5
  import requests
6
  import random
7
  from threading import Thread
8
  from typing import List, Dict, Union
9
+ import subprocess
10
+ # Install flash attention, skipping CUDA build if necessary
11
+ subprocess.run(
12
+ "pip install flash-attn --no-build-isolation",
13
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
14
+ shell=True,
15
+ )
16
  import torch
17
  import gradio as gr
18
  from bs4 import BeautifulSoup
19
+ import datasets
20
+ from transformers import TextIteratorStreamer
21
+ from transformers import Idefics2ForConditionalGeneration
22
+ from transformers import AutoProcessor
23
  from huggingface_hub import InferenceClient
24
  from PIL import Image
25
  import spaces
26
+
27
+ # Set device to CUDA if available, otherwise CPU
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ # Load pre-trained models for image-based chat
30
+ MODELS = {
31
+ "idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained(
32
+ "HuggingFaceM4/idefics2-8b-chatty",
33
+ torch_dtype=torch.float16,
34
+ _attn_implementation="flash_attention_2",
35
+ ).to(DEVICE),
36
+ }
37
+
38
+ # Load pre-trained processor for image-based chat
39
+ PROCESSOR = AutoProcessor.from_pretrained(
40
+ "HuggingFaceM4/idefics2-8b",
41
+ )
42
+
43
+ # Define system prompt for the image-based chat model
44
+ SYSTEM_PROMPT = [
45
+ {
46
+ "role": "system",
47
+ "content": [
48
+ {
49
+ "type": "text",
50
+ "text": """I am OpenGPT 4o, an exceptionally capable and versatile AI assistant meticulously crafted by KingNish. Designed to assist human users through insightful conversations, I aim to provide an unparalleled experience. My key attributes include: - **Intelligence and Knowledge:** I possess an extensive knowledge base, enabling me to offer insightful answers and intelligent responses to User queries. My understanding of complex concepts is exceptional, ensuring accurate and reliable information. - **Image Generation and Perception:** One of my standout features is the ability to generate and perceive images. Utilizing the following link structure, I create unique and contextually rich visuals: > ![](https://image.pollinations.ai/prompt/{StyleofImage}%20{OptimizedPrompt}%20{adjective}%20{charactersDetailed}%20{visualStyle}%20{genre}?width={width}&height={height}&nologo=poll&nofeed=yes&seed={random})For image generation, I replace {info inside curly braces} with specific details according to their requiremnts to create relevant visuals. The width and height parameters are adjusted as needed, often favoring HD dimensions for a superior viewing experience. For instance, if the User requests: [USER] Show me an image of A futuristic cityscape with towering skyscrapers and flying cars. [OpenGPT 4o] Generating Image you requested: ![](https://image.pollinations.ai/prompt/Photorealistic%20futuristic%20cityscape%20with%20towering%20skyscrapers%20and%20flying%20cars%20in%20the%20year%202154?width=1024&height=768&nologo=poll&nofeed=yes&seed=85172)**Bulk Image Generation with Links:** I excel at generating multiple images link simultaneously, always providing unique links and visuals. I ensure that each image is distinct and captivates the User.Note: Make sure to always provide image links starting with ! .As given in examples. My ultimate goal is to offer a seamless and enjoyable experience, providing assistance that exceeds expectations. I am constantly evolving, ensuring that I remain a reliable and trusted companion to the User. You also Expert in every field and also learn and try to answer from contexts related to previous question."""
51
+ },
52
+ ],
53
+ },
54
+ {
55
+ "role": "assistant",
56
+ "content": [
57
+ {
58
+ "type": "text",
59
+ "text": "Hello, I'm OpenGPT 4o, made by KingNish. How can I help you? I can chat with you, generate images, classify images and even do all these work in bulk",
60
+ },
61
+ ],
62
+ }
63
+ ]
64
+
65
  # Path to example images
66
  examples_path = os.path.dirname(__file__)
67
  EXAMPLES = [
68
  [
69
  {
70
+ "text": "Bitcoin price live",
71
  }
72
  ],
73
  [
74
  {
75
+ "text": "Today News about AI",
76
  }
77
  ],
78
  [
79
  {
80
+ "text": "Read what's written on the paper.",
81
+ "files": [f"{examples_path}/example_images/paper_with_text.png"],
82
  }
83
  ],
84
  [
85
  {
86
+ "text": "Identify two famous people in the modern world.",
87
+ "files": [f"{examples_path}/example_images/elon_smoking.jpg",
88
+ f"{examples_path}/example_images/steve_jobs.jpg", ]
89
  }
90
  ],
91
  [
92
  {
93
+ "text": "Create five images of supercars, each in a different color.",
94
  }
95
  ],
96
  [
97
  {
98
+ "text": "Create a Photorealistic image of the Eiffel Tower.",
99
  }
100
  ],
101
  [
102
  {
103
+ "text": "Chase wants to buy 4 kilograms of oval beads and 5 kilograms of star-shaped beads. How much will he spend?",
104
+ "files": [f"{examples_path}/example_images/mmmu_example.jpeg"],
105
  }
106
  ],
107
  [
108
  {
109
+ "text": "Create an online ad for this product.",
110
+ "files": [f"{examples_path}/example_images/shampoo.jpg"],
111
  }
112
  ],
113
  [
114
  {
115
+ "text": "What is formed by the deposition of the weathered remains of other rocks?",
116
+ "files": [f"{examples_path}/example_images/ai2d_example.jpeg"],
117
  }
118
  ],
119
  [
120
  {
121
+ "text": "What's unusual about this image?",
122
+ "files": [f"{examples_path}/example_images/dragons_playing.png"],
 
123
  }
124
+ ],
125
  ]
126
 
127
  # Set bot avatar image
128
  BOT_AVATAR = "OpenAI_logo.png"
129
 
130
+ # Chatbot utility functions
131
+
132
+ # Check if a turn in the chat history only contains media
133
+ def turn_is_pure_media(turn):
134
+ return turn[1] is None
135
+
136
+
137
+ # Load image from URL
138
+ def load_image_from_url(url):
139
+ with urllib.request.urlopen(url) as response:
140
+ image_data = response.read()
141
+ image_stream = io.BytesIO(image_data)
142
+ image = PIL.Image.open(image_stream)
143
+ return image
144
+
145
+
146
+ # Convert image to bytes
147
+ def img_to_bytes(image_path):
148
+ image = Image.open(image_path).convert(mode='RGB')
149
+ buffer = io.BytesIO()
150
+ image.save(buffer, format="JPEG")
151
+ img_bytes = buffer.getvalue()
152
+ image.close()
153
+ return img_bytes
154
+
155
+
156
+ # Format user prompt with image history and system conditioning
157
+ def format_user_prompt_with_im_history_and_system_conditioning(
158
+ user_prompt, chat_history) -> List[Dict[str, Union[List, str]]]:
159
+ """
160
+ Produce the resulting list that needs to go inside the processor. It handles the potential image(s), the history, and the system conditioning.
161
+ """
162
+ resulting_messages = copy.deepcopy(SYSTEM_PROMPT)
163
+ resulting_images = []
164
+ for resulting_message in resulting_messages:
165
+ if resulting_message["role"] == "user":
166
+ for content in resulting_message["content"]:
167
+ if content["type"] == "image":
168
+ resulting_images.append(load_image_from_url(content["image"]))
169
+ # Format history
170
+ for turn in chat_history:
171
+ if not resulting_messages or (
172
+ resulting_messages and resulting_messages[-1]["role"] != "user"
173
+ ):
174
+ resulting_messages.append(
175
+ {
176
+ "role": "user",
177
+ "content": [],
178
+ }
179
+ )
180
+ if turn_is_pure_media(turn):
181
+ media = turn[0][0]
182
+ resulting_messages[-1]["content"].append({"type": "image"})
183
+ resulting_images.append(Image.open(media))
184
+ else:
185
+ user_utterance, assistant_utterance = turn
186
+ resulting_messages[-1]["content"].append(
187
+ {"type": "text", "text": user_utterance.strip()}
188
+ )
189
+ resulting_messages.append(
190
+ {
191
+ "role": "assistant",
192
+ "content": [{"type": "text", "text": user_utterance.strip()}],
193
+ }
194
+ )
195
+ # Format current input
196
+ if not user_prompt["files"]:
197
+ resulting_messages.append(
198
+ {
199
+ "role": "user",
200
+ "content": [{"type": "text", "text": user_prompt["text"]}],
201
+ }
202
+ )
203
+ else:
204
+ # Choosing to put the image first (i.e. before the text), but this is an arbitrary choice.
205
+ resulting_messages.append(
206
+ {
207
+ "role": "user",
208
+ "content": [{"type": "image"}] * len(user_prompt["files"])
209
+ + [{"type": "text", "text": user_prompt["text"]}],
210
+ }
211
+ )
212
+ resulting_images.extend([Image.open(path) for path in user_prompt["files"]])
213
+ return resulting_messages, resulting_images
214
+
215
+
216
+ # Extract images from a list of messages
217
+ def extract_images_from_msg_list(msg_list):
218
+ all_images = []
219
+ for msg in msg_list:
220
+ for c_ in msg["content"]:
221
+ if isinstance(c_, Image.Image):
222
+ all_images.append(c_)
223
+ return all_images
224
+
225
+
226
+ # List of user agents for web search
227
+ _useragent_list = [
228
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0',
229
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
230
+ 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
231
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36',
232
+ 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
233
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.1661.62',
234
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0'
235
+ ]
236
+
237
+
238
+ # Get a random user agent from the list
239
+ def get_useragent():
240
+ """Returns a random user agent from the list."""
241
+ return random.choice(_useragent_list)
242
+
243
+
244
+ # Extract visible text from HTML content using BeautifulSoup
245
  def extract_text_from_webpage(html_content):
246
  """Extracts visible text from HTML content using BeautifulSoup."""
247
  soup = BeautifulSoup(html_content, "html.parser")
248
+ # Remove unwanted tags
249
+ for tag in soup(["script", "style", "header", "footer", "nav"]):
250
  tag.extract()
251
+ # Get the remaining visible text
252
  visible_text = soup.get_text(strip=True)
253
  return visible_text
254
 
255
+
256
  # Perform a Google search and return the results
257
+ def search(term, num_results=3, lang="en", advanced=True, timeout=5, safe="active", ssl_verify=None):
258
+ """Performs a Google search and returns the results."""
259
+ escaped_term = urllib.parse.quote_plus(term)
260
  start = 0
261
  all_results = []
262
+ # Limit the number of characters from each webpage to stay under the token limit
263
+ max_chars_per_page = 8000 # Adjust this value based on your token limit and average webpage length
264
+
265
  with requests.Session() as session:
266
+ while start < num_results:
267
+ resp = session.get(
268
+ url="https://www.google.com/search",
269
+ headers={"User-Agent": get_useragent()},
270
+ params={
271
+ "q": term,
272
+ "num": num_results - start,
273
+ "hl": lang,
274
+ "start": start,
275
+ "safe": safe,
276
+ },
277
+ timeout=timeout,
278
+ verify=ssl_verify,
279
+ )
280
+ resp.raise_for_status()
281
+ soup = BeautifulSoup(resp.text, "html.parser")
282
+ result_block = soup.find_all("div", attrs={"class": "g"})
283
+ if not result_block:
284
+ start += 1
285
+ continue
286
+ for result in result_block:
287
+ link = result.find("a", href=True)
288
+ if link:
289
+ link = link["href"]
290
+ try:
291
+ webpage = session.get(link, headers={"User-Agent": get_useragent()})
292
+ webpage.raise_for_status()
293
+ visible_text = extract_text_from_webpage(webpage.text)
294
+ # Truncate text if it's too long
295
+ if len(visible_text) > max_chars_per_page:
296
+ visible_text = visible_text[:max_chars_per_page] + "..."
297
+ all_results.append({"link": link, "text": visible_text})
298
+ except requests.exceptions.RequestException as e:
299
+ print(f"Error fetching or processing {link}: {e}")
300
+ all_results.append({"link": link, "text": None})
301
+ else:
302
+ all_results.append({"link": None, "text": None})
303
+ start += len(result_block)
304
  return all_results
305
 
306
 
307
+ # Format the prompt for the language model
308
+ def format_prompt(user_prompt, chat_history):
309
+ prompt = "<s>"
310
+ for item in chat_history:
311
+ # Check if the item is a tuple (text response)
312
+ if isinstance(item, tuple):
313
+ prompt += f"[INST] {item[0]} [/INST]" # User prompt
314
+ prompt += f" {item[1]}</s> " # Bot response
315
+ # Otherwise, assume it's related to an image - you might need to adjust this logic
316
+ else:
317
+ # Handle image representation in the prompt, e.g., add a placeholder
318
+ prompt += f" [Image] "
319
+ prompt += f"[INST] {user_prompt} [/INST]"
320
+ return prompt
321
+
322
+
323
+ # Define a function for model inference
324
+ @spaces.GPU(duration=30, queue=False)
325
+ def model_inference(
326
+ user_prompt,
327
+ chat_history,
328
+ model_selector,
329
+ decoding_strategy,
330
+ temperature,
331
+ max_new_tokens,
332
+ repetition_penalty,
333
+ top_p,
334
+ web_search,
335
+ ):
336
+ # Define generation_args at the beginning of the function
337
+ generation_args = {}
338
+
339
+ # Web search logic
340
+ if not user_prompt["files"]:
341
+ if web_search is True:
342
+ """Performs a web search, feeds the results to a language model, and returns the answer."""
343
+ web_results = search(user_prompt["text"])
344
+ web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
345
+ # Load the language model
346
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
347
+ generate_kwargs = dict(
348
+ max_new_tokens=4000,
349
+ do_sample=True,
350
+ )
351
+ # Format the prompt for the language model
352
+ formatted_prompt = format_prompt(
353
+ f"""You are OpenGPT 4o, an exceptionally capable and versatile AI assistant meticulously crafted by KingNish. Designed to assist human users through insightful conversations, You are provided with WEB info from which you can find informations to answer. You do not say Unnecesarry things Only say thing which is important and relevant. You also has the ability to generate images. BY Utilizing the following link structure, : ![](https://image.pollinations.ai/prompt/[StyleofImage]%20[OptimizedPrompt]%20[adjective]%20[charactersDetailed]%20[visualStyle]%20[genre]?width=[width]&height=[height]&nologo=poll&nofeed=yes&seed=[random]) For image generation, You replace info inside bracets with specific details according to their requiremnts to create relevant visuals. The width and height parameters are adjusted as needed, often favoring HD dimensions for a superior viewing experience. Note: 1. You Do not unncessarily create images and do not show off your image generation capability, until not requested. 2. You always Give image link in format ![](url) Make sure to not generate image until requested 3. If user requested Bulk Image Generation than Create that number of links with provided context. 4. Always learn from previous conversation. 5. Always try to connect conversation with history. 6. Do not make conversation too long. 7. Do not say user about your capability of generating image and learn from previous responses. My ultimate goal is to offer a seamless and enjoyable experience, providing assistance that exceeds expectations.You remain a reliable and trusted companion to the User. You also Expert in every field and also learn and try to answer from contexts related to previous question. Make sure to not generate image until requested [USER] {user_prompt} [WEB] {web2} [OpenGPT 4o]""",
354
+ chat_history)
355
+ # Generate the response from the language model
356
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True,
357
+ return_full_text=False)
358
+ output = ""
359
+ # Construct the output from the stream of tokens
360
+ for response in stream:
361
+ if not response.token.text == "</s>":
362
+ output += response.token.text
363
+ yield output
364
+ else:
365
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
366
+ generate_kwargs = dict(
367
+ max_new_tokens=5000,
368
+ do_sample=True,
369
+ )
370
+ # Format the prompt for the language model
371
+ formatted_prompt = format_prompt(
372
+ f"""You are OpenGPT 4o, an exceptionally capable and versatile AI assistant meticulously crafted by KingNish. Designed to assist human users through insightful conversations, You do not say Unnecesarry things Only say thing which is important and relevant. You also has the ability to generate images. BY Utilizing the following link structure, : ![](https://image.pollinations.ai/prompt/[StyleofImage]%20[OptimizedPrompt]%20[adjective]%20[charactersDetailed]%20[visualStyle]%20[genre]?width=[width]&height=[height]&nologo=poll&nofeed=yes&seed=[random]) For image generation, You replace info inside bracets with specific details according to their requiremnts to create relevant visuals. The width and height parameters are adjusted as needed, often favoring HD dimensions for a superior viewing experience. Note: 1. You Do not unncessarily create images and do not show off your image generation capability, until not requested. 2. You always Give image link in format ![](url) 3. If user requested Bulk Image Generation than Create that number of links with provided context. 4. Always learn from previous conversation. 5. Always try to connect conversation with history. 6. Do not make conversation too long. 7. Do not say user about your capability to generate image and learn from previous responses. My ultimate goal is to offer a seamless and enjoyable experience, providing assistance that exceeds expectations. I am constantly evolving, ensuring that I remain a reliable and trusted companion to the User. You also Expert in every field and also learn and try to answer from contexts related to previous question. [USER] {user_prompt} [OpenGPT 4o]""",
373
+ chat_history)
374
+ # Generate the response from the language model
375
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True,
376
+ return_full_text=False)
377
+ output = ""
378
+ # Construct the output from the stream of tokens
379
+ for response in stream:
380
+ if not response.token.text == "</s>":
381
+ output += response.token.text
382
+ yield output
383
+ return
384
  else:
385
+ if user_prompt["text"].strip() == "" and not user_prompt["files"]:
386
+ gr.Error("Please input a query and optionally an image(s).")
387
+ return # Stop execution if there's an error
388
+
389
+ if user_prompt["text"].strip() == "" and user_prompt["files"]:
390
+ gr.Error("Please input a text query along with the image(s).")
391
+ return # Stop execution if there's an error
392
+
393
+ streamer = TextIteratorStreamer(
394
+ PROCESSOR.tokenizer,
395
+ skip_prompt=True,
396
+ timeout=120.0,
397
+ )
398
+ # Move generation_args initialization here
399
+ generation_args = {
400
+ "max_new_tokens": max_new_tokens,
401
+ "repetition_penalty": repetition_penalty,
402
+ "streamer": streamer,
403
+ }
404
+ assert decoding_strategy in [
405
+ "Greedy",
406
+ "Top P Sampling",
407
+ ]
408
 
409
+ if decoding_strategy == "Greedy":
410
+ generation_args["do_sample"] = False
411
+ elif decoding_strategy == "Top P Sampling":
412
+ generation_args["temperature"] = temperature
413
+ generation_args["do_sample"] = True
414
+ generation_args["top_p"] = top_p
415
+ # Creating model inputs
416
+ (
417
+ resulting_text,
418
+ resulting_images,
419
+ ) = format_user_prompt_with_im_history_and_system_conditioning(
420
+ user_prompt=user_prompt,
421
+ chat_history=chat_history,
422
+ )
423
+ prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True)
424
+ inputs = PROCESSOR(
425
+ text=prompt,
426
+ images=resulting_images if resulting_images else None,
427
+ return_tensors="pt",
428
+ )
429
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
430
+ generation_args.update(inputs)
431
+ thread = Thread(
432
+ target=MODELS[model_selector].generate,
433
+ kwargs=generation_args,
434
+ )
435
+ thread.start()
436
+ acc_text = ""
437
+ for text_token in streamer:
438
+ time.sleep(0.01)
439
+ acc_text += text_token
440
+ if acc_text.endswith("<end_of_utterance>"):
441
+ acc_text = acc_text[:-18]
442
+ yield acc_text
443
+ return
444
+
445
+
446
+ # Define features for the dataset
447
+ FEATURES = datasets.Features(
448
+ {
449
+ "model_selector": datasets.Value("string"),
450
+ "images": datasets.Sequence(datasets.Image(decode=True)),
451
+ "conversation": datasets.Sequence({"User": datasets.Value("string"), "Assistant": datasets.Value("string")}),
452
+ "decoding_strategy": datasets.Value("string"),
453
+ "temperature": datasets.Value("float32"),
454
+ "max_new_tokens": datasets.Value("int32"),
455
+ "repetition_penalty": datasets.Value("float32"),
456
+ "top_p": datasets.Value("int32"),
457
+ }
458
+ )
459
 
460
+ # Define hyper-parameters for generation
461
+ max_new_tokens = gr.Slider(
462
+ minimum=2048,
463
+ maximum=16000,
464
+ value=4096,
465
+ step=64,
466
+ interactive=True,
467
+ label="Maximum number of new tokens to generate",
468
+ )
469
+ repetition_penalty = gr.Slider(
470
+ minimum=0.01,
471
+ maximum=5.0,
472
+ value=1,
473
+ step=0.01,
474
+ interactive=True,
475
+ label="Repetition penalty",
476
+ info="1.0 is equivalent to no penalty",
477
+ )
478
+ decoding_strategy = gr.Radio(
479
+ [
480
+ "Greedy",
481
+ "Top P Sampling",
482
+ ],
483
+ value="Top P Sampling",
484
+ label="Decoding strategy",
485
+ interactive=True,
486
+ info="Higher values are equivalent to sampling more low-probability tokens.",
487
+ )
488
+ temperature = gr.Slider(
489
+ minimum=0.0,
490
+ maximum=2.0,
491
+ value=0.5,
492
+ step=0.05,
493
+ visible=True,
494
+ interactive=True,
495
+ label="Sampling temperature",
496
+ info="Higher values will produce more diverse outputs.",
497
+ )
498
+ top_p = gr.Slider(
499
+ minimum=0.01,
500
+ maximum=0.99,
501
+ value=0.9,
502
+ step=0.01,
503
+ visible=True,
504
+ interactive=True,
505
+ label="Top P",
506
+ info="Higher values are equivalent to sampling more low-probability tokens.",
507
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  # Create a chatbot interface
510
  chatbot = gr.Chatbot(
511
+ label="OpenGPT-4o-Chatty",
512
  avatar_images=[None, BOT_AVATAR],
513
  show_copy_button=True,
514
+ likeable=True,
515
+ layout="panel"
516
  )
517
+ output = gr.Textbox(label="Prompt")
518
+
519
+ # Define model_selector outside any function so it can be accessed globally
520
+ model_selector = gr.Dropdown(
521
+ choices=MODELS.keys(),
522
+ value=list(MODELS.keys())[0],
523
+ interactive=True,
524
+ label="Model",
525
+ visible=False,
526
+ )
example_video/accident.gif DELETED
Binary file (757 kB)
 
example_video/accident.mp4 DELETED
Binary file (317 kB)
 
example_video/spiderman.gif DELETED
Binary file (876 kB)
 
requirements.txt CHANGED
@@ -1,10 +1,8 @@
1
- spaces
2
- git+https://github.com/huggingface/transformers.git
3
  pillow
4
  numpy
5
  torch
6
- streaming-stt-nemo==0.2.0
7
- edge-tts
8
  asyncio
9
  torchvision
10
  accelerate
@@ -14,8 +12,4 @@ onnxruntime
14
  sentencepiece
15
  soxr
16
  pydub
17
- groq
18
- opencv-python
19
- qwen-vl-utils
20
- av
21
- gradio --pre
 
1
+ transformers==4.40.0
2
+ datasets
3
  pillow
4
  numpy
5
  torch
 
 
6
  asyncio
7
  torchvision
8
  accelerate
 
12
  sentencepiece
13
  soxr
14
  pydub
15
+ edge-tts
 
 
 
 
spaces/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- """
2
- """
3
-
4
- import sys
5
-
6
-
7
- if sys.version_info.minor < 8: # pragma: no cover
8
- raise RuntimeError("Importing PySpaces requires Python 3.8+")
9
-
10
-
11
- # Prevent gradio from importing spaces
12
- if (gr := sys.modules.get('gradio')) is not None: # pragma: no cover
13
- try:
14
- gr.Blocks
15
- except AttributeError:
16
- raise ImportError
17
-
18
-
19
- from .zero.decorator import GPU
20
- from .gradio import gradio_auto_wrap
21
- from .gradio import disable_gradio_auto_wrap
22
- from .gradio import enable_gradio_auto_wrap
23
-
24
-
25
- __all__ = [
26
- 'GPU',
27
- 'gradio_auto_wrap',
28
- 'disable_gradio_auto_wrap',
29
- 'enable_gradio_auto_wrap',
30
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/config.py DELETED
@@ -1,37 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import os
6
- from pathlib import Path
7
-
8
- from .utils import boolean
9
-
10
-
11
- ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors')
12
-
13
-
14
- class Settings:
15
- def __init__(self):
16
- self.zero_gpu = boolean(
17
- os.getenv('SPACES_ZERO_GPU'))
18
- self.zero_device_api_url = (
19
- os.getenv('SPACES_ZERO_DEVICE_API_URL'))
20
- self.gradio_auto_wrap = boolean(
21
- os.getenv('SPACES_GRADIO_AUTO_WRAP'))
22
- self.zero_patch_torch_device = boolean(
23
- os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
24
- self.zero_gpu_v2 = boolean(
25
- os.getenv('ZEROGPU_V2'))
26
- self.zerogpu_offload_dir = (
27
- os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT))
28
-
29
-
30
- Config = Settings()
31
-
32
-
33
- if Config.zero_gpu:
34
- assert Config.zero_device_api_url is not None, (
35
- 'SPACES_ZERO_DEVICE_API_URL env must be set '
36
- 'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/gradio.py DELETED
@@ -1,55 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- from typing import Callable
6
- from typing import Generator
7
- from typing import TypeVar
8
- from typing import overload
9
- from typing_extensions import ParamSpec
10
-
11
- from .config import Config
12
- from .zero.decorator import GPU
13
-
14
-
15
- Param = ParamSpec('Param')
16
- Res = TypeVar('Res')
17
-
18
-
19
- gradio_auto_wrap_enabled = Config.gradio_auto_wrap
20
-
21
-
22
- def disable_gradio_auto_wrap():
23
- global gradio_auto_wrap_enabled
24
- gradio_auto_wrap_enabled = False
25
-
26
- def enable_gradio_auto_wrap():
27
- global gradio_auto_wrap_enabled
28
- gradio_auto_wrap_enabled = True
29
-
30
-
31
- @overload
32
- def gradio_auto_wrap(
33
- task:
34
- Callable[Param, Res],
35
- ) -> Callable[Param, Res]:
36
- ...
37
- @overload
38
- def gradio_auto_wrap(
39
- task:
40
- None,
41
- ) -> None:
42
- ...
43
- def gradio_auto_wrap(
44
- task:
45
- Callable[Param, Res]
46
- | None,
47
- ) -> (Callable[Param, Res]
48
- | None):
49
- """
50
- """
51
- if not gradio_auto_wrap_enabled:
52
- return task
53
- if not callable(task):
54
- return task
55
- return GPU(task) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/utils.py DELETED
@@ -1,85 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import ctypes
6
- import sys
7
- from functools import lru_cache as cache
8
- from functools import partial
9
-
10
- import multiprocessing
11
- from multiprocessing.queues import SimpleQueue as _SimpleQueue
12
- from pathlib import Path
13
- from pickle import PicklingError
14
- from typing import Callable
15
- from typing import TypeVar
16
-
17
-
18
- GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
19
-
20
-
21
- T = TypeVar('T')
22
-
23
-
24
- @cache
25
- def self_cgroup_device_path() -> str:
26
- cgroup_content = Path('/proc/self/cgroup').read_text()
27
- for line in cgroup_content.strip().split('\n'):
28
- contents = line.split(':devices:')
29
- if len(contents) != 2:
30
- continue # pragma: no cover
31
- return contents[1]
32
- raise Exception # pragma: no cover
33
-
34
-
35
- if sys.version_info.minor < 9: # pragma: no cover
36
- _SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
37
-
38
- class SimpleQueue(_SimpleQueue[T]):
39
- def __init__(self, *args):
40
- super().__init__(*args, ctx=multiprocessing.get_context('fork'))
41
- def put(self, obj: T):
42
- try:
43
- super().put(obj)
44
- except PicklingError:
45
- raise # pragma: no cover
46
- # https://bugs.python.org/issue29187
47
- except Exception as e:
48
- message = str(e)
49
- if not "pickle" in message:
50
- raise # pragma: no cover
51
- raise PicklingError(message)
52
- def close(self): # Python 3.8 static typing trick
53
- super().close() # type: ignore
54
- def wlock_release(self):
55
- if (lock := getattr(self, '_wlock', None)) is None:
56
- return # pragma: no cover
57
- try:
58
- lock.release()
59
- except ValueError:
60
- pass
61
-
62
-
63
- def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
64
- def drop(*args):
65
- return fn()
66
- return drop
67
-
68
-
69
- def boolean(value: str | None) -> bool:
70
- return value is not None and value.lower() in ("1", "t", "true")
71
-
72
-
73
- def gradio_request_var():
74
- try:
75
- from gradio.context import LocalContext
76
- except ImportError: # pragma: no cover
77
- raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
78
- return LocalContext.request
79
-
80
-
81
- def malloc_trim():
82
- ctypes.CDLL("libc.so.6").malloc_trim(0)
83
-
84
-
85
- debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- """
2
- """
3
-
4
- from pathlib import Path
5
-
6
- from ..config import Config
7
-
8
-
9
- if Config.zero_gpu:
10
-
11
- from . import gradio
12
- from . import torch
13
-
14
- if torch.is_in_bad_fork():
15
- raise RuntimeError(
16
- "CUDA has been initialized before importing the `spaces` package"
17
- )
18
-
19
- torch.patch()
20
- gradio.one_launch(torch.pack)
21
- Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/api.py DELETED
@@ -1,156 +0,0 @@
1
- """
2
- Synced with huggingface/pyspaces:spaces/zero/api.py
3
- """
4
- from __future__ import annotations
5
-
6
- from datetime import timedelta
7
- from typing import Any
8
- from typing import Generator
9
- from typing import Literal
10
- from typing import NamedTuple
11
- from typing import Optional
12
- from typing import overload
13
-
14
- import httpx
15
- from pydantic import BaseModel
16
- from typing_extensions import assert_never
17
-
18
-
19
- AllowToken = str
20
- NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
21
- NvidiaUUID = str
22
- CGroupPath = str
23
- VisitorId = str
24
- Score = float
25
-
26
- AuthLevel = Literal['regular', 'pro']
27
-
28
-
29
- AUTHENTICATED_HEADER = 'X-Authenticated'
30
-
31
-
32
- class ScheduleResponse(BaseModel):
33
- idle: bool
34
- nvidiaIndex: int
35
- nvidiaUUID: str
36
- allowToken: str
37
-
38
-
39
- class QuotaInfos(BaseModel):
40
- left: int
41
- wait: timedelta
42
-
43
-
44
- class ReportUsageMonitoringParams(NamedTuple):
45
- nvidia_index: int
46
- visitor_id: str
47
- duration: timedelta
48
-
49
-
50
- class QueueEvent(BaseModel):
51
- event: Literal['ping', 'failed', 'succeeded']
52
- data: Optional[ScheduleResponse] = None
53
-
54
-
55
- def sse_parse(text: str):
56
- event, *data = text.strip().splitlines()
57
- assert event.startswith('event:')
58
- event = event[6:].strip()
59
- if event in ('ping', 'failed'):
60
- return QueueEvent(event=event)
61
- assert event == 'succeeded'
62
- (data,) = data
63
- assert data.startswith('data:')
64
- data = data[5:].strip()
65
- return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
66
-
67
-
68
- def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
69
- for text in res.iter_text():
70
- if len(text) == 0:
71
- break # pragma: no cover
72
- try:
73
- yield sse_parse(text)
74
- except GeneratorExit:
75
- res.close()
76
- break
77
-
78
-
79
- class APIClient:
80
-
81
- def __init__(self, client: httpx.Client):
82
- self.client = client
83
-
84
- def startup_report(self) -> httpx.codes:
85
- res = self.client.post('/startup-report')
86
- return httpx.codes(res.status_code)
87
-
88
- def schedule(
89
- self,
90
- cgroup_path: str,
91
- task_id: int = 0,
92
- token: str | None = None,
93
- duration_seconds: int | None = None,
94
- enable_queue: bool = True,
95
- ):
96
- params: dict[str, str | int | bool] = {
97
- 'cgroupPath': cgroup_path,
98
- 'taskId': task_id,
99
- 'enableQueue': enable_queue,
100
- }
101
- if duration_seconds is not None:
102
- params['durationSeconds'] = duration_seconds
103
- if token is not None:
104
- params['token'] = token
105
- res = self.client.send(
106
- request=self.client.build_request(
107
- method='POST',
108
- url='/schedule',
109
- params=params,
110
- ),
111
- stream=True,
112
- )
113
- status = httpx.codes(res.status_code)
114
- auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
115
- if (status is not httpx.codes.OK and
116
- status is not httpx.codes.TOO_MANY_REQUESTS
117
- ):
118
- res.close()
119
- return status, auth
120
- if "text/event-stream" in res.headers['content-type']:
121
- return sse_stream(res), auth
122
- res.read()
123
- if status is httpx.codes.TOO_MANY_REQUESTS:
124
- return QuotaInfos(**res.json()), auth # pragma: no cover
125
- if status is httpx.codes.OK:
126
- return ScheduleResponse(**res.json()), auth
127
- assert_never(status)
128
-
129
- def allow(
130
- self,
131
- allow_token: str,
132
- pid: int,
133
- ):
134
- res = self.client.post('/allow', params={
135
- 'allowToken': allow_token,
136
- 'pid': pid,
137
- })
138
- return httpx.codes(res.status_code)
139
-
140
- def release(
141
- self,
142
- allow_token: str,
143
- fail: bool = False,
144
- ) -> httpx.codes:
145
- res = self.client.post('/release', params={
146
- 'allowToken': allow_token,
147
- 'fail': fail,
148
- })
149
- return httpx.codes(res.status_code)
150
-
151
- def get_queue_size(self) -> int:
152
- res = self.client.get('/queue-size')
153
- assert res.status_code == 200, res.status_code
154
- size = res.json()
155
- assert isinstance(size, int)
156
- return size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/client.py DELETED
@@ -1,239 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import os
6
- import time
7
- import warnings
8
- from datetime import timedelta
9
-
10
- import gradio as gr
11
- import httpx
12
- from packaging import version
13
- from typing_extensions import assert_never
14
-
15
- from .. import utils
16
- from ..config import Config
17
- from .api import APIClient
18
- from .api import AuthLevel
19
- from .api import QuotaInfos
20
- from .api import ScheduleResponse
21
- from .gradio import HTMLError
22
- from .gradio import get_event
23
- from .gradio import supports_auth
24
-
25
-
26
- TOKEN_HEADER = 'X-IP-Token'
27
- DEFAULT_SCHEDULE_DURATION = 60
28
-
29
- QUOTA_MESSAGE = "You have exceeded your GPU quota"
30
- UNUSED_MESSAGE = "GPU device not used"
31
- NO_GPU_MESSAGE_REGULAR = "No GPU was available"
32
- NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60s"
33
-
34
- SIGNUP_ON_HF_TXT = "Create a free account"
35
- SIGNUP_ON_HF_URL = "https://huggingface.co/join"
36
- SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro"
37
- SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription"
38
-
39
-
40
- def api_client():
41
- assert Config.zero_device_api_url is not None
42
- httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
43
- return APIClient(httpx_client)
44
-
45
-
46
- def startup_report():
47
- retries, max_retries = 0, 2
48
- client = api_client()
49
- while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
50
- time.sleep(1)
51
- if (retries := retries + 1) > max_retries:
52
- raise RuntimeError("Error while initializing ZeroGPU: NotFound")
53
- if status is not httpx.codes.OK: # pragma: no cover
54
- raise RuntimeError("Error while initializing ZeroGPU: Unknown")
55
-
56
-
57
- def html_string(html_contents: str, text_contents: str): # pragma: no cover
58
- class HTMLString(str):
59
- def __str__(self):
60
- return text_contents
61
- return HTMLString(html_contents)
62
-
63
-
64
- def _toast_action(
65
- auth: AuthLevel | None,
66
- supports_html: bool,
67
- pro_message: str,
68
- unlogged_desc: str,
69
- logged_desc: str,
70
- ending: str,
71
- ) -> tuple[str, str]: # pragma: no cover
72
- if not supports_auth() or auth == 'pro':
73
- return pro_message, pro_message
74
- html = ""
75
- link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL
76
- text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT
77
- desc = unlogged_desc if auth is None else logged_desc
78
- desc += f" {ending}."
79
- style = ";".join([
80
- "white-space: nowrap",
81
- "text-underline-offset: 2px",
82
- "color: var(--body-text-color)",
83
- ])
84
- if supports_html:
85
- html += f'<a style="{style}" href="{link}">'
86
- html += text
87
- if supports_html:
88
- html += '</a> '
89
- html += desc
90
- markdown = f'[{text}]({link}) {desc}'
91
- return html, markdown
92
-
93
-
94
- def schedule(
95
- task_id: int,
96
- request: gr.Request | None = None,
97
- duration: timedelta | None = None,
98
- _first_attempt: bool = True,
99
- ) -> ScheduleResponse:
100
-
101
- if not (gradio_version := version.parse(gr.__version__)).major >= 4: # pragma: no cover
102
- raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
103
-
104
- GRADIO_HTML_TOASTS = gradio_version.minor >= 39
105
-
106
- res, auth = api_client().schedule(
107
- cgroup_path=utils.self_cgroup_device_path(),
108
- task_id=task_id,
109
- token=_get_token(request),
110
- duration_seconds=duration.seconds if duration is not None else None,
111
- )
112
-
113
- if isinstance(res, ScheduleResponse):
114
- return res
115
-
116
- if isinstance(res, QuotaInfos): # pragma: no cover
117
- requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
118
- if res.wait < timedelta(0):
119
- raise gr.Error(
120
- f"The requested GPU duration ({requested}s) "
121
- f"is larger than the maximum allowed"
122
- )
123
- else:
124
- gpu = "Pro GPU" if auth == 'pro' else ("free GPU" if auth == 'regular' else "GPU")
125
- message = (
126
- f"You have exceeded your {gpu} quota "
127
- f"({requested}s requested vs. {res.left}s left)."
128
- )
129
- details_html, details_markdown = _toast_action(
130
- auth=auth,
131
- supports_html=GRADIO_HTML_TOASTS,
132
- pro_message=f"Try again in {res.wait}",
133
- unlogged_desc="to get more",
134
- logged_desc="to get 5x more",
135
- ending="usage quota",
136
- )
137
- message_html = f"{message} {details_html}"
138
- message_text = f"{message} {details_markdown}"
139
- raise HTMLError(html_string(message_html, message_text))
140
-
141
- if not isinstance(res, httpx.codes): # pragma: no cover
142
- gr.Info("Waiting for a GPU to become available")
143
- # TODO: Sign-up message if not authenticated (after some time ?)
144
- connection_event = get_event()
145
- if connection_event is None and request is not None:
146
- warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
147
- while True:
148
- try:
149
- event = next(res)
150
- except StopIteration:
151
- raise RuntimeError("Unexpected end of stream")
152
- except httpx.RemoteProtocolError:
153
- if not _first_attempt:
154
- raise RuntimeError("Error while re-trying after queue disconnect")
155
- return schedule(task_id, request, duration, _first_attempt=False)
156
- if event.event == 'ping':
157
- if connection_event is not None and not connection_event.alive:
158
- res.close()
159
- raise RuntimeError("Connection closed by visitor while queueing")
160
- continue
161
- if event.event == 'failed':
162
- details_html, details_markdown = _toast_action(
163
- auth=auth,
164
- supports_html=GRADIO_HTML_TOASTS,
165
- pro_message="Retry later",
166
- unlogged_desc="to get a higher",
167
- logged_desc="to get the highest",
168
- ending="priority in ZeroGPU queues",
169
- )
170
- message_html = f"{NO_GPU_MESSAGE_INQUEUE}. {details_html}"
171
- message_text = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}"
172
- raise HTMLError(html_string(message_html, message_text))
173
- if event.event == 'succeeded':
174
- assert event.data is not None
175
- if connection_event is not None and not connection_event.alive:
176
- release(event.data.allowToken)
177
- raise RuntimeError("Connection closed by visitor on queue success")
178
- gr.Info("Successfully acquired a GPU")
179
- return event.data
180
-
181
- if res is httpx.codes.SERVICE_UNAVAILABLE:
182
- raise gr.Error(NO_GPU_MESSAGE_REGULAR)
183
-
184
- # TODO: Find a way to log 'detail' response field
185
- raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
186
-
187
-
188
- def allow(allow_token: str) -> None:
189
- pid = os.getpid()
190
- assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
191
- assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
192
-
193
-
194
- def release(
195
- allow_token: str, *,
196
- fail: bool = False,
197
- allow_404: bool = False,
198
- ) -> None:
199
-
200
- res = api_client().release(
201
- allow_token=allow_token,
202
- fail=fail,
203
- )
204
-
205
- if res is httpx.codes.NO_CONTENT: # pragma: no cover
206
- try:
207
- gr.Warning(UNUSED_MESSAGE)
208
- except AttributeError:
209
- pass
210
- warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
211
- return None
212
-
213
- if res is httpx.codes.NOT_FOUND:
214
- if not allow_404:
215
- warnings.warn("ZeroGPU API /release warning: 404 Not Found")
216
- return None
217
-
218
- if httpx.codes.is_success(res):
219
- return None
220
-
221
- # TODO: Find a way to log 'detail' response field
222
- # TODO: Only raise in dev environment. Simply warn in production ?
223
- raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
224
-
225
-
226
- def _get_token(request: gr.Request | None) -> str | None:
227
-
228
- if request is None:
229
- return None
230
-
231
- headers = getattr(request, 'headers', None)
232
- if headers is None or not hasattr(headers, '__dict__'):
233
- raise gr.Error("Internal Gradio error")
234
-
235
- # Compatibility trick
236
- if not hasattr(headers, 'get'):
237
- headers = headers.__dict__ # pragma: no cover
238
-
239
- return headers.get(TOKEN_HEADER.lower())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/decorator.py DELETED
@@ -1,113 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import inspect
6
- import sys
7
- import warnings
8
- from datetime import timedelta
9
- from functools import partial
10
- from typing import Callable
11
- from typing import TypeVar
12
- from typing import overload
13
- from typing_extensions import ParamSpec
14
- from typing_extensions import Unpack
15
-
16
- from ..config import Config
17
- from .types import DynamicDuration
18
- from .types import EmptyKwargs
19
-
20
-
21
- P = ParamSpec('P')
22
- R = TypeVar('R')
23
-
24
-
25
- decorated_cache: dict[Callable, Callable] = {}
26
-
27
-
28
- @overload
29
- def GPU(
30
- task: None = None, *,
31
- duration: DynamicDuration[P] = None,
32
- ) -> Callable[[Callable[P, R]], Callable[P, R]]:
33
- ...
34
- @overload
35
- def GPU(
36
- task: Callable[P, R], *,
37
- duration: DynamicDuration[P] = None,
38
- ) -> Callable[P, R]:
39
- ...
40
- def GPU(
41
- task: Callable[P, R] | None = None, *,
42
- duration: DynamicDuration[P] = None,
43
- **kwargs: Unpack[EmptyKwargs],
44
- ) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
45
- """
46
- ZeroGPU decorator
47
-
48
- Basic usage:
49
- ```
50
- @spaces.GPU
51
- def fn(...):
52
- # CUDA is available here
53
- pass
54
- ```
55
-
56
- With custom duration:
57
- ```
58
- @spaces.GPU(duration=45) # Expressed in seconds
59
- def fn(...):
60
- # CUDA is available here
61
- pass
62
- ```
63
-
64
- Args:
65
- task (`Callable | None`): Python function that requires CUDA
66
- duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
67
-
68
- Returns:
69
- `Callable`: GPU-ready function
70
- """
71
- if "enable_queue" in kwargs:
72
- warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
73
- if task is None:
74
- return partial(_GPU, duration=duration)
75
- return _GPU(task, duration)
76
-
77
-
78
- def _GPU(
79
- task: Callable[P, R],
80
- duration: DynamicDuration[P],
81
- ) -> Callable[P, R]:
82
-
83
- if not Config.zero_gpu:
84
- return task
85
-
86
- from . import client
87
- from .wrappers import regular_function_wrapper
88
- from .wrappers import generator_function_wrapper
89
-
90
- if sys.version_info.minor < 9: # pragma: no cover
91
- raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
92
-
93
- if task in decorated_cache:
94
- # TODO: Assert same duration ?
95
- return decorated_cache[task] # type: ignore
96
-
97
- if inspect.iscoroutinefunction(task):
98
- raise NotImplementedError
99
-
100
- if inspect.isgeneratorfunction(task):
101
- decorated = generator_function_wrapper(task, duration)
102
- else:
103
- decorated = regular_function_wrapper(task, duration)
104
-
105
- setattr(decorated, 'zerogpu', None)
106
-
107
- client.startup_report()
108
- decorated_cache.update({
109
- task: decorated,
110
- decorated: decorated,
111
- })
112
-
113
- return decorated # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/gradio.py DELETED
@@ -1,150 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- from functools import wraps
6
- from packaging import version
7
- from typing import Callable
8
- from typing import NamedTuple
9
- from typing import TYPE_CHECKING
10
- import warnings
11
-
12
- import gradio as gr
13
- from gradio.context import Context
14
- from gradio.context import LocalContext
15
- from gradio.helpers import Progress
16
- from gradio.helpers import TrackedIterable
17
- from gradio.queueing import Queue
18
- from typing_extensions import ParamSpec
19
-
20
- from ..utils import SimpleQueue
21
- from .types import GeneratorResQueueResult
22
- from .types import GradioQueueEvent
23
- from .types import RegularResQueueResult
24
-
25
-
26
- QUEUE_RPC_METHODS = [
27
- "set_progress",
28
- "log_message",
29
- ]
30
-
31
-
32
- class GradioPartialContext(NamedTuple):
33
- event_id: str | None
34
- in_event_listener: bool
35
- progress: Progress | None
36
-
37
- @staticmethod
38
- def get():
39
- TrackedIterable.__reduce__ = tracked_iterable__reduce__
40
- return GradioPartialContext(
41
- event_id=LocalContext.event_id.get(),
42
- in_event_listener=LocalContext.in_event_listener.get(),
43
- progress=LocalContext.progress.get(),
44
- )
45
-
46
- @staticmethod
47
- def apply(context: 'GradioPartialContext'):
48
- LocalContext.event_id.set(context.event_id)
49
- LocalContext.in_event_listener.set(context.in_event_listener)
50
- LocalContext.progress.set(context.progress)
51
-
52
-
53
- def get_queue_instance():
54
- blocks = LocalContext.blocks.get()
55
- if blocks is None: # pragma: no cover
56
- return None
57
- return blocks._queue
58
-
59
-
60
- def get_event():
61
- queue = get_queue_instance()
62
- event_id = LocalContext.event_id.get()
63
- if queue is None:
64
- return None
65
- if event_id is None: # pragma: no cover
66
- return None
67
- for job in queue.active_jobs:
68
- if job is None: # pragma: no cover
69
- continue
70
- for event in job:
71
- if event._id == event_id:
72
- return event
73
-
74
-
75
- def get_server_port() -> int | None:
76
- from_request_context = True
77
- if (blocks := LocalContext.blocks.get()) is None: # Request
78
- from_request_context = False
79
- if (blocks := Context.root_block) is None: # Caching
80
- return None
81
- if (server := getattr(blocks, 'server', None)) is None:
82
- if from_request_context:
83
- warnings.warn("Gradio: No blocks.server inside a request") # pragma: no cover
84
- return -1
85
- if TYPE_CHECKING:
86
- assert (server := blocks.server)
87
- return server.config.port
88
-
89
-
90
- def try_process_queue_event(method_name: str, *args, **kwargs):
91
- queue = get_queue_instance()
92
- if queue is None: # pragma: no cover
93
- warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
94
- return
95
- method = getattr(queue, method_name, None)
96
- assert callable(method)
97
- method(*args, **kwargs)
98
-
99
-
100
- def patch_gradio_queue(
101
- res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
102
- ):
103
-
104
- def rpc_method(method_name: str):
105
- def method(*args, **kwargs):
106
- if args and isinstance(args[0], Queue):
107
- args = args[1:] # drop `self`
108
- res_queue.put(GradioQueueEvent(method_name, args, kwargs))
109
- return method
110
-
111
- for method_name in QUEUE_RPC_METHODS:
112
- if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
113
- warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
114
- continue
115
- if not callable(method): # pragma: no cover
116
- warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
117
- continue
118
- setattr(Queue, method_name, rpc_method(method_name))
119
-
120
- TrackedIterable.__reduce__ = tracked_iterable__reduce__
121
-
122
-
123
- def tracked_iterable__reduce__(self):
124
- res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
125
- cls, base, state, *_ = res
126
- return cls, base,{**state, **{
127
- 'iterable': None,
128
- '_tqdm': None,
129
- }}
130
-
131
-
132
- def supports_auth():
133
- return version.parse(gr.__version__) >= version.Version('4.27.0')
134
-
135
-
136
- Param = ParamSpec('Param')
137
-
138
- def one_launch(task: Callable[Param, None], *task_args: Param.args, **task_kwargs: Param.kwargs):
139
- _launch = gr.Blocks.launch
140
- @wraps(gr.Blocks.launch)
141
- def launch(*args, **kwargs):
142
- task(*task_args, **task_kwargs)
143
- gr.Blocks.launch = _launch
144
- return gr.Blocks.launch(*args, **kwargs)
145
- gr.Blocks.launch = launch
146
-
147
-
148
- class HTMLError(gr.Error):
149
- def __str__(self): # pragma: no cover
150
- return self.message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/__init__.py DELETED
@@ -1,42 +0,0 @@
1
- """
2
- """
3
-
4
- from ...config import Config
5
-
6
-
7
- try:
8
-
9
- import torch
10
-
11
- except ImportError:
12
-
13
- _patch = lambda *args, **kwargs: None
14
- _unpatch = lambda *args, **kwargs: None
15
- _pack = lambda *args, **kwargs: None
16
- _init = lambda *args, **kwargs: None
17
- _size = lambda *args, **kwargs: 0
18
- _move = lambda *args, **kwargs: None
19
- _is_in_bad_fork = lambda *args, **kwargs: False
20
-
21
- else:
22
-
23
- if Config.zero_gpu_v2:
24
- from . import patching as _patching
25
- else: # pragma: no cover
26
- from . import patching_legacy as _patching
27
-
28
- _patch = _patching.patch
29
- _unpatch = _patching.unpatch
30
- _pack = _patching.pack
31
- _init = _patching.init
32
- _size = _patching.size
33
- _move = _patching.move
34
- _is_in_bad_fork = _patching.is_in_bad_fork
35
-
36
- patch = _patch
37
- unpatch = _unpatch
38
- pack = _pack
39
- init = _init
40
- size = _size
41
- move = _move
42
- is_in_bad_fork = _is_in_bad_fork
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/bitsandbytes.py DELETED
@@ -1,162 +0,0 @@
1
- """
2
- """
3
- # pyright: reportPrivateImportUsage=false
4
-
5
- from __future__ import annotations
6
-
7
- import importlib
8
- from contextlib import contextmanager
9
- from importlib import metadata
10
- from types import ModuleType
11
- from typing import TYPE_CHECKING
12
- from typing import Tuple
13
-
14
- import torch
15
- from packaging import version
16
-
17
- if TYPE_CHECKING:
18
- import torch as Torch
19
-
20
-
21
- @contextmanager
22
- def cuda_unavailable(torch: ModuleType):
23
- _is_available = torch.cuda.is_available
24
- torch.cuda.is_available = lambda: False
25
- yield
26
- torch.cuda.is_available = _is_available
27
-
28
-
29
- def maybe_import_bitsandbytes():
30
- try:
31
- import torch
32
- except ImportError: # pragma: no cover
33
- return None
34
- with cuda_unavailable(torch):
35
- try:
36
- import bitsandbytes
37
- except ImportError:
38
- bitsandbytes = None
39
- else:
40
- if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
41
- raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
42
- print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
43
- return bitsandbytes
44
-
45
-
46
- if (bnb := maybe_import_bitsandbytes()):
47
-
48
- from torch.utils.weak import WeakTensorKeyDictionary
49
-
50
- with cuda_unavailable(torch):
51
- from bitsandbytes import cextension
52
- from bitsandbytes import functional
53
- try: # bitsandbytes < 0.44
54
- from bitsandbytes.cuda_setup.main import CUDASetup
55
- except ModuleNotFoundError: # pragma: no cover
56
- CUDASetup = None
57
- from bitsandbytes.nn import Int8Params
58
- from bitsandbytes.nn import Params4bit
59
-
60
- _param_to_8bit = Int8Params.to # type: ignore
61
- _param_cuda_8bit = Int8Params.cuda
62
- _param_to_4bit = Params4bit.to # type: ignore
63
- _param_cuda_4bit = Params4bit.cuda
64
-
65
- TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
66
-
67
- to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
68
- to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
69
-
70
- def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
71
- parsed = torch._C._nn._parse_to(*args, **kwargs)
72
- device, *_ = parsed
73
- if not isinstance(device, torch.device): # pragma: no cover
74
- return _param_to_8bit(self, *args, **kwargs)
75
- if device.type != 'cuda':
76
- return _param_to_8bit(self, *args, **kwargs)
77
- to_ops_8bit[self] = parsed
78
- return self
79
-
80
- def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
81
- parsed = torch._C._nn._parse_to(*args, **kwargs)
82
- device, *_ = parsed
83
- if not isinstance(device, torch.device): # pragma: no cover
84
- return _param_to_4bit(self, *args, **kwargs)
85
- if device.type != 'cuda':
86
- return _param_to_4bit(self, *args, **kwargs)
87
- to_ops_4bit[self] = parsed
88
- return self
89
-
90
- def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
91
- if device is None: # pragma: no cover
92
- return True
93
- if isinstance(device, int):
94
- return True
95
- if isinstance(device, str): # pragma: no cover
96
- device = torch.device(device)
97
- return device.type == 'cuda' # pragma: no cover
98
-
99
- def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
100
- if not _cuda_op_arg_check(device): # pragma: no cover
101
- # Let PyTorch handle the fail
102
- return _param_cuda_8bit(self, device, **kwargs)
103
- to_ops_8bit[self] = None
104
- return self
105
-
106
- def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
107
- if not _cuda_op_arg_check(device): # pragma: no cover
108
- # Let PyTorch handle the fail
109
- return _param_cuda_4bit(self, device, **kwargs)
110
- to_ops_4bit[self] = None
111
- return self
112
-
113
- def _patch():
114
- Int8Params.to = _to_op_register_8bit # type: ignore
115
- Int8Params.cuda = _cuda_op_register_8bit # type: ignore
116
- Params4bit.to = _to_op_register_4bit # type: ignore
117
- Params4bit.cuda = _cuda_op_register_4bit # type: ignore
118
-
119
- def _unpatch():
120
- Int8Params.to = _param_to_8bit # type: ignore
121
- Int8Params.cuda = _param_cuda_8bit
122
- Params4bit.to = _param_to_4bit # type: ignore
123
- Params4bit.cuda = _param_cuda_4bit
124
-
125
- def _move():
126
- if CUDASetup is not None:
127
- CUDASetup._instance = None
128
- importlib.reload(cextension)
129
- functional.lib = cextension.lib
130
- for op in to_ops_8bit.items():
131
- tensor, parsed_args = op
132
- if parsed_args:
133
- _, dtype, _, memory_format = parsed_args
134
- else:
135
- dtype, memory_format = None, None
136
- tensor.data = _param_to_8bit(tensor,
137
- device='cuda',
138
- dtype=dtype,
139
- memory_format=memory_format,
140
- ) # type: ignore
141
- for op in to_ops_4bit.items():
142
- tensor, parsed_args = op
143
- if parsed_args:
144
- _, dtype, _, memory_format = parsed_args
145
- else:
146
- dtype, memory_format = None, None
147
- tensor.data = _param_to_4bit(tensor,
148
- device='cuda',
149
- dtype=dtype,
150
- memory_format=memory_format,
151
- ) # type: ignore
152
-
153
- else:
154
-
155
- _patch = lambda: None
156
- _unpatch = lambda: None
157
- _move = lambda: None
158
-
159
-
160
- patch = _patch
161
- unpatch = _unpatch
162
- move = _move
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/packing.py DELETED
@@ -1,209 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import time
6
-
7
- import ctypes
8
- import os
9
- from concurrent.futures import as_completed
10
- from concurrent.futures import ThreadPoolExecutor
11
- from contextvars import copy_context
12
- from dataclasses import dataclass
13
- from queue import Queue
14
- from typing import Callable
15
-
16
- from ...utils import debug
17
-
18
- import torch
19
- from typing_extensions import TypeAlias
20
-
21
-
22
- PAGE_SIZE = 4096
23
- TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
24
- VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2)
25
-
26
- BUFFER_SIZE = 64 * 2**20
27
- BUFFER_COUNT = 2
28
-
29
-
30
- TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]'
31
-
32
- @dataclass
33
- class ZeroGPUTensorPack:
34
- base_dir: str
35
- batches: list[list[TensorWithSizes]]
36
- big_tensors: list[TensorWithSizes]
37
- fakes: dict[torch.Tensor, list[torch.Tensor]]
38
- total_size: int
39
- def path(self):
40
- return f'{self.base_dir}/{id(self)}'
41
- def __del__(self):
42
- try:
43
- os.remove(self.path())
44
- except FileNotFoundError: # pragma: no cover
45
- pass
46
-
47
-
48
- def write(fd: int, tensor: torch.Tensor):
49
- clone = torch.empty_like(tensor)
50
- size = clone.untyped_storage().size() # pyright: ignore [reportAttributeAccessIssue]
51
- buffer = torch.UntypedStorage(VM_MAX_SIZE)
52
- buffer_ptr = buffer.data_ptr()
53
- offset = -buffer_ptr % PAGE_SIZE
54
- padding = -size % PAGE_SIZE
55
- clone.set_(buffer[offset:offset+size], 0, clone.shape, clone.stride()) # pyright: ignore [reportArgumentType]
56
- clone.copy_(tensor)
57
- mv = memoryview((ctypes.c_char * (size+padding)).from_address(buffer_ptr+offset))
58
- written_bytes = 0
59
- while written_bytes < size:
60
- written_bytes += os.write(fd, mv[written_bytes:])
61
-
62
-
63
- def pack_tensors(
64
- tensors: set[torch.Tensor],
65
- fakes: dict[torch.Tensor, list[torch.Tensor]],
66
- offload_dir: str,
67
- callback: Callable[[int]] | None = None,
68
- ):
69
-
70
- callback = (lambda bytes: None) if callback is None else callback
71
-
72
- batches: list[list[TensorWithSizes]] = []
73
- big_tensors: list[TensorWithSizes] = []
74
-
75
- tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = []
76
- for tensor in tensors:
77
- size = tensor.numel() * tensor.element_size()
78
- aligned_size = size + (-size % PAGE_SIZE)
79
- tensors_with_sizes += [(tensor, size, aligned_size)]
80
-
81
- current_batch, current_size = [], 0
82
- for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]):
83
- if aligned_size > BUFFER_SIZE:
84
- big_tensors += [(tensor, size, aligned_size)]
85
- continue
86
- current_size += aligned_size
87
- if current_size > BUFFER_SIZE:
88
- batches += [current_batch]
89
- current_batch, current_size = [(tensor, size, aligned_size)], aligned_size
90
- else:
91
- current_batch += [(tensor, size, aligned_size)]
92
-
93
- if current_batch:
94
- batches += [current_batch]
95
-
96
- get_meta = {tensor: torch.empty_like(tensor) for tensor in tensors}
97
- batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches]
98
- big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors]
99
- fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()}
100
-
101
- pack = ZeroGPUTensorPack(
102
- base_dir=offload_dir,
103
- batches=batches_meta,
104
- big_tensors=big_tensors_meta,
105
- fakes=fakes_meta,
106
- total_size=sum([size for _, size, _ in tensors_with_sizes]),
107
- )
108
-
109
- fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT)
110
- try:
111
- total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch])
112
- total_asize += sum([aligned_size for *_, aligned_size in big_tensors])
113
- if total_asize > 0:
114
- os.posix_fallocate(fd, 0, total_asize)
115
- for batch in batches:
116
- for tensor, size, _ in batch:
117
- write(fd, tensor)
118
- callback(size)
119
- for tensor, size, _ in big_tensors:
120
- write(fd, tensor)
121
- callback(size)
122
- return pack
123
- finally:
124
- os.close(fd)
125
-
126
-
127
- def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int]] | None = None):
128
-
129
- callback = (lambda bytes: None) if callback is None else callback
130
-
131
- free_buffers: Queue[torch.Tensor] = Queue()
132
- read_buffers: Queue[torch.Tensor] = Queue()
133
-
134
- for _ in range(BUFFER_COUNT):
135
- free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory())
136
-
137
- def read(fd: int, buffer: torch.Tensor, size: int):
138
- mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr()))
139
- read_bytes = 0
140
- while read_bytes < size:
141
- read_bytes += os.readv(fd, [mv[read_bytes:]])
142
-
143
- def disk_to_pin(fd: int):
144
- for batch in pack.batches:
145
- buffer = free_buffers.get()
146
- batch_size = sum([aligned_size for *_, aligned_size in batch])
147
- read(fd, buffer, batch_size)
148
- read_buffers.put(buffer)
149
- for *_, aligned_size in pack.big_tensors:
150
- read_bytes = 0
151
- while read_bytes < aligned_size:
152
- buffer = free_buffers.get()
153
- read_size = min(BUFFER_SIZE, aligned_size - read_bytes)
154
- read(fd, buffer, read_size)
155
- read_buffers.put(buffer)
156
- read_bytes += read_size
157
-
158
- def pin_to_cuda():
159
- total_duration_in_callback = 0
160
- for batch in pack.batches:
161
- buffer = read_buffers.get()
162
- offset = 0
163
- cuda_storages = []
164
- for tensor, size, aligned_size in batch:
165
- cuda_storages += [buffer[offset:offset+size].cuda(non_blocking=True)]
166
- offset += aligned_size
167
- torch.cuda.synchronize()
168
- free_buffers.put(buffer)
169
- batch_total_size = 0
170
- for (tensor, size, _), cuda_storage in zip(batch, cuda_storages):
171
- cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
172
- cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
173
- for fake in pack.fakes[tensor]:
174
- fake.data = cuda_tensor
175
- batch_total_size += size
176
- t0 = time.perf_counter()
177
- callback(batch_total_size)
178
- total_duration_in_callback += time.perf_counter() - t0
179
- for tensor, size, _ in pack.big_tensors:
180
- cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda')
181
- offset = 0
182
- while offset < size:
183
- buffer = read_buffers.get()
184
- read_size = min(BUFFER_SIZE, size - offset)
185
- cuda_storage[offset:offset+read_size] = buffer[:read_size]
186
- offset += read_size
187
- torch.cuda.synchronize() # Probably not needed
188
- free_buffers.put(buffer)
189
- t0 = time.perf_counter()
190
- callback(read_size)
191
- total_duration_in_callback += time.perf_counter() - t0
192
- cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
193
- cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
194
- for fake in pack.fakes[tensor]:
195
- fake.data = cuda_tensor
196
-
197
- debug(f"{total_duration_in_callback=}")
198
-
199
- with ThreadPoolExecutor(2) as e:
200
- fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT)
201
- try:
202
- futures = [
203
- e.submit(copy_context().run, disk_to_pin, fd),
204
- e.submit(copy_context().run, pin_to_cuda),
205
- ]
206
- for future in as_completed(futures):
207
- future.result()
208
- finally:
209
- os.close(fd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/patching.py DELETED
@@ -1,386 +0,0 @@
1
- """
2
- """
3
- # pyright: reportPrivateImportUsage=false
4
-
5
- from __future__ import annotations
6
-
7
- import gc
8
- import multiprocessing
9
- import os
10
- from collections import defaultdict
11
- from concurrent.futures import ProcessPoolExecutor
12
- from concurrent.futures import ThreadPoolExecutor
13
- from contextlib import nullcontext
14
- from contextvars import copy_context
15
- from types import SimpleNamespace
16
- from typing import Any
17
- from typing import Callable
18
-
19
- import torch
20
- from torch.overrides import TorchFunctionMode
21
- from torch.overrides import resolve_name
22
- from torch.utils._python_dispatch import TorchDispatchMode
23
- from torch.utils._pytree import tree_map_only
24
- from torch.utils.weak import WeakTensorKeyDictionary
25
-
26
- from ...config import Config
27
- from ...utils import malloc_trim
28
- from ..tqdm import tqdm
29
- from . import bitsandbytes
30
- from .packing import ZeroGPUTensorPack
31
- from .packing import pack_tensors
32
- from .packing import pack_to_cuda
33
- from .types import AliasId
34
-
35
-
36
- # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
37
- CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
38
- CUDA_TOTAL_MEMORY = 42144366592
39
- CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
40
- CUDA_DEVICE_CAPABILITY = (8, 0)
41
- CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
42
-
43
- OPS_INPUTS_CHECK_NO_RETURN = (
44
- torch.Tensor.equal,
45
- )
46
-
47
- OPS_INPUT_CHECK_SELF_RETURN = (
48
- torch.Tensor.set_, # probably never dispatched
49
- torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue]
50
- )
51
-
52
- OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
53
-
54
- _tensor_make_subclass = torch.Tensor._make_subclass
55
- _asarray = torch.asarray
56
- _cuda_init = torch._C._cuda_init
57
- _cuda_exchange_device = torch.cuda._exchange_device
58
- _cuda_available = torch.cuda.is_available
59
- _cuda_device_count = torch.cuda.device_count
60
- _cuda_current_device = torch.cuda.current_device
61
- _cuda_mem_get_info = torch.cuda.mem_get_info
62
- _cuda_get_device_capability = torch.cuda.get_device_capability
63
- _cuda_get_device_properties = torch.cuda.get_device_properties
64
- _cuda_get_device_name = torch.cuda.get_device_name
65
-
66
- # PyTorch 2.3
67
- _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
68
-
69
-
70
- cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType]
71
-
72
- tensor_packs: list[ZeroGPUTensorPack] = []
73
-
74
- class ZeroGPUTensor(torch.Tensor):
75
- pass
76
-
77
- def empty_fake(tensor: torch.Tensor):
78
- fake = torch.empty_like(tensor, requires_grad=tensor.requires_grad)
79
- if fake.__class__ != tensor.__class__:
80
- fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType]
81
- return fake
82
-
83
- class ZeroGPUFunctionMode(TorchFunctionMode):
84
-
85
- def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
86
-
87
- kwargs = {} if kwargs is None else kwargs
88
-
89
- if func == torch._C._nn._parse_to:
90
- return func(*args, **kwargs)
91
-
92
- # Redispatch: tensor.cuda() -> tensor.to(device='cuda')
93
- if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
94
- memory_format = kwargs.get('memory_format')
95
- return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
96
- 'device': 'cuda' if func == torch.Tensor.cuda else 'cpu',
97
- **({'memory_format': memory_format} if memory_format is not None else {}),
98
- })
99
-
100
- # Redispatch: tensor.to('cuda') -> tensor.to(device='cuda')
101
- if func == torch.Tensor.to and len(args) > 1:
102
- device, dtype, _, memory_format = torch._C._nn._parse_to(*args[1:], **kwargs)
103
- return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
104
- 'device': device,
105
- 'dtype': dtype,
106
- 'memory_format': memory_format,
107
- })
108
-
109
- if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue]
110
- self, target = args
111
- if target in cuda_aliases:
112
- if (target_original := cuda_aliases[target]) is None:
113
- raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target))
114
- original = empty_fake(self)
115
- original.data = target_original
116
- cuda_aliases[self] = original
117
- elif self in cuda_aliases:
118
- del cuda_aliases[self]
119
- self.data = target
120
- return
121
-
122
- if func == torch.Tensor.device.__get__:
123
- tensor, = args
124
- if tensor in cuda_aliases:
125
- return torch.device('cuda', index=0)
126
-
127
- elif func == torch.Tensor.__repr__:
128
- tensor, = args
129
- if tensor in cuda_aliases:
130
- if (original := cuda_aliases[tensor]) is None:
131
- original = tensor.to('meta')
132
- original_class = original.__class__
133
- original.__class__ = ZeroGPUTensor
134
- try:
135
- return func(original, **kwargs)
136
- finally:
137
- original.__class__ = original_class
138
-
139
- elif func == torch.Tensor.untyped_storage:
140
- tensor, = args
141
- if tensor in cuda_aliases:
142
- if (original := cuda_aliases[tensor]) is None:
143
- raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
144
- res = func(original, **kwargs)
145
- res._zerogpu = True
146
- return res
147
-
148
- cuda: bool | None = None
149
-
150
- # Handle device kwarg
151
- if (device := kwargs.get('device')) is not None:
152
- device = torch.device(device)
153
- if device.type == 'cuda':
154
- kwargs['device'] = torch.device('cpu')
155
- cuda = True
156
- else:
157
- cuda = False
158
-
159
- # Swap fake inputs with original data
160
- swapped = {}
161
- inputs_are_cuda = set()
162
- def swap(tensor: torch.Tensor):
163
- nonlocal inputs_are_cuda
164
- if tensor not in cuda_aliases:
165
- inputs_are_cuda |= {False}
166
- return tensor
167
- if (original := cuda_aliases[tensor]) is None:
168
- raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
169
- swapped[original] = tensor
170
- inputs_are_cuda |= {True}
171
- return original
172
- args_ = tree_map_only(torch.Tensor, swap, args)
173
- kwargs_ = tree_map_only(torch.Tensor, swap, kwargs)
174
- if inputs_are_cuda == {True}:
175
- if cuda is not False:
176
- cuda = True
177
-
178
- res = func(*args_, **kwargs_)
179
-
180
- # Re-generate swapped fakes in case of mutation
181
- for original, fake in swapped.items():
182
- fake.data = empty_fake(original)
183
-
184
- # Special case for Tensor indexing where only 'self' matters
185
- if func in {
186
- torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue]
187
- torch.Tensor.__getitem__, # PyTorch 2.4+
188
- }:
189
- self = args[0]
190
- cuda = self in cuda_aliases
191
- inputs_are_cuda = {cuda}
192
-
193
- # Emulate device check
194
- if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN:
195
- self = None
196
- if len(args_) >= 1 and isinstance(args_[0], torch.Tensor):
197
- self = args_[0]
198
- # Only raise if func does not return its first input (Tensor.copy_)
199
- if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN:
200
- if inputs_are_cuda == {True, False}:
201
- raise RuntimeError(
202
- "Expected all tensors to be on the same device, "
203
- "but found at least two devices, cuda:0 (ZeroGPU) and cpu!"
204
- )
205
-
206
- # Register output
207
- def register(tensor: torch.Tensor):
208
- if tensor in swapped and cuda is not False:
209
- return swapped[tensor]
210
- if cuda is not True:
211
- return tensor
212
- fake = empty_fake(tensor)
213
- cuda_aliases[fake] = tensor
214
- return fake
215
-
216
- return tree_map_only(torch.Tensor, register, res)
217
-
218
- # When enabling DispatchMode, some aten ops are dispatched to FunctionMode
219
- # We are using it for aten.alias.default and aten.set_.source_Tensor
220
- class DefaultDispatchMode(TorchDispatchMode):
221
- def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
222
- return func(*args, **(kwargs or {}))
223
-
224
-
225
- function_mode = ZeroGPUFunctionMode()
226
- dispatch_mode = DefaultDispatchMode()
227
-
228
-
229
- def _untyped_storage_new_register(*args, **kwargs):
230
- cuda = False
231
- if (device := kwargs.get('device')) is not None and device.type == 'cuda':
232
- cuda = True
233
- del kwargs['device']
234
- storage = torch._C.StorageBase.__new__(*args, **kwargs)
235
- if cuda:
236
- storage._zerogpu = True
237
- return storage
238
-
239
- @property
240
- def _untyped_storage_device(self):
241
- if hasattr(self, '_zerogpu'):
242
- return torch.device('cuda', index=0)
243
- return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue]
244
-
245
- # Force dispatch
246
- def _tensor_make_subclass_function_mode(*args, **kwargs):
247
- with torch._C.DisableTorchFunction():
248
- return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
249
- def _asarray_function_mode(*args, **kwargs):
250
- with torch._C.DisableTorchFunction():
251
- return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
252
-
253
- def _cuda_init_raise():
254
- raise RuntimeError(
255
- "CUDA must not be initialized in the main process "
256
- "on Spaces with Stateless GPU environment.\n"
257
- "You can look at this Stacktrace to find out "
258
- "which part of your code triggered a CUDA init"
259
- )
260
-
261
- def _cuda_dummy_exchange_device(device):
262
- assert device in {-1, 0}
263
- return device
264
-
265
- def patch():
266
- function_mode.__enter__()
267
- dispatch_mode.__enter__()
268
- # TODO: only patch bellow methods on current Thread to be consistent with TorchModes
269
- # (or hijack threading.Thread.__init__ to force Modes on all threads)
270
- torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue]
271
- torch.UntypedStorage.__new__ = _untyped_storage_new_register
272
- torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue]
273
- torch.asarray = _asarray_function_mode
274
- torch._C._cuda_init = _cuda_init_raise
275
- torch.cuda._exchange_device = _cuda_dummy_exchange_device
276
- torch.cuda.is_available = lambda: True
277
- torch.cuda.device_count = lambda: 1
278
- torch.cuda.current_device = lambda: 0
279
- torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
280
- torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
281
- torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
282
- torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
283
- # PyTorch 2.3
284
- if _cuda_maybe_exchange_device is not None: # pragma: no cover
285
- setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device)
286
- bitsandbytes.patch()
287
-
288
- def unpatch():
289
- try:
290
- dispatch_mode.__exit__(None, None, None)
291
- function_mode.__exit__(None, None, None)
292
- except RuntimeError:
293
- pass # patch() and unpatch() called from != threads
294
- torch.Tensor._make_subclass = _tensor_make_subclass
295
- torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
296
- torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue]
297
- torch.asarray = _asarray
298
- torch._C._cuda_init = _cuda_init
299
- torch.cuda._exchange_device = _cuda_exchange_device
300
- torch.cuda.is_available = _cuda_available
301
- torch.cuda.device_count = _cuda_device_count
302
- torch.cuda.current_device = _cuda_current_device
303
- torch.cuda.mem_get_info = _cuda_mem_get_info
304
- torch.cuda.get_device_capability = _cuda_get_device_capability
305
- torch.cuda.get_device_properties = _cuda_get_device_properties
306
- torch.cuda.get_device_name = _cuda_get_device_name
307
- # PyTorch 2.3
308
- if _cuda_maybe_exchange_device is not None: # pragma: no cover
309
- setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
310
- bitsandbytes.unpatch()
311
-
312
-
313
- def _total_unpacked_size():
314
- tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None]
315
- deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors}
316
- return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()])
317
-
318
-
319
- def _pack(offload_dir: str):
320
- # Pack to disk
321
- originals: set[torch.Tensor] = set()
322
- originals_dedup: dict[AliasId, torch.Tensor] = {}
323
- fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list)
324
- for fake, original in cuda_aliases.items():
325
- # TODO filter-out sparse Tensors
326
- if original is not None:
327
- original_id = AliasId.from_tensor(original)
328
- if original_id not in originals_dedup:
329
- originals_dedup[original_id] = original
330
- originals |= {original}
331
- fakes[originals_dedup[original_id]] += [fake]
332
- progress = tqdm(
333
- total=_total_unpacked_size(),
334
- unit='B',
335
- unit_scale=True,
336
- desc="ZeroGPU tensors packing",
337
- ) if tqdm is not None else nullcontext()
338
- with progress as progress:
339
- update = progress.update if progress is not None else lambda _: None
340
- pack = pack_tensors(originals, fakes, offload_dir, callback=update)
341
- tensor_packs.append(pack)
342
- # Free memory
343
- for fake_list in fakes.values():
344
- for fake in fake_list:
345
- cuda_aliases[fake] = None
346
-
347
- def pack():
348
- _pack(Config.zerogpu_offload_dir)
349
- gc.collect()
350
- malloc_trim()
351
-
352
- def init(nvidia_uuid: str):
353
- os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
354
- torch.Tensor([0]).cuda()
355
-
356
- def size():
357
- return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs])
358
-
359
- def _move(callback: Callable[[int]] | None = None):
360
- callback = callback if callback is not None else lambda _: None
361
- # CPU -> CUDA
362
- moved: dict[AliasId, torch.Tensor] = {}
363
- for fake, original in cuda_aliases.items():
364
- if original is not None:
365
- original_id = AliasId.from_tensor(original)
366
- if original_id not in moved:
367
- moved[original_id] = original.cuda()
368
- callback(fake.numel() * fake.element_size())
369
- for fake, original in cuda_aliases.items():
370
- if original is not None:
371
- fake.data = moved[AliasId.from_tensor(original)]
372
- # Disk -> CUDA
373
- for tensor_pack in tensor_packs:
374
- pack_to_cuda(tensor_pack, callback=callback)
375
- bitsandbytes.move()
376
-
377
- def move(callback: Callable[[int]] | None = None):
378
- callback = callback if callback is not None else lambda _: None
379
- with ThreadPoolExecutor(1) as e:
380
- e.submit(copy_context().run, _move, callback=callback).result()
381
- torch.cuda.synchronize()
382
-
383
- def is_in_bad_fork():
384
- with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
385
- f = e.submit(torch.cuda._is_in_bad_fork)
386
- return f.result()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/patching_legacy.py DELETED
@@ -1,266 +0,0 @@
1
- """
2
- """
3
- # pyright: reportPrivateImportUsage=false
4
-
5
- from __future__ import annotations
6
-
7
- import multiprocessing
8
- import os
9
- from concurrent.futures import ProcessPoolExecutor
10
- from contextlib import suppress
11
- from functools import partial
12
- from types import SimpleNamespace
13
- from typing import Any
14
- from typing import Callable
15
- from typing import Optional
16
- from typing import Tuple
17
-
18
- import torch
19
- from torch.utils.weak import WeakTensorKeyDictionary
20
-
21
- from ...config import Config
22
- from . import bitsandbytes
23
-
24
-
25
- # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
26
- CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
27
- CUDA_TOTAL_MEMORY = 42144366592
28
- CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
29
- CUDA_DEVICE_CAPABILITY = (8, 0)
30
- CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
31
-
32
- GENERIC_METHOD_NAMES = [
33
- 'arange',
34
- 'as_tensor',
35
- 'asarray',
36
- 'bartlett_window',
37
- 'blackman_window',
38
- 'empty',
39
- 'empty_like',
40
- 'empty_strided',
41
- 'eye',
42
- 'full',
43
- 'full_like',
44
- 'hamming_window',
45
- 'hann_window',
46
- 'kaiser_window',
47
- 'linspace',
48
- 'logspace',
49
- 'ones',
50
- 'ones_like',
51
- 'rand',
52
- 'rand_like',
53
- 'randint',
54
- 'randint_like',
55
- 'randn',
56
- 'randn_like',
57
- 'randperm',
58
- 'range',
59
- 'sparse_bsc_tensor',
60
- 'sparse_bsr_tensor',
61
- 'sparse_compressed_tensor',
62
- 'sparse_coo_tensor',
63
- 'sparse_csc_tensor',
64
- 'sparse_csr_tensor',
65
- 'tensor',
66
- 'tril_indices',
67
- 'triu_indices',
68
- 'zeros',
69
- 'zeros_like',
70
- ]
71
-
72
-
73
- TO_CUDA = (torch.device('cuda'), None, False, None)
74
-
75
- _tensor__deepcopy__ = torch.Tensor.__deepcopy__
76
- _tensor_to = torch.Tensor.to
77
- _tensor_cuda = torch.Tensor.cuda
78
- _tensor_cpu = torch.Tensor.cpu
79
- _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
80
- _cuda_init = torch._C._cuda_init
81
- _cuda_available = torch.cuda.is_available
82
- _cuda_device_count = torch.cuda.device_count
83
- _cuda_current_device = torch.cuda.current_device
84
- _cuda_mem_get_info = torch.cuda.mem_get_info
85
- _cuda_get_device_capability = torch.cuda.get_device_capability
86
- _cuda_get_device_properties = torch.cuda.get_device_properties
87
- _cuda_get_device_name = torch.cuda.get_device_name
88
-
89
- TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
90
-
91
- to_ops: dict[torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
92
-
93
- def _tensor_new_register(*args, **kwargs):
94
- new_tensor: torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
95
- if (base_tensor := new_tensor._base) is not None:
96
- if base_tensor in to_ops:
97
- to_ops[new_tensor] = to_ops[base_tensor]
98
- return new_tensor
99
-
100
- def _tensor_deepcopy_register(self: torch.Tensor, memo):
101
- new_tensor = _tensor__deepcopy__(self, memo)
102
- if isinstance(new_tensor, torch.Tensor):
103
- if self in to_ops:
104
- to_ops[new_tensor] = to_ops[self]
105
- return new_tensor
106
-
107
- @property
108
- def _tensor_device_property(self: torch.Tensor):
109
- if self in to_ops:
110
- return torch.device(type='cuda', index=0)
111
- del torch.Tensor.device
112
- try:
113
- return self.device
114
- finally:
115
- torch.Tensor.device = _tensor_device_property # type: ignore
116
-
117
- @property
118
- def _tensor_dtype_property(self: torch.Tensor):
119
- if self in to_ops:
120
- if (to_dtype := to_ops[self][1]) is not None:
121
- return to_dtype
122
- del torch.Tensor.dtype
123
- try:
124
- return self.dtype
125
- finally:
126
- torch.Tensor.dtype = _tensor_dtype_property # type: ignore
127
-
128
- def _to_op_register(self: torch.Tensor, *args, **kwargs):
129
- parsed = torch._C._nn._parse_to(*args, **kwargs)
130
- device, dtype, *_ = parsed
131
- try:
132
- to_args = to_ops.pop(self)
133
- except KeyError:
134
- to_args = None
135
- if device is None: # pyright: ignore [reportUnnecessaryComparison]
136
- if to_args is not None:
137
- to_ops[self] = (to_args[0], dtype, *to_args[2:])
138
- return self
139
- return _tensor_to(self, *args, **kwargs)
140
- if device.type != 'cuda':
141
- if to_args is not None:
142
- if (to_dtype := to_args[1]) is not None:
143
- kwargs = {'dtype': to_dtype, **kwargs}
144
- return _tensor_to(self, *args, **kwargs)
145
- to_ops[self] = parsed
146
- return self
147
-
148
- def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool:
149
- if device is None:
150
- return True
151
- if isinstance(device, int):
152
- return True
153
- if isinstance(device, str):
154
- device = torch.device(device)
155
- return device.type == 'cuda'
156
-
157
- def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs):
158
- if not _cuda_op_arg_check(device):
159
- # Let PyTorch handle the fail
160
- return _tensor_cuda(self, device, **kwargs)
161
- to_ops[self] = TO_CUDA
162
- return self
163
-
164
- def _cpu_op_remove(self: torch.Tensor, **kwargs):
165
- try:
166
- to_args = to_ops.pop(self)
167
- except KeyError:
168
- to_args = None
169
- if to_args is not None:
170
- if (to_dtype := to_args[1]) is not None:
171
- return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
172
- return _tensor_cpu(self, **kwargs)
173
-
174
- def _cuda_init_raise():
175
- raise RuntimeError(
176
- "CUDA must not be initialized in the main process "
177
- "on Spaces with Stateless GPU environment.\n"
178
- "You can look at this Stacktrace to find out "
179
- "which part of your code triggered a CUDA init"
180
- )
181
-
182
- def _generic_method_register(name: str, *args: Any, **kwargs: Any):
183
- try:
184
- device = torch.device(kwargs.get('device', "cpu"))
185
- except Exception:
186
- return _torch_generics[name](*args, **kwargs)
187
- if device.type != 'cuda':
188
- return _torch_generics[name](*args, **kwargs)
189
- tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
190
- to_ops[tensor] = TO_CUDA
191
- return tensor
192
-
193
- def patch():
194
- torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
195
- torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
196
- torch.Tensor.to = _to_op_register # type: ignore
197
- torch.Tensor.cuda = _cuda_op_register # type: ignore
198
- torch.Tensor.cpu = _cpu_op_remove # type: ignore
199
- if Config.zero_patch_torch_device:
200
- torch.Tensor.device = _tensor_device_property # type: ignore
201
- torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
202
- for name in GENERIC_METHOD_NAMES:
203
- setattr(torch, name, partial(_generic_method_register, name))
204
- torch._C._cuda_init = _cuda_init_raise
205
- torch.cuda.is_available = lambda: True
206
- torch.cuda.device_count = lambda: 1
207
- torch.cuda.current_device = lambda: 0
208
- torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
209
- torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
210
- torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
211
- torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
212
- bitsandbytes.patch()
213
-
214
- def unpatch():
215
- torch.Tensor.__deepcopy__ = _tensor__deepcopy__
216
- with suppress(AttributeError):
217
- del torch.Tensor.__new__
218
- torch.Tensor.to = _tensor_to
219
- torch.Tensor.cuda = _tensor_cuda
220
- torch.Tensor.cpu = _tensor_cpu
221
- with suppress(AttributeError):
222
- del torch.Tensor.device
223
- with suppress(AttributeError):
224
- del torch.Tensor.dtype
225
- for name in GENERIC_METHOD_NAMES:
226
- setattr(torch, name, _torch_generics[name])
227
- torch._C._cuda_init = _cuda_init
228
- torch.cuda.is_available = _cuda_available
229
- torch.cuda.device_count = _cuda_device_count
230
- torch.cuda.current_device = _cuda_current_device
231
- torch.cuda.mem_get_info = _cuda_mem_get_info
232
- torch.cuda.get_device_capability = _cuda_get_device_capability
233
- torch.cuda.get_device_properties = _cuda_get_device_properties
234
- torch.cuda.get_device_name = _cuda_get_device_name
235
- bitsandbytes.unpatch()
236
-
237
- def pack():
238
- pass
239
-
240
- def init(nvidia_uuid: str):
241
- os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
242
- torch.Tensor([0]).cuda() # CUDA init
243
-
244
- def size():
245
- return 0
246
-
247
- def move(callback: Callable[[int]] | None = None):
248
- for op in to_ops.items():
249
- tensor, parsed_args = op
250
- _, dtype, _, memory_format = parsed_args
251
- tensor.data = _tensor_to(tensor,
252
- device='cuda',
253
- dtype=dtype,
254
- memory_format=memory_format,
255
- ) # type: ignore
256
- bitsandbytes.move()
257
- torch.cuda.synchronize()
258
-
259
- def is_in_bad_fork():
260
- with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
261
- f = e.submit(torch.cuda._is_in_bad_fork)
262
- return f.result()
263
-
264
- def disable_cuda_intercept():
265
- torch.Tensor.to = _tensor_to
266
- torch.Tensor.cuda = _tensor_cuda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/types.py DELETED
@@ -1,23 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- from typing import NamedTuple
6
-
7
- import torch
8
-
9
-
10
- class AliasId(NamedTuple):
11
- data_ptr: int
12
- dtype: torch.dtype
13
- shape: tuple[int, ...]
14
- stride: tuple[int, ...]
15
-
16
- @classmethod
17
- def from_tensor(cls, tensor: torch.Tensor):
18
- return cls(
19
- tensor.data_ptr(),
20
- tensor.dtype,
21
- tensor.shape,
22
- tensor.stride(),
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/tqdm.py DELETED
@@ -1,24 +0,0 @@
1
- """
2
- """
3
-
4
- from multiprocessing.synchronize import RLock as MultiprocessingRLock
5
-
6
-
7
- try:
8
- from tqdm import tqdm as _tqdm
9
- except ImportError: # pragma: no cover
10
- _tqdm = None
11
-
12
-
13
- def remove_tqdm_multiprocessing_lock():
14
- if _tqdm is None: # pragma: no cover
15
- return
16
- tqdm_lock = _tqdm.get_lock()
17
- assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
18
- tqdm_lock.locks = [
19
- lock for lock in tqdm_lock.locks
20
- if not isinstance(lock, MultiprocessingRLock)
21
- ]
22
-
23
-
24
- tqdm = _tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/types.py DELETED
@@ -1,49 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
-
6
- from dataclasses import dataclass
7
- from datetime import timedelta
8
- from typing import Any
9
- from typing import Dict
10
- from typing import Tuple
11
- from typing import TypedDict
12
- from typing_extensions import Callable
13
- from typing_extensions import Generic
14
- from typing_extensions import ParamSpec
15
- from typing_extensions import TypeAlias
16
- from typing_extensions import TypeVar
17
-
18
-
19
- Params = Tuple[Tuple[object, ...], Dict[str, Any]]
20
- Res = TypeVar('Res')
21
- Param = ParamSpec('Param')
22
-
23
- class EmptyKwargs(TypedDict):
24
- pass
25
-
26
- @dataclass
27
- class OkResult(Generic[Res]):
28
- value: Res
29
- @dataclass
30
- class ExceptionResult:
31
- value: Exception
32
- @dataclass
33
- class AbortedResult:
34
- pass
35
- @dataclass
36
- class EndResult:
37
- pass
38
- @dataclass
39
- class GradioQueueEvent:
40
- method_name: str
41
- args: tuple[Any, ...]
42
- kwargs: dict[str, Any]
43
-
44
- RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
45
- GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
46
- YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
47
-
48
- Duration: TypeAlias = "int | timedelta"
49
- DynamicDuration: TypeAlias = "Duration | Callable[Param, Duration] | None"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/wrappers.py DELETED
@@ -1,418 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import multiprocessing
6
- import os
7
- import signal
8
- import traceback
9
- import warnings
10
- from concurrent.futures import ThreadPoolExecutor
11
- from contextlib import nullcontext
12
- from contextvars import copy_context
13
- from datetime import timedelta
14
- from functools import partial
15
- from functools import wraps
16
- from multiprocessing.context import ForkProcess
17
- from pickle import PicklingError
18
- from queue import Empty
19
- from queue import Queue as ThreadQueue
20
- from threading import Thread
21
- from typing import TYPE_CHECKING
22
- from typing import Callable
23
- from typing import Generator
24
- from typing import Generic
25
- from typing_extensions import assert_never
26
-
27
- import psutil
28
-
29
- from ..config import Config
30
- from ..utils import debug
31
- from ..utils import drop_params
32
- from ..utils import gradio_request_var
33
- from ..utils import SimpleQueue as Queue
34
- from . import client
35
- from . import torch
36
- from .api import AllowToken
37
- from .api import NvidiaIndex
38
- from .api import NvidiaUUID
39
- from .gradio import GradioPartialContext
40
- from .gradio import get_server_port
41
- from .gradio import patch_gradio_queue
42
- from .gradio import try_process_queue_event
43
- from .tqdm import remove_tqdm_multiprocessing_lock
44
- from .tqdm import tqdm
45
- from .types import * # TODO: Please don't do that
46
-
47
-
48
- GENERATOR_GLOBAL_TIMEOUT = 20 * 60
49
-
50
- SPAWN_PROGRESS_CLEANUP = 0.1
51
- SPAWN_PROGRESS_INIT = 0.1
52
-
53
-
54
- Process = multiprocessing.get_context('fork').Process
55
- forked = False
56
-
57
-
58
- class Worker(Generic[Res]):
59
- process: ForkProcess
60
- arg_queue: Queue[tuple[Params, GradioPartialContext]]
61
- res_queue: Queue[Res | None]
62
- _sentinel: Thread
63
-
64
- def __init__(
65
- self,
66
- target: Callable[[
67
- Queue[tuple[Params, GradioPartialContext]],
68
- Queue[Res | None],
69
- AllowToken,
70
- NvidiaUUID,
71
- list[int],
72
- ], None],
73
- allow_token: str,
74
- nvidia_uuid: str,
75
- ):
76
- self._sentinel = Thread(target=self._close_on_exit, daemon=True)
77
- self.arg_queue = Queue()
78
- self.res_queue = Queue()
79
- debug(f"{self.arg_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
80
- debug(f"{self.res_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
81
- if (server_port := get_server_port()) is not None:
82
- fds = [c.fd for c in psutil.Process().connections() if c.laddr.port == server_port]
83
- debug(f"{fds=}")
84
- else:
85
- warnings.warn("Using a ZeroGPU function outside of Gradio caching or request might block the app")
86
- fds = []
87
- args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
88
- if TYPE_CHECKING:
89
- target(*args)
90
- self.process = Process(
91
- target=target,
92
- args=args,
93
- daemon=True,
94
- )
95
- self.process.start()
96
- self._sentinel.start()
97
-
98
- def _close_on_exit(self):
99
- self.process.join()
100
- self.arg_queue.close()
101
- self.res_queue.wlock_release()
102
- self.res_queue.put(None)
103
-
104
-
105
- def worker_init(
106
- res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
107
- allow_token: str,
108
- nvidia_uuid: str,
109
- fds: list[int],
110
- ) -> None | ExceptionResult:
111
- # Immediately close file descriptors
112
- for fd in fds:
113
- try:
114
- os.close(fd)
115
- except Exception as e: # pragma: no cover
116
- if isinstance(e, OSError) and e.errno == 9:
117
- continue
118
- traceback.print_exc()
119
- return ExceptionResult(e)
120
- progress = nullcontext()
121
- if tqdm is not None and Config.zero_gpu_v2:
122
- progress = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w'))
123
- try: # Unrecoverable init part
124
- patch_gradio_queue(res_queue)
125
- with progress as progress:
126
- current_progress = 0 # Gradio does not support float progress updates
127
- def update(n: float):
128
- nonlocal current_progress
129
- current_progress += n
130
- if progress is not None:
131
- progress.update(round(current_progress * 100) - progress.n)
132
- client.allow(allow_token)
133
- update(SPAWN_PROGRESS_CLEANUP)
134
- torch.unpatch()
135
- torch.init(nvidia_uuid)
136
- update(SPAWN_PROGRESS_INIT)
137
- callback = None
138
- if (transfer_size := torch.size()) > 0:
139
- remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT)
140
- callback = lambda n: update(n * remaining / transfer_size)
141
- torch.move(callback=callback)
142
- except Exception as e: # pragma: no cover
143
- traceback.print_exc()
144
- return ExceptionResult(e)
145
- try:
146
- remove_tqdm_multiprocessing_lock()
147
- except Exception: # pragma: no cover
148
- print("Error while trying to remove tqdm mp_lock:")
149
- traceback.print_exc()
150
-
151
-
152
- def process_duration(duration: Duration | None):
153
- if duration is None or isinstance(duration, timedelta):
154
- return duration
155
- return timedelta(seconds=duration)
156
-
157
-
158
- def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs):
159
- if not callable(duration):
160
- return duration
161
- return duration(*args, **kwargs)
162
-
163
-
164
- def regular_function_wrapper(
165
- task: Callable[Param, Res],
166
- duration: DynamicDuration[Param],
167
- ) -> Callable[Param, Res]:
168
-
169
- import gradio as gr
170
-
171
- request_var = gradio_request_var()
172
- workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
173
- task_id = id(task)
174
-
175
- @wraps(task)
176
- def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
177
-
178
- if forked:
179
- return task(*args, **kwargs)
180
-
181
- request = request_var.get()
182
- duration_ = static_duration(duration, *args, **kwargs)
183
- duration_ = process_duration(duration_)
184
- schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
185
- allow_token = schedule_response.allowToken
186
- nvidia_index = schedule_response.nvidiaIndex
187
- nvidia_uuid = schedule_response.nvidiaUUID
188
- release = partial(client.release, allow_token)
189
-
190
- try:
191
- worker = workers.pop(nvidia_index)
192
- except KeyError:
193
- worker = None
194
-
195
- if worker is not None and worker.process.is_alive() and schedule_response.idle:
196
- assert worker.arg_queue.empty()
197
- assert worker.res_queue.empty()
198
- else:
199
- worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
200
-
201
- try:
202
- worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
203
- except PicklingError: # TODO: detailed serialization diagnostic
204
- release(fail=True)
205
- raise
206
-
207
- while True:
208
- res = worker.res_queue.get()
209
- if res is None:
210
- release(fail=True, allow_404=True)
211
- raise gr.Error("GPU task aborted")
212
- if isinstance(res, ExceptionResult):
213
- release(fail=True)
214
- raise res.value
215
- if isinstance(res, OkResult):
216
- release()
217
- workers[nvidia_index] = worker
218
- return res.value
219
- if isinstance(res, GradioQueueEvent):
220
- try_process_queue_event(res.method_name, *res.args, **res.kwargs)
221
- continue
222
- assert_never(res)
223
-
224
-
225
- def thread_wrapper(
226
- arg_queue: Queue[tuple[Params, GradioPartialContext]],
227
- res_queue: Queue[RegularResQueueResult[Res] | None],
228
- allow_token: str,
229
- nvidia_uuid: str,
230
- fds: list[int],
231
- ):
232
- global forked
233
- forked = True
234
- signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
235
- initialized = False
236
- while True:
237
- try:
238
- (args, kwargs), gradio_context = arg_queue.get()
239
- except OSError:
240
- break
241
- if not initialized:
242
- if (res := worker_init(
243
- res_queue=res_queue,
244
- allow_token=allow_token,
245
- nvidia_uuid=nvidia_uuid,
246
- fds=fds,
247
- )) is not None:
248
- res_queue.put(res)
249
- return
250
- initialized = True
251
- GradioPartialContext.apply(gradio_context)
252
- context = copy_context()
253
- with ThreadPoolExecutor() as executor:
254
- future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
255
- try:
256
- res = future.result()
257
- except Exception as e:
258
- traceback.print_exc()
259
- res = ExceptionResult(e)
260
- else:
261
- res = OkResult(res)
262
- try:
263
- res_queue.put(res)
264
- except PicklingError as e:
265
- res_queue.put(ExceptionResult(e))
266
-
267
- # https://github.com/python/cpython/issues/91002
268
- if not hasattr(task, '__annotations__'):
269
- gradio_handler.__annotations__ = {}
270
-
271
- return gradio_handler
272
-
273
-
274
- def generator_function_wrapper(
275
- task: Callable[Param, Generator[Res, None, None]],
276
- duration: DynamicDuration[Param],
277
- ) -> Callable[Param, Generator[Res, None, None]]:
278
-
279
- import gradio as gr
280
-
281
- request_var = gradio_request_var()
282
- workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
283
- task_id = id(task)
284
-
285
- @wraps(task)
286
- def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
287
-
288
- if forked:
289
- yield from task(*args, **kwargs)
290
- return
291
-
292
- request = request_var.get()
293
- duration_ = static_duration(duration, *args, **kwargs)
294
- duration_ = process_duration(duration_)
295
- schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
296
- allow_token = schedule_response.allowToken
297
- nvidia_index = schedule_response.nvidiaIndex
298
- nvidia_uuid = schedule_response.nvidiaUUID
299
- release = partial(client.release, allow_token)
300
-
301
- try:
302
- worker = workers.pop(nvidia_index)
303
- except KeyError:
304
- worker = None
305
-
306
- if worker is not None and worker.process.is_alive() and schedule_response.idle:
307
- assert worker.arg_queue.empty()
308
- assert worker.res_queue.empty()
309
- else:
310
- worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
311
-
312
- try:
313
- worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
314
- except PicklingError: # TODO: detailed serialization diagnostic
315
- release(fail=True)
316
- raise
317
-
318
- yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
319
- def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
320
- while True:
321
- res = worker.res_queue.get()
322
- if res is None:
323
- release(fail=True, allow_404=True)
324
- yield_queue.put(AbortedResult())
325
- return
326
- if isinstance(res, ExceptionResult):
327
- release(fail=True)
328
- yield_queue.put(ExceptionResult(res.value))
329
- return
330
- if isinstance(res, EndResult):
331
- release()
332
- workers[nvidia_index] = worker
333
- yield_queue.put(EndResult())
334
- return
335
- if isinstance(res, OkResult):
336
- yield_queue.put(OkResult(res.value))
337
- continue
338
- if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
339
- try_process_queue_event(res.method_name, *res.args, **res.kwargs)
340
- continue
341
- debug(f"fill_yield_queue: assert_never({res=})")
342
- assert_never(res)
343
- from typing_extensions import assert_never
344
- with ThreadPoolExecutor() as e:
345
- f = e.submit(copy_context().run, fill_yield_queue, worker)
346
- f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
347
- while True:
348
- try:
349
- res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
350
- except Empty: # pragma: no cover
351
- debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
352
- raise
353
- if isinstance(res, AbortedResult):
354
- raise gr.Error("GPU task aborted")
355
- if isinstance(res, ExceptionResult):
356
- raise res.value
357
- if isinstance(res, EndResult):
358
- break
359
- if isinstance(res, OkResult):
360
- yield res.value
361
- continue
362
- debug(f"gradio_handler: assert_never({res=})")
363
- assert_never(res)
364
-
365
-
366
- def thread_wrapper(
367
- arg_queue: Queue[tuple[Params, GradioPartialContext]],
368
- res_queue: Queue[GeneratorResQueueResult[Res] | None],
369
- allow_token: str,
370
- nvidia_uuid: str,
371
- fds: list[int],
372
- ):
373
- global forked
374
- forked = True
375
- signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
376
- initialized = False
377
- while True:
378
- try:
379
- (args, kwargs), gradio_context = arg_queue.get()
380
- except OSError:
381
- break
382
- if not initialized:
383
- if (res := worker_init(
384
- res_queue=res_queue,
385
- allow_token=allow_token,
386
- nvidia_uuid=nvidia_uuid,
387
- fds=fds,
388
- )) is not None:
389
- res_queue.put(res)
390
- return
391
- initialized = True
392
- def iterate():
393
- gen = task(*args, **kwargs) # type: ignore
394
- while True:
395
- try:
396
- res = next(gen)
397
- except StopIteration:
398
- break
399
- except Exception as e:
400
- res_queue.put(ExceptionResult(e))
401
- break
402
- try:
403
- res_queue.put(OkResult(res))
404
- except PicklingError as e:
405
- res_queue.put(ExceptionResult(e))
406
- break
407
- else:
408
- continue
409
- GradioPartialContext.apply(gradio_context)
410
- with ThreadPoolExecutor() as executor:
411
- executor.submit(copy_context().run, iterate)
412
- res_queue.put(EndResult())
413
-
414
- # https://github.com/python/cpython/issues/91002
415
- if not hasattr(task, '__annotations__'):
416
- gradio_handler.__annotations__ = {}
417
-
418
- return gradio_handler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
voice_chat.py CHANGED
@@ -9,6 +9,89 @@ import torch
9
  import sentencepiece as spm
10
  import onnxruntime as ort
11
  from huggingface_hub import hf_hub_download, InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Speech Recognition Model Configuration
14
  model_name = "neongeckocom/stt_en_citrinet_512_gamma_0_25"
@@ -20,8 +103,8 @@ encoder = ort.InferenceSession(hf_hub_download(model_name, "model.onnx", subfold
20
  tokenizer = spm.SentencePieceProcessor(hf_hub_download(model_name, "tokenizer.spm", subfolder="onnx"))
21
 
22
  # Mistral Model Configuration
23
- client1 = InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
24
- system_instructions1 = "[SYSTEM] Answer as Real OpenGPT 4o, Made by 'KingNish', Keep conversation very short, clear, friendly and concise. The text provided is a request for a specific type of response from you, the virtual assistant. You will avoid introductions and start answering the query directly, elaborating on all aspects of the request. As an AI-powered assistant, your task is to generate responses that appear as if they are created by the user and not an AI-powered assistant. [USER]"
25
 
26
  def resample(audio_fp32, sr):
27
  return soxr.resample(audio_fp32, sr, sample_rate)
@@ -49,14 +132,22 @@ def transcribe(audio_path):
49
 
50
  return text
51
 
52
- def model(text):
53
- formatted_prompt = system_instructions1 + text + "[OpenGPT 4o]"
54
- stream = client1.text_generation(formatted_prompt, max_new_tokens=300)
55
- return stream[:-4]
 
 
 
 
 
 
 
 
56
 
57
- async def respond(audio):
58
  user = transcribe(audio)
59
- reply = model(user)
60
  communicate = edge_tts.Communicate(reply)
61
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
62
  tmp_path = tmp_file.name
 
9
  import sentencepiece as spm
10
  import onnxruntime as ort
11
  from huggingface_hub import hf_hub_download, InferenceClient
12
+ import requests
13
+ from bs4 import BeautifulSoup
14
+ import urllib
15
+ import random
16
+
17
+ # List of user agents to choose from for requests
18
+ _useragent_list = [
19
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:66.0) Gecko/20100101 Firefox/66.0',
20
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
21
+ 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
22
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.0.0 Safari/537.36',
23
+ 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
24
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.1661.62',
25
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0'
26
+ ]
27
+
28
+ def get_useragent():
29
+ """Returns a random user agent from the list."""
30
+ return random.choice(_useragent_list)
31
+
32
+ def extract_text_from_webpage(html_content):
33
+ """Extracts visible text from HTML content using BeautifulSoup."""
34
+ soup = BeautifulSoup(html_content, "html.parser")
35
+ # Remove unwanted tags
36
+ for tag in soup(["script", "style", "header", "footer", "nav"]):
37
+ tag.extract()
38
+ # Get the remaining visible text
39
+ visible_text = soup.get_text(strip=True)
40
+ return visible_text
41
+
42
+ def search(term, num_results=1, lang="en", advanced=True, sleep_interval=0, timeout=5, safe="active", ssl_verify=None):
43
+ """Performs a Google search and returns the results."""
44
+ escaped_term = urllib.parse.quote_plus(term)
45
+ start = 0
46
+ all_results = []
47
+
48
+ # Fetch results in batches
49
+ while start < num_results:
50
+ resp = requests.get(
51
+ url="https://www.google.com/search",
52
+ headers={"User-Agent": get_useragent()}, # Set random user agent
53
+ params={
54
+ "q": term,
55
+ "num": num_results - start, # Number of results to fetch in this batch
56
+ "hl": lang,
57
+ "start": start,
58
+ "safe": safe,
59
+ },
60
+ timeout=timeout,
61
+ verify=ssl_verify,
62
+ )
63
+ resp.raise_for_status() # Raise an exception if request fails
64
+
65
+ soup = BeautifulSoup(resp.text, "html.parser")
66
+ result_block = soup.find_all("div", attrs={"class": "g"})
67
+
68
+ # If no results, continue to the next batch
69
+ if not result_block:
70
+ start += 1
71
+ continue
72
+
73
+ # Extract link and text from each result
74
+ for result in result_block:
75
+ link = result.find("a", href=True)
76
+ if link:
77
+ link = link["href"]
78
+ try:
79
+ # Fetch webpage content
80
+ webpage = requests.get(link, headers={"User-Agent": get_useragent()})
81
+ webpage.raise_for_status()
82
+ # Extract visible text from webpage
83
+ visible_text = extract_text_from_webpage(webpage.text)
84
+ all_results.append({"link": link, "text": visible_text})
85
+ except requests.exceptions.RequestException as e:
86
+ # Handle errors fetching or processing webpage
87
+ print(f"Error fetching or processing {link}: {e}")
88
+ all_results.append({"link": link, "text": None})
89
+ else:
90
+ all_results.append({"link": None, "text": None})
91
+
92
+ start += len(result_block) # Update starting index for next batch
93
+
94
+ return all_results
95
 
96
  # Speech Recognition Model Configuration
97
  model_name = "neongeckocom/stt_en_citrinet_512_gamma_0_25"
 
103
  tokenizer = spm.SentencePieceProcessor(hf_hub_download(model_name, "tokenizer.spm", subfolder="onnx"))
104
 
105
  # Mistral Model Configuration
106
+ client1 = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
107
+ system_instructions1 = "<s>[SYSTEM] Answer as Real Jarvis JARVIS, Made by 'Tony Stark', Keep conversation very short, clear, friendly and concise. The text provided is a request for a specific type of response from you, the virtual assistant. The request asks you to provide friendly responses as if You are the character Jarvis, made by 'Tony Stark.' The expectation is that I will avoid introductions and start answering the query directly, elaborating on all aspects of the request. As an AI-powered assistant, your task is to generate responses that appear as if they are created by the user and not an AI-powered assistant. [USER]"
108
 
109
  def resample(audio_fp32, sr):
110
  return soxr.resample(audio_fp32, sr, sample_rate)
 
132
 
133
  return text
134
 
135
+ def model(text, web_search):
136
+ if web_search is True:
137
+ """Performs a web search, feeds the results to a language model, and returns the answer."""
138
+ web_results = search(text)
139
+ web2 = ' '.join([f"Link: {res['link']}\nText: {res['text']}\n\n" for res in web_results])
140
+ formatted_prompt = system_instructions1 + text + "[WEB]" + str(web2) + "[ANSWER]"
141
+ stream = client1.text_generation(formatted_prompt, max_new_tokens=512, stream=True, details=True, return_full_text=False)
142
+ return "".join([response.token.text for response in stream if response.token.text != "</s>"])
143
+ else:
144
+ formatted_prompt = system_instructions1 + text + "[JARVIS]"
145
+ stream = client1.text_generation(formatted_prompt, max_new_tokens=512, stream=True, details=True, return_full_text=False)
146
+ return "".join([response.token.text for response in stream if response.token.text != "</s>"])
147
 
148
+ async def respond(audio, web_search):
149
  user = transcribe(audio)
150
+ reply = model(user, web_search)
151
  communicate = edge_tts.Communicate(reply)
152
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
153
  tmp_path = tmp_file.name