|
import inspect |
|
import json |
|
import os |
|
import random |
|
from typing import Literal, cast |
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from gradio.data_classes import InterfaceTypes |
|
from gradio.flagging import CSVLogger |
|
from torchvision import transforms |
|
from transformers import AutoTokenizer, LlamaForCausalLM |
|
|
|
from trace_exec import run_program_with_trace, CompileTimeError |
|
from vision_processes import load_models |
|
|
|
print("-" * 10, "Loading models...") |
|
load_models() |
|
|
|
with open('joint.prompt') as f: |
|
prompt_template = f.read().strip() |
|
|
|
INPUT_TYPE = 'image' |
|
OUTPUT_TYPE = 'str' |
|
SIGNATURE = f'def execute_command({INPUT_TYPE}) -> {OUTPUT_TYPE}:' |
|
|
|
|
|
def generate(model, input_text): |
|
torch.cuda.empty_cache() |
|
print("-" * 10, "Before loading LLM:") |
|
print(torch.cuda.memory_summary()) |
|
|
|
dtype = os.environ.get("CODELLAMA_DTYPE") |
|
assert dtype in ['bfloat16', '8bit', '4bit', ] |
|
tokenizer = AutoTokenizer.from_pretrained(model) |
|
model = LlamaForCausalLM.from_pretrained( |
|
model, |
|
device_map="auto", |
|
load_in_8bit=dtype == "8bit", |
|
load_in_4bit=dtype == "4bit", |
|
torch_dtype=torch.bfloat16 if dtype == "bfloat16" else None, |
|
) |
|
print("-" * 10, "LLM loaded:") |
|
print(model) |
|
print(torch.cuda.memory_summary()) |
|
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids |
|
generated_ids = model.generate( |
|
input_ids.to('cuda'), max_new_tokens=256, stop_strings=["\n\n"], do_sample=False, tokenizer=tokenizer |
|
) |
|
generated_ids = generated_ids[0][input_ids.shape[1]:] |
|
text = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
|
|
del model |
|
torch.cuda.empty_cache() |
|
print("-" * 10, "After loading LLM:") |
|
print(torch.cuda.memory_summary()) |
|
|
|
return text |
|
|
|
|
|
def to_custom_trace(result, error, traced): |
|
if traced is None: |
|
assert isinstance(error, CompileTimeError) |
|
traced = 'Compile Error' |
|
return "-> {}\n\n--- Trace\n\n{}".format(result, traced) |
|
|
|
|
|
def answer_from_trace(x): |
|
assert x.startswith("->") |
|
return x[2:].splitlines()[0].strip() |
|
|
|
|
|
def debug(image, question, code, traced_info): |
|
|
|
prompt = f"# Given an image: {question}\n{code}\n\n{traced_info}\n\n# Program is" |
|
print("--- For debug: critic prompt is ---") |
|
print(prompt) |
|
print("---\n") |
|
critic_out = generate("VDebugger/VDebugger-critic-generalist-7B", prompt) |
|
incorrect = critic_out.strip().startswith('wrong') |
|
critic_out = "# Program is" + critic_out |
|
|
|
if not incorrect: |
|
yield code, traced_info, critic_out, "N/A", "N/A", answer_from_trace(traced_info) |
|
return |
|
else: |
|
yield code, traced_info, critic_out, "RUNNING IN PROGRESS...", "", "" |
|
|
|
|
|
critic_code = ('def execute_command' + critic_out.split('def execute_command')[1]).strip() |
|
if '# Program is' in code: |
|
critic_code = critic_code.split("# Program is")[0].strip() |
|
prompt = f"# Given an image: {question}\n{critic_code}\n\n{traced_info}\n\n# Correction" |
|
print("--- For debug: refiner prompt is ---") |
|
print(prompt) |
|
print("---\n") |
|
refiner_out = generate("VDebugger/VDebugger-refiner-generalist-7B", prompt).strip() |
|
yield code, traced_info, critic_out, refiner_out, "RUNNING IN PROGRESS...", "" |
|
|
|
|
|
result, error, traced = run_program_with_trace(refiner_out, image, INPUT_TYPE, OUTPUT_TYPE) |
|
traced_info_2 = to_custom_trace(result, error, traced) |
|
|
|
yield code, traced_info, critic_out, refiner_out, traced_info_2, answer_from_trace(traced_info_2) |
|
|
|
|
|
def predict(image, question): |
|
if image is None: |
|
gr.Warning("Please provide an image", duration=5) |
|
return |
|
image = transforms.Compose([transforms.ToTensor()])(image) |
|
|
|
question = question.strip() |
|
if question == "": |
|
gr.Warning("Please provide a question", duration=5) |
|
return |
|
|
|
|
|
prompt = prompt_template.replace("INSERT_QUERY_HERE", f"Given an image: {question}\n{SIGNATURE}") |
|
code = generate("codellama/CodeLlama-7b-Python-hf", prompt) |
|
code = (SIGNATURE + code).strip() |
|
yield code, "RUNNING IN PROGRESS...", "", "", "", "" |
|
|
|
|
|
result, error, traced = run_program_with_trace(code, image, INPUT_TYPE, OUTPUT_TYPE) |
|
traced_info = to_custom_trace(result, error, traced) |
|
yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" |
|
|
|
for tup in debug(image, question, code, traced_info): |
|
yield tup |
|
return |
|
|
|
|
|
def re_debug(image, question, code, traced_info): |
|
if code is None or code == "" or traced_info is None or traced_info == "": |
|
gr.Warning("No prior debugging round", duration=5) |
|
return |
|
|
|
yield code, traced_info, "RUNNING IN PROGRESS...", "", "", "" |
|
for tup in debug(image, question, code, traced_info): |
|
yield tup |
|
return |
|
|
|
|
|
DESCRIPTION = """# VDebugger |
|
|
|
| [Paper](https://arxiv.org/abs/2406.13444) | [Project](https://shirley-wu.github.io/vdebugger/) | [Code](https://github.com/shirley-wu/vdebugger/) | [Models and Data](https://huggingface.co/VDebugger) | |
|
|
|
**VDebugger** is a novel critic-refiner framework trained to localize and debug *visual programs* by tracking execution step by step. In this demo, we show the visual programs, the outputs from both the critic and the refiner, as well as the final result. |
|
|
|
**Warning:** Reduced performance and accuracy may be observed. Due to resource limitation of huggingface spaces, this demo runs Llama inference in 4-bit quantization and uses smaller foundation VLMs. For full capacity, please use the original code.""" |
|
|
|
|
|
class MyInterface(gr.Interface): |
|
def __init__(self): |
|
super(gr.Interface, self).__init__( |
|
title=None, |
|
theme=None, |
|
analytics_enabled=None, |
|
mode="tabbed_interface", |
|
css=None, |
|
js=None, |
|
head=None, |
|
) |
|
self.interface_type = InterfaceTypes.STANDARD |
|
self.description = DESCRIPTION |
|
self.cache_examples = None |
|
self.examples_per_page = 5 |
|
self.example_labels = None |
|
self.batch = False |
|
self.live = False |
|
self.api_name = "predict" |
|
self.max_batch_size = 4 |
|
self.concurrency_limit = 'default' |
|
self.show_progress = "full" |
|
self.allow_flagging = 'auto' |
|
self.flagging_options = [("Flag", ""), ] |
|
self.flagging_callback = CSVLogger() |
|
self.flagging_dir = 'flagged' |
|
|
|
|
|
with open('examples/questions.json') as f: |
|
example_questions = json.load(f) |
|
self.examples = [] |
|
for question in example_questions: |
|
self.examples.append([ |
|
Image.open('examples/{}.jpg'.format(question['imageId'])), question['question'], |
|
]) |
|
|
|
def load_random_example(): |
|
image, question = random.choice(self.examples) |
|
return image, question, "", "", "", "", "", "" |
|
|
|
|
|
with self: |
|
self.render_title_description() |
|
|
|
with gr.Row(): |
|
image = gr.Image(label="Image", type="pil", width="30%", scale=1) |
|
question = gr.Textbox(label="Question", scale=2) |
|
|
|
with gr.Row(): |
|
_clear_btn = gr.ClearButton(value="Clear", variant="secondary") |
|
_random_eg_btn = gr.Button("Random Example Input") |
|
_submit_btn = gr.Button("Submit", variant="primary") |
|
if inspect.isgeneratorfunction(predict) or inspect.isasyncgenfunction(predict): |
|
_stop1_btn = gr.Button("Stop", variant="stop", visible=False) |
|
_redebug_btn = gr.Button("Debug for Another Round", variant="primary") |
|
if inspect.isgeneratorfunction(re_debug) or inspect.isasyncgenfunction(re_debug): |
|
_stop2_btn = gr.Button("Stop", variant="stop", visible=False) |
|
|
|
with gr.Row(): |
|
o1 = gr.Textbox(label="No debugging: program") |
|
o2 = gr.Textbox(label="No debugging: execution") |
|
|
|
with gr.Row(): |
|
o3 = gr.Textbox(label="VDebugger: critic") |
|
o4 = gr.Textbox(label="VDebugger: refiner") |
|
|
|
with gr.Row(): |
|
o5 = gr.Textbox(label="VDebugger: execution") |
|
o6 = gr.Textbox(label="VDebugger: final answer") |
|
|
|
question.submit(fn=predict, inputs=[image, question], outputs=[o1, o2, o3, o4, o5, o6]) |
|
_random_eg_btn.click(fn=load_random_example, outputs=[image, question, o1, o2, o3, o4, o5, o6]) |
|
|
|
async def cleanup(): |
|
return [gr.Button(visible=True), gr.Button(visible=False)] |
|
|
|
|
|
triggers = [_redebug_btn.click, ] |
|
extra_output = [_redebug_btn, _stop2_btn] |
|
predict_event = gr.on( |
|
triggers, |
|
gr.utils.async_lambda( |
|
lambda: ( |
|
gr.Button(visible=False), |
|
gr.Button(visible=True), |
|
) |
|
), |
|
inputs=None, |
|
outputs=[_redebug_btn, _stop2_btn], |
|
queue=False, |
|
show_api=False, |
|
).then( |
|
re_debug, |
|
[image, question, o4, o5], |
|
[o1, o2, o3, o4, o5, o6], |
|
api_name=self.api_name, |
|
scroll_to_output=False, |
|
preprocess=not (self.api_mode), |
|
postprocess=not (self.api_mode), |
|
batch=self.batch, |
|
max_batch_size=self.max_batch_size, |
|
concurrency_limit=self.concurrency_limit, |
|
show_progress=cast( |
|
Literal["full", "minimal", "hidden"], self.show_progress |
|
), |
|
) |
|
redebug_event = predict_event.then( |
|
cleanup, |
|
inputs=None, |
|
outputs=extra_output, |
|
queue=False, |
|
show_api=False, |
|
) |
|
_stop2_btn.click( |
|
cleanup, |
|
inputs=None, |
|
outputs=[_redebug_btn, _stop2_btn], |
|
cancels=predict_event, |
|
queue=False, |
|
show_api=False, |
|
) |
|
|
|
|
|
triggers = [_submit_btn.click, question.submit, ] |
|
extra_output = [_submit_btn, _stop1_btn] |
|
predict_event = gr.on( |
|
triggers, |
|
gr.utils.async_lambda( |
|
lambda: ( |
|
gr.Button(visible=False), |
|
gr.Button(visible=True), |
|
) |
|
), |
|
inputs=None, |
|
outputs=[_submit_btn, _stop1_btn], |
|
queue=False, |
|
show_api=False, |
|
).then( |
|
predict, |
|
[image, question], |
|
[o1, o2, o3, o4, o5, o6], |
|
api_name=self.api_name, |
|
scroll_to_output=False, |
|
preprocess=not (self.api_mode), |
|
postprocess=not (self.api_mode), |
|
batch=self.batch, |
|
max_batch_size=self.max_batch_size, |
|
concurrency_limit=self.concurrency_limit, |
|
show_progress=cast( |
|
Literal["full", "minimal", "hidden"], self.show_progress |
|
), |
|
) |
|
submit_event = predict_event.then( |
|
cleanup, |
|
inputs=None, |
|
outputs=extra_output, |
|
queue=False, |
|
show_api=False, |
|
) |
|
_stop1_btn.click( |
|
cleanup, |
|
inputs=None, |
|
outputs=[_submit_btn, _stop1_btn], |
|
cancels=predict_event, |
|
queue=False, |
|
show_api=False, |
|
) |
|
|
|
|
|
self.input_components = [image, question] |
|
self.output_components = [o1, o2, o3, o4, o5, o6] |
|
self.fn = predict |
|
self.attach_clear_events(_clear_btn, None) |
|
self.render_examples() |
|
|
|
|
|
if __name__ == "__main__": |
|
MyInterface().launch(share=os.environ.get("SHARE", '') != "") |
|
|