pecore / app.py
gsarti's picture
Updated
cf3d1b1
raw
history blame
12.2 kB
import json
import os
import gradio as gr
import spaces
from contents import (
citation,
description,
examples,
how_it_works,
how_to_use,
subtitle,
title,
)
from gradio_highlightedtextbox import HighlightedTextbox
from style import custom_css
from utils import get_tuples_from_output
from inseq import list_feature_attribution_methods, list_step_functions
from inseq.commands.attribute_context.attribute_context import (
AttributeContextArgs,
attribute_context,
)
@spaces.GPU()
def pecore(
input_current_text: str,
input_context_text: str,
output_current_text: str,
output_context_text: str,
model_name_or_path: str,
attribution_method: str,
attributed_fn: str | None,
context_sensitivity_metric: str,
context_sensitivity_std_threshold: float,
context_sensitivity_topk: int,
attribution_std_threshold: float,
attribution_topk: int,
input_template: str,
input_current_text_template: str,
output_template: str,
special_tokens_to_keep: str | list[str] | None,
model_kwargs: str,
tokenizer_kwargs: str,
generation_kwargs: str,
attribution_kwargs: str,
):
formatted_input_current_text = input_current_text_template.format(
current=input_current_text
)
pecore_args = AttributeContextArgs(
show_intermediate_outputs=False,
save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
add_output_info=True,
viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
show_viz=False,
model_name_or_path=model_name_or_path,
attribution_method=attribution_method,
attributed_fn=attributed_fn,
attribution_selectors=None,
attribution_aggregators=None,
normalize_attributions=True,
model_kwargs=json.loads(model_kwargs),
tokenizer_kwargs=json.loads(tokenizer_kwargs),
generation_kwargs=json.loads(generation_kwargs),
attribution_kwargs=json.loads(attribution_kwargs),
context_sensitivity_metric=context_sensitivity_metric,
align_output_context_auto=False,
prompt_user_for_contextless_output_next_tokens=False,
special_tokens_to_keep=special_tokens_to_keep,
context_sensitivity_std_threshold=context_sensitivity_std_threshold,
context_sensitivity_topk=context_sensitivity_topk
if context_sensitivity_topk > 0
else None,
attribution_std_threshold=attribution_std_threshold,
attribution_topk=attribution_topk if attribution_topk > 0 else None,
input_current_text=formatted_input_current_text,
input_context_text=input_context_text if input_context_text else None,
input_template=input_template,
output_current_text=output_current_text if output_current_text else None,
output_context_text=output_context_text if output_context_text else None,
output_template=output_template,
)
out = attribute_context(pecore_args)
return get_tuples_from_output(out), gr.Button(visible=True), gr.Button(visible=True)
with gr.Blocks(css=custom_css) as demo:
gr.Markdown(title)
gr.Markdown(subtitle)
gr.Markdown(description)
with gr.Tab("πŸ‘ Attributing Context"):
with gr.Row():
with gr.Column():
input_current_text = gr.Textbox(
label="Input query", placeholder="Your input query..."
)
input_context_text = gr.Textbox(
label="Input context", lines=4, placeholder="Your input context..."
)
attribute_input_button = gr.Button("Submit", variant="primary")
with gr.Column():
pecore_output_highlights = HighlightedTextbox(
value=[
("This output will contain ", None),
("context sensitive", "Context sensitive"),
(" generated tokens and ", None),
("influential context", "Influential context"),
(" tokens.", None),
],
color_map={
"Context sensitive": "green",
"Influential context": "blue",
},
show_legend=True,
label="PECoRe Output",
combine_adjacent=True,
interactive=False,
)
with gr.Row(equal_height=True):
download_output_file_button = gr.Button(
"⇓ Download output",
visible=False,
link=os.path.join(
os.path.dirname(__file__), "/file=outputs/output.json"
),
)
download_output_html_button = gr.Button(
"πŸ” Download HTML",
visible=False,
link=os.path.join(
os.path.dirname(__file__), "/file=outputs/output.html"
),
)
attribute_input_examples = gr.Examples(
examples,
inputs=[input_current_text, input_context_text],
outputs=pecore_output_highlights,
)
with gr.Tab("βš™οΈ Parameters"):
gr.Markdown("## βš™οΈ PECoRe Parameters")
with gr.Row(equal_height=True):
model_name_or_path = gr.Textbox(
value="gsarti/cora_mgen",
label="Model",
info="Hugging Face Hub identifier of the model to analyze with PECoRe.",
interactive=True,
)
context_sensitivity_metric = gr.Dropdown(
value="kl_divergence",
label="Context sensitivity metric",
info="Metric to use to measure context sensitivity of generated tokens.",
choices=list_step_functions(),
interactive=True,
)
attribution_method = gr.Dropdown(
value="saliency",
label="Attribution method",
info="Attribution method identifier to identify relevant context tokens.",
choices=list_feature_attribution_methods(),
interactive=True,
)
attributed_fn = gr.Dropdown(
value="contrast_prob_diff",
label="Attributed function",
info="Function of model logits to use as target for the attribution method.",
choices=list_step_functions(),
interactive=True,
)
gr.Markdown("#### Results Selection Parameters")
with gr.Row(equal_height=True):
context_sensitivity_std_threshold = gr.Number(
value=1.0,
label="Context sensitivity threshold",
info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.",
precision=1,
minimum=0.0,
maximum=5.0,
step=0.5,
interactive=True,
)
context_sensitivity_topk = gr.Number(
value=0,
label="Context sensitivity top-k",
info="Select N to keep top N context sensitive tokens. 0 = keep all.",
interactive=True,
precision=0,
minimum=0,
maximum=10,
)
attribution_std_threshold = gr.Number(
value=1.0,
label="Attribution threshold",
info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.",
precision=1,
minimum=0.0,
maximum=5.0,
step=0.5,
interactive=True,
)
attribution_topk = gr.Number(
value=0,
label="Attribution top-k",
info="Select N to keep top N attributed tokens in the context. 0 = keep all.",
interactive=True,
precision=0,
minimum=0,
maximum=50,
)
gr.Markdown("#### Text Format Parameters")
with gr.Row(equal_height=True):
input_template = gr.Textbox(
value="{current} <P>:{context}",
label="Input template",
info="Template to format the input for the model. Use {current} and {context} placeholders.",
interactive=True,
)
output_template = gr.Textbox(
value="{current}",
label="Output template",
info="Template to format the output from the model. Use {current} and {context} placeholders.",
interactive=True,
)
input_current_text_template = gr.Textbox(
value="<Q>:{current}",
label="Input current text template",
info="Template to format the input query for the model. Use {current} placeholder.",
interactive=True,
)
special_tokens_to_keep = gr.Dropdown(
label="Special tokens to keep",
info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.",
value=None,
multiselect=True,
allow_custom_value=True,
)
gr.Markdown("## βš™οΈ Generation Parameters")
with gr.Row(equal_height=True):
output_current_text = gr.Textbox(
label="Generation output",
info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.",
interactive=True,
)
output_context_text = gr.Textbox(
label="Generation context",
info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.",
interactive=True,
)
generation_kwargs = gr.Code(
value="{}",
language="json",
label="Generation kwargs",
interactive=True,
lines=1,
)
gr.Markdown("## βš™οΈ Other Parameters")
with gr.Row(equal_height=True):
model_kwargs = gr.Code(
value="{}",
language="json",
label="Model kwargs",
interactive=True,
lines=1,
)
tokenizer_kwargs = gr.Code(
value="{}",
language="json",
label="Tokenizer kwargs",
interactive=True,
lines=1,
)
attribution_kwargs = gr.Code(
value="{}",
language="json",
label="Attribution kwargs",
interactive=True,
lines=1,
)
gr.Markdown(how_it_works)
gr.Markdown(how_to_use)
gr.Markdown(citation)
attribute_input_button.click(
pecore,
inputs=[
input_current_text,
input_context_text,
output_current_text,
output_context_text,
model_name_or_path,
attribution_method,
attributed_fn,
context_sensitivity_metric,
context_sensitivity_std_threshold,
context_sensitivity_topk,
attribution_std_threshold,
attribution_topk,
input_template,
input_current_text_template,
output_template,
special_tokens_to_keep,
model_kwargs,
tokenizer_kwargs,
generation_kwargs,
attribution_kwargs,
],
outputs=[
pecore_output_highlights,
download_output_file_button,
download_output_html_button,
],
)
demo.launch(allowed_paths=["outputs/"])