|
import gradio as gr |
|
from llm_inference import LLMInferenceNode |
|
import random |
|
from PIL import Image |
|
import io |
|
|
|
title = """<h1 align="center">SD 3.5 Prompt Generator</h1> |
|
<p><center> |
|
<a href="https://x.com/gokayfem" target="_blank">[X gokaygokay]</a> |
|
<a href="https://github.com/gokayfem" target="_blank">[Github gokayfem]</a> |
|
<p align="center">Generate random prompts using powerful LLMs from Hugging Face and SambaNova.</p> |
|
</center></p> |
|
""" |
|
|
|
def create_interface(): |
|
llm_node = LLMInferenceNode() |
|
|
|
with gr.Blocks(theme='bethecloud/storj_theme') as demo: |
|
|
|
gr.HTML(title) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
custom = gr.Textbox(label="Custom Input Prompt (optional)", lines=3) |
|
|
|
prompt_types = ["Random", "Long", "Short", "Medium", "OnlyObjects", "NoFigure", "Landscape", "Fantasy"] |
|
prompt_type = gr.Dropdown( |
|
choices=prompt_types, |
|
label="Prompt Type", |
|
value="Random", |
|
interactive=True |
|
) |
|
|
|
|
|
prompt_type_state = gr.State("Random") |
|
|
|
|
|
def update_prompt_type(value, state): |
|
if value == "Random": |
|
new_value = random.choice([t for t in prompt_types if t != "Random"]) |
|
print(f"Random prompt type selected: {new_value}") |
|
return value, new_value |
|
print(f"Updated prompt type: {value}") |
|
return value, value |
|
|
|
|
|
prompt_type.change(update_prompt_type, inputs=[prompt_type, prompt_type_state], outputs=[prompt_type, prompt_type_state]) |
|
|
|
with gr.Column(scale=2): |
|
with gr.Accordion("LLM Prompt Generation", open=False): |
|
long_talk = gr.Checkbox(label="Long Talk", value=True) |
|
compress = gr.Checkbox(label="Compress", value=True) |
|
compression_level = gr.Dropdown( |
|
choices=["soft", "medium", "hard"], |
|
label="Compression Level", |
|
value="hard" |
|
) |
|
|
|
custom_base_prompt = gr.Textbox(label="Custom Base Prompt", lines=5) |
|
|
|
|
|
llm_provider = gr.Dropdown( |
|
choices=["Hugging Face", "SambaNova"], |
|
label="LLM Provider", |
|
value="Hugging Face" |
|
) |
|
api_key = gr.Textbox(label="API Key", type="password", visible=False) |
|
model = gr.Dropdown(label="Model", choices=["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3.1-70B-Instruct","mistralai/Mixtral-8x7B-Instruct-v0.1","mistralai/Mistral-7B-Instruct-v0.3"], value="Qwen/Qwen2.5-72B-Instruct") |
|
with gr.Row(): |
|
|
|
generate_button = gr.Button("Generate Prompt") |
|
with gr.Row(): |
|
text_output = gr.Textbox(label="LLM Generated Text", lines=10, show_copy_button=True) |
|
|
|
|
|
def update_model_choices(provider): |
|
provider_models = { |
|
"Hugging Face": [ |
|
"Qwen/Qwen2.5-72B-Instruct", |
|
"meta-llama/Meta-Llama-3.1-70B-Instruct", |
|
"mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
"mistralai/Mistral-7B-Instruct-v0.3" |
|
], |
|
"SambaNova": [ |
|
"Meta-Llama-3.1-70B-Instruct", |
|
"Meta-Llama-3.1-405B-Instruct", |
|
"Meta-Llama-3.1-8B-Instruct" |
|
], |
|
} |
|
models = provider_models.get(provider, []) |
|
return gr.Dropdown(choices=models, value=models[0] if models else "") |
|
|
|
def update_api_key_visibility(provider): |
|
return gr.update(visible=False) |
|
|
|
llm_provider.change( |
|
update_model_choices, |
|
inputs=[llm_provider], |
|
outputs=[model] |
|
) |
|
llm_provider.change( |
|
update_api_key_visibility, |
|
inputs=[llm_provider], |
|
outputs=[api_key] |
|
) |
|
|
|
|
|
def generate_random_prompt_with_llm(custom_input, prompt_type, long_talk, compress, compression_level, custom_base_prompt, provider, api_key, model_selected, prompt_type_state): |
|
try: |
|
|
|
dynamic_seed = random.randint(0, 1000000) |
|
|
|
|
|
if prompt_type == "Random": |
|
prompt_type = random.choice([t for t in prompt_types if t != "Random"]) |
|
print(f"Random prompt type selected: {prompt_type}") |
|
|
|
if custom_input and custom_input.strip(): |
|
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, custom_input) |
|
print(f"Using Custom Input Prompt.") |
|
else: |
|
prompt = llm_node.generate_prompt(dynamic_seed, prompt_type, f"Create a random prompt based on the '{prompt_type}' type.") |
|
print(f"No Custom Input Prompt provided. Generated prompt based on prompt_type: {prompt_type}") |
|
|
|
print(f"Generated Prompt: {prompt}") |
|
|
|
|
|
poster = False |
|
result = llm_node.generate( |
|
input_text=prompt, |
|
long_talk=long_talk, |
|
compress=compress, |
|
compression_level=compression_level, |
|
poster=poster, |
|
prompt_type=prompt_type, |
|
custom_base_prompt=custom_base_prompt, |
|
provider=provider, |
|
api_key=api_key, |
|
model=model_selected |
|
) |
|
print(f"Generated Text: {result}") |
|
|
|
return result |
|
|
|
except Exception as e: |
|
print(f"An error occurred: {e}") |
|
return f"Error occurred while processing the request: {str(e)}" |
|
|
|
|
|
generate_button.click( |
|
generate_random_prompt_with_llm, |
|
inputs=[custom, prompt_type, long_talk, compress, compression_level, custom_base_prompt, llm_provider, api_key, model, prompt_type_state], |
|
outputs=[text_output], |
|
api_name="generate_random_prompt_with_llm" |
|
) |
|
|
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch(share=True) |