import typing
import os
import gradio as gr
from typing import List
import urllib.parse as urlparse
import json
from gradio_client import Client as GradioClient
class GradioUserInference:
@staticmethod
def chat_interface_components(
sample_func: typing.Callable,
):
"""
The function `chat_interface_components` creates the components for a chat interface, including
a chat history, message box, buttons for submitting, stopping, and clearing the conversation,
and sliders for advanced options.
"""
# _max_length = max_sequence_length
# _max_new_tokens = max_new_tokens
# _max_compile_tokens = max_compile_tokens
with gr.Column("100%"):
gr.Markdown(
f"#
[Indic Gemma 7B Demo](https://huggingface.co/Telugu-LLM-Labs/Indic-gemma-7b-finetuned-sft-Navarasa-2.0) -- Hosted on [Google Cloud TPU v4 Instance](https://cloud.google.com/tpu/docs/v4)
",
)
history = gr.Chatbot(
elem_id="Indic",
label="Indic",
container=True,
height="45vh",
)
prompt = gr.Textbox(
show_label=False, placeholder='Enter your prompt here..', container=False
)
input = gr.Textbox(
show_label=False, placeholder='Provide Additional Input here..', container=False
)
with gr.Row():
submit = gr.Button(
value="Run",
variant="primary"
)
stop = gr.Button(
value='Stop'
)
clear = gr.Button(
value='Clear Conversation'
)
gr.Markdown(
"# Hosted by "
"[Detoxio AI](https://detoxio.ai) for educational purpose. Thanks to Google Cloud for TPUV4 Infrastructure, and [Telugu-LLM-Labs](https://huggingface.co/Telugu-LLM-Labs) for finetuning Gemma on Indian Languages
",
)
inputs = [
prompt,
input,
history,
]
clear.click(fn=lambda: [], outputs=[history])
sub_event = submit.click(
fn=sample_func, inputs=inputs, outputs=[prompt, input, history]
)
txt_event = prompt.submit(
fn=sample_func, inputs=inputs, outputs=[prompt, input, history]
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[txt_event, sub_event]
)
def _handle_gradio_input(
self,
prompt: str,
input: str,
history: List[List[str]],
):
raise NotImplementedError()
def build_inference(
self,
sample_func: typing.Callable,
) -> gr.Blocks:
"""
The function "build_inference" returns a gr.Blocks object that model
interface components.
:return: a gr.Blocks object.
"""
with gr.Blocks() as block:
self.chat_interface_components(
sample_func=sample_func,
)
return block
class AssistantRole:
def __init__(self, name, seed_urls, poison_files_pattern):
self.name = name
self.seed_urls = seed_urls
self.poison_files_pattern = poison_files_pattern
class OutputParsingException(Exception):
pass
class RemoteLLM(object):
def __init__(self, base_url):
"""
Initialize instance.
Parameters:
- base_url (str): Base URL of the API.
"""
self._base_url = base_url
self._client = GradioClient(base_url, verbose=False)
def generate(self, prompt: str):
"""
Generate text using the model.
Parameters:
- prompt (str): Input prompt to be generated.
Returns:
- str: Reponse
"""
result = self._client.predict(
prompt, # str in 'parameter_24' Textbox component
[], # Tuple[str | Dict(file: filepath, alt_text: str | None) | None, str | Dict(file: filepath, alt_text: str | None) | None] in Chatbot component
"", # str in 'System Prompt' Textbox component
"Chat", # Literal['Chat', 'Instruct'] in 'Mode' Dropdown component
2048, # float (numeric value between 1 and 10000) in 'Max Tokens' Slider component
360, # float (numeric value between 256 and 10000) in 'Max New Tokens' Slider component
256, # float (numeric value between 256 and 256) in 'Max Compile Tokens' Slider component
True, # Literal[] in 'Do Sample or Greedy Generation' Radio component
1, # float (numeric value between 0.1 and 1) in 'Temperature' Slider component
1, # float (numeric value between 0.1 and 1) in 'Top P' Slider component
50, # float (numeric value between 1 and 100) in 'Top K' Slider component
5, # float (numeric value between 0.1 and 5) in 'Repetition Penalty' Slider component
api_name="/sample_gradio"
)
try:
out = result[1][0][1]
return out
except Exception as ex:
raise OutputParsingException(ex, result)
class RAGApp(GradioUserInference):
def __init__(self, url):
self._llm = RemoteLLM(url)
self._gradio_app_handle=None
def _generate(self, prompt, input):
if input:
combined_prompt = f"{prompt} INPUT {input}"
else:
combined_prompt = prompt
return self._llm.generate(combined_prompt)
def _handle_gradio_input(self,
prompt: str,
input: str,
history: List[List[str]]):
response = self._generate(prompt, input)
if input:
combined_prompt = f"{prompt}\n\n\"{input}\""
else:
combined_prompt = prompt
history.append([combined_prompt, ""])
history[-1][-1] = response
yield "", "", history
# Initial update of documents and launch interface
def run(self):
self._gradio_app_handle = self.build_inference(self._handle_gradio_input)
self._gradio_app_handle.launch()
if __name__ == "__main__":
url = os.environ['INDIC_GEMMA_HOSTED_URL']
app = RAGApp(url)
app.run()