vid2persona / app.py
chansung's picture
Update app.py
89fdef5 verified
raw
history blame
7.93 kB
import gradio as gr
from vid2persona import init
from vid2persona.pipeline import vlm
from vid2persona.pipeline import llm
init.init_model("HuggingFaceH4/zephyr-7b-beta")
init.auth_gcp()
init.get_env_vars()
prompt_tpl_path = "vid2persona/prompts"
async def extract_traits(video_path):
traits = await vlm.get_traits(
init.gcp_project_id,
init.gcp_project_location,
video_path,
prompt_tpl_path
)
if 'characters' in traits:
traits = traits['characters'][0]
return [
traits, [],
gr.Textbox("", interactive=True),
gr.Button(interactive=True),
gr.Button(interactive=True),
gr.Button(interactive=True)
]
async def conversation(
message: str, messages: list, traits: dict,
model_id: str, max_input_token_length: int,
max_new_tokens: int, temperature: float,
top_p: float, top_k: float, repetition_penalty: float,
):
messages = messages + [[message, ""]]
yield [messages, message, gr.Button(interactive=False), gr.Button(interactive=False)]
async for partial_response in llm.chat(
message, messages, traits,
prompt_tpl_path, model_id,
max_input_token_length, max_new_tokens,
temperature, top_p, top_k,
repetition_penalty, hf_token=init.hf_access_token
):
last_message = messages[-1]
last_message[1] = last_message[1] + partial_response
messages[-1] = last_message
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]
async def regen_conversation(
messages: list, traits: dict,
model_id: str, max_input_token_length: int,
max_new_tokens: int, temperature: float,
top_p: float, top_k: float, repetition_penalty: float,
):
if len(messages) > 0:
message = messages[-1][0]
messages = messages[:-1]
messages = messages + [[message, ""]]
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
async for partial_response in llm.chat(
message, messages, traits,
prompt_tpl_path, model_id,
max_input_token_length, max_new_tokens,
temperature, top_p, top_k,
repetition_penalty, hf_token=init.hf_access_token
):
last_message = messages[-1]
last_message[1] = last_message[1] + partial_response
messages[-1] = last_message
yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]
yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]
with gr.Blocks(css="styles.css", theme=gr.themes.Soft()) as demo:
gr.Markdown("Vid2Persona", elem_classes=["md-center", "h1-font"])
gr.Markdown("This project breathes life into video characters by using AI to describe their personality and then chat with you as them. "
"[Gemini 1.0 Pro Vision model on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/overview) is used "
"to grasp traits of video characters, then [HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) model "
"is used to make conversation with them.",)
gr.Markdown("This space is modified to be working on Hugging Face [ZeroGPU](https://huggingface.co/zero-gpu-explorers). If you wish to run "
"the same application on your own machine, please check out the [project repository](https://github.com/deep-diver/Vid2Persona). "
"You can interact with other LLMs to make conversation besides HuggingFaceH4/zephyr-7b-beta by running them locally, or by "
"connecting them through remotely hosted within Text Generation Inference framework as [Hugging Face PRO](https://huggingface.co/blog/inference-pro) user.")
with gr.Column(elem_classes=["group"]):
with gr.Row():
video = gr.Video(label="upload short video clip", max_length=180)
traits = gr.Json(label="extracted traits")
with gr.Row():
trait_gen = gr.Button("generate traits")
with gr.Column(elem_classes=["group"]):
chatbot = gr.Chatbot([], label="chatbot", elem_id="chatbot", elem_classes=["chatbot-no-label"])
with gr.Row():
clear = gr.Button("clear conversation", interactive=False)
regen = gr.Button("regenerate the last", interactive=False)
stop = gr.Button("stop", interactive=False)
user_input = gr.Textbox(placeholder="ask anything", interactive=False, elem_classes=["textbox-no-label", "textbox-no-top-bottom-borders"])
with gr.Accordion("parameters' control pane", open=False):
model_id = gr.Dropdown(choices=init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS, value="HuggingFaceH4/zephyr-7b-beta", label="Model ID", visible=False)
with gr.Row():
max_input_token_length = gr.Slider(minimum=1024, maximum=4096, value=4096, label="max-input-tokens")
max_new_tokens = gr.Slider(minimum=128, maximum=2048, value=256, label="max-new-tokens")
with gr.Row():
temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="temperature")
top_p = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.9, label="top-p")
top_k = gr.Slider(minimum=0, maximum=2, step=0.1, value=50, label="top-k")
repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, value=1.2, label="repetition-penalty")
with gr.Row():
gr.Markdown(
"[![GitHub Repo](https://img.shields.io/badge/GitHub%20Repo-gray?style=for-the-badge&logo=github&link=https://github.com/deep-diver/Vid2Persona)](https://github.com/deep-diver/Vid2Persona) "
"[![Chansung](https://img.shields.io/badge/Chansung-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/algo_diver)](https://twitter.com/algo_diver) "
"[![Sayak](https://img.shields.io/badge/Sayak-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/RisingSayak)](https://twitter.com/RisingSayak )",
elem_id="bottom-md"
)
trait_gen.click(
extract_traits,
[video],
[traits, chatbot, user_input, clear, regen, stop],
concurrency_limit=5,
)
conv = user_input.submit(
conversation,
[
user_input, chatbot, traits,
model_id, max_input_token_length,
max_new_tokens, temperature,
top_p, top_k, repetition_penalty,
],
[chatbot, user_input, clear, regen],
concurrency_limit=5,
)
clear.click(
lambda: [
gr.Chatbot([]),
gr.Button(interactive=False),
gr.Button(interactive=False),
],
None, [chatbot, clear, regen],
concurrency_limit=5,
)
conv_regen = regen.click(
regen_conversation,
[
chatbot, traits,
model_id, max_input_token_length,
max_new_tokens, temperature,
top_p, top_k, repetition_penalty,
],
[chatbot, user_input, clear, regen],
concurrency_limit=5,
)
stop.click(
lambda: [
gr.Button(interactive=True),
gr.Button(interactive=True),
gr.Button(interactive=True),
], None, [clear, regen, stop],
cancels=[conv, conv_regen],
concurrency_limit=5,
)
gr.Examples(
[["assets/sample1.mp4"], ["assets/sample2.mp4"], ["assets/sample3.mp4"], ["assets/sample4.mp4"]],
video,
[traits, chatbot, user_input, clear, regen, stop],
extract_traits,
cache_examples=True
)
demo.queue(
max_size=256
).launch(
debug=True
)