import gradio as gr from urllib.parse import urlparse import requests import time from PIL import Image import base64 import io import uuid import os def extract_property_info(prop): combined_prop = {} merge_keywords = ["allOf", "anyOf", "oneOf"] for keyword in merge_keywords: if keyword in prop: for subprop in prop[keyword]: combined_prop.update(subprop) del prop[keyword] if not combined_prop: combined_prop = prop.copy() for key in ["description", "default"]: if key in prop: combined_prop[key] = prop[key] return combined_prop def detect_file_type(filename): audio_extensions = [".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"] image_extensions = [ ".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".svg", ".webp", ] video_extensions = [ ".mp4", ".mov", ".wmv", ".flv", ".avi", ".avchd", ".mkv", ".webm", ] # Extract the file extension if isinstance(filename, str): extension = filename[filename.rfind(".") :].lower() # Check the extension against each list if extension in audio_extensions: return "audio" elif extension in image_extensions: return "image" elif extension in video_extensions: return "video" else: return "string" elif isinstance(filename, list): return "list" def build_gradio_inputs(ordered_input_schema, example_inputs=None): inputs = [] input_field_strings = """inputs = []\n""" names = [] for index, (name, prop) in enumerate(ordered_input_schema): names.append(name) prop = extract_property_info(prop) if "enum" in prop: input_field = gr.Dropdown( choices=prop["enum"], label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), ) input_field_string = f"""inputs.append(gr.Dropdown( choices={prop["enum"]}, label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value="{prop.get("default")}" ))\n""" elif prop["type"] == "integer": if prop.get("minimum") and prop.get("maximum"): input_field = gr.Slider( label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), minimum=prop.get("minimum"), maximum=prop.get("maximum"), step=1, ) input_field_string = f"""inputs.append(gr.Slider( label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, minimum={prop.get("minimum")}, maximum={prop.get("maximum")}, step=1, ))\n""" else: input_field = gr.Number( label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), ) input_field_string = f"""inputs.append(gr.Number( label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} ))\n""" elif prop["type"] == "number": if prop.get("minimum") and prop.get("maximum"): input_field = gr.Slider( label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), minimum=prop.get("minimum"), maximum=prop.get("maximum"), ) input_field_string = f"""inputs.append(gr.Slider( label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}, minimum={prop.get("minimum")}, maximum={prop.get("maximum")} ))\n""" else: input_field = gr.Number( label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), ) input_field_string = f"""inputs.append(gr.Number( label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} ))\n""" elif prop["type"] == "boolean": input_field = gr.Checkbox( label=prop.get("title"), info=prop.get("description"), value=prop.get("default"), ) input_field_string = f"""inputs.append(gr.Checkbox( label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")} ))\n""" elif ( prop["type"] == "string" and prop.get("format") == "uri" and example_inputs ): input_type_example = example_inputs.get(name, None) if input_type_example: input_type = detect_file_type(input_type_example) else: input_type = None if input_type == "image": input_field = gr.Image(label=prop.get("title"), type="filepath") input_field_string = f"""inputs.append(gr.Image( label="{prop.get("title")}", type="filepath" ))\n""" elif input_type == "audio": input_field = gr.Audio(label=prop.get("title"), type="filepath") input_field_string = f"""inputs.append(gr.Audio( label="{prop.get("title")}", type="filepath" ))\n""" elif input_type == "video": input_field = gr.Video(label=prop.get("title")) input_field_string = f"""inputs.append(gr.Video( label="{prop.get("title")}" ))\n""" else: input_field = gr.File(label=prop.get("title")) input_field_string = f"""inputs.append(gr.File( label="{prop.get("title")}" ))\n""" else: input_field = gr.Textbox( label=prop.get("title"), info=prop.get("description"), ) input_field_string = f"""inputs.append(gr.Textbox( label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'} ))\n""" inputs.append(input_field) input_field_strings += f"{input_field_string}\n" input_field_strings += f"names = {names}\n" return inputs, input_field_strings, names def build_gradio_outputs_replicate(output_types): outputs = [] output_field_strings = """outputs = []\n""" if output_types: for output in output_types: if output == "image": output_field = gr.Image() output_field_string = "outputs.append(gr.Image())" elif output == "audio": output_field = gr.Audio(type="filepath") output_field_string = "outputs.append(gr.Audio(type='filepath'))" elif output == "video": output_field = gr.Video() output_field_string = "outputs.append(gr.Video())" elif output == "string": output_field = gr.Textbox() output_field_string = "outputs.append(gr.Textbox())" elif output == "json": output_field = gr.JSON() output_field_string = "outputs.append(gr.JSON())" elif output == "list": output_field = gr.JSON() output_field_string = "outputs.append(gr.JSON())" outputs.append(output_field) output_field_strings += f"{output_field_string}\n" else: output_field = gr.JSON() output_field_string = "outputs.append(gr.JSON())" outputs.append(output_field) return outputs, output_field_strings def build_gradio_outputs_cog(): pass def process_outputs(outputs): output_values = [] for output in outputs: if not output: continue if isinstance(output, str): if output.startswith("data:image"): base64_data = output.split(",", 1)[1] image_data = base64.b64decode(base64_data) image_stream = io.BytesIO(image_data) image = Image.open(image_stream) output_values.append(image) elif output.startswith("data:audio"): base64_data = output.split(",", 1)[1] audio_data = base64.b64decode(base64_data) audio_stream = io.BytesIO(audio_data) filename = f"{uuid.uuid4()}.wav" # Change format as needed with open(filename, "wb") as audio_file: audio_file.write(audio_stream.getbuffer()) output_values.append(filename) elif output.startswith("data:video"): base64_data = output.split(",", 1)[1] video_data = base64.b64decode(base64_data) video_stream = io.BytesIO(video_data) # Here you can save the audio or return the stream for further processing filename = f"{uuid.uuid4()}.mp4" # Change format as needed with open(filename, "wb") as video_file: video_file.write(video_stream.getbuffer()) output_values.append(filename) else: output_values.append(output) else: output_values.append(output) return output_values def parse_outputs(data): if isinstance(data, dict): # Handle case where data is an object dict_values = [] for value in data.values(): extracted_values = parse_outputs(value) # For dict, we append instead of extend to maintain list structure within objects if isinstance(value, list): dict_values += [extracted_values] else: dict_values += extracted_values return dict_values elif isinstance(data, list): # Handle case where data is an array list_values = [] for item in data: # Here we extend to flatten the list since we're already in an array context list_values += parse_outputs(item) return list_values else: # Handle primitive data types directly return [data] def create_dynamic_gradio_app( inputs, outputs, api_url, api_id=None, replicate_token=None, title="", model_description="", names=[], local_base=False, hostname="0.0.0.0", ): expected_outputs = len(outputs) def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)): payload = {"input": {}} if api_id: payload["version"] = api_id parsed_url = urlparse(str(request.url)) if local_base: base_url = f"http://{hostname}:7860" else: base_url = parsed_url.scheme + "://" + parsed_url.netloc for i, key in enumerate(names): value = args[i] if value and (os.path.exists(str(value))): value = f"{base_url}/file=" + value if value is not None and value != "": payload["input"][key] = value print(payload) headers = {"Content-Type": "application/json"} if replicate_token: headers["Authorization"] = f"Token {replicate_token}" print(headers) response = requests.post(api_url, headers=headers, json=payload) if response.status_code == 201: follow_up_url = response.json()["urls"]["get"] response = requests.get(follow_up_url, headers=headers) while response.json()["status"] != "succeeded": if response.json()["status"] == "failed": raise gr.Error("The submission failed!") response = requests.get(follow_up_url, headers=headers) time.sleep(1) # TODO: Add a failing mechanism if the API gets stuck if response.status_code == 200: json_response = response.json() # If the output component is JSON return the entire output response if outputs[0].get_config()["name"] == "json": return json_response["output"] predict_outputs = parse_outputs(json_response["output"]) processed_outputs = process_outputs(predict_outputs) difference_outputs = expected_outputs - len(processed_outputs) # If less outputs than expected, hide the extra ones if difference_outputs > 0: extra_outputs = [gr.update(visible=False)] * difference_outputs processed_outputs.extend(extra_outputs) # If more outputs than expected, cap the outputs to the expected number if elif difference_outputs < 0: processed_outputs = processed_outputs[:difference_outputs] return ( tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0] ) else: if response.status_code == 409: raise gr.Error( f"Sorry, the Cog image is still processing. Try again in a bit." ) raise gr.Error(f"The submission failed! Error: {response.status_code}") app = gr.Interface( fn=predict, inputs=inputs, outputs=outputs, title=title, description=model_description, allow_flagging="never", ) return app def create_gradio_app_script( inputs_string, outputs_string, api_url, api_id=None, replicate_token=None, title="", model_description="", local_base=False, hostname="0.0.0.0" ): headers = {"Content-Type": "application/json"} if replicate_token: headers["Authorization"] = f"Token {replicate_token}" if local_base: base_url = f'base_url = "http://{hostname}:7860"' else: base_url = """parsed_url = urlparse(str(request.url)) base_url = parsed_url.scheme + "://" + parsed_url.netloc""" headers_string = f"""headers = {headers}\n""" api_id_value = f'payload["version"] = "{api_id}"' if api_id is not None else "" definition_string = """expected_outputs = len(outputs) def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):""" payload_string = f"""payload = {{"input": {{}}}} {api_id_value} {base_url} for i, key in enumerate(names): value = args[i] if value and (os.path.exists(str(value))): value = f"{{base_url}}/file=" + value if value is not None and value != "": payload["input"][key] = value\n""" request_string = ( f"""response = requests.post("{api_url}", headers=headers, json=payload)\n""" ) result_string = f""" if response.status_code == 201: follow_up_url = response.json()["urls"]["get"] response = requests.get(follow_up_url, headers=headers) while response.json()["status"] != "succeeded": if response.json()["status"] == "failed": raise gr.Error("The submission failed!") response = requests.get(follow_up_url, headers=headers) time.sleep(1) if response.status_code == 200: json_response = response.json() #If the output component is JSON return the entire output response if(outputs[0].get_config()["name"] == "json"): return json_response["output"] predict_outputs = parse_outputs(json_response["output"]) processed_outputs = process_outputs(predict_outputs) difference_outputs = expected_outputs - len(processed_outputs) # If less outputs than expected, hide the extra ones if difference_outputs > 0: extra_outputs = [gr.update(visible=False)] * difference_outputs processed_outputs.extend(extra_outputs) # If more outputs than expected, cap the outputs to the expected number elif difference_outputs < 0: processed_outputs = processed_outputs[:difference_outputs] return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0] else: if(response.status_code == 409): raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.") raise gr.Error(f"The submission failed! Error: {{response.status_code}}")\n""" interface_string = f"""title = "{title}" model_description = "{model_description}" app = gr.Interface( fn=predict, inputs=inputs, outputs=outputs, title=title, description=model_description, allow_flagging="never", ) app.launch(share=True) """ app_string = f"""import gradio as gr from urllib.parse import urlparse import requests import time import os from utils.gradio_helpers import parse_outputs, process_outputs {inputs_string} {outputs_string} {definition_string} {headers_string} {payload_string} {request_string} {result_string} {interface_string} """ return app_string