|
|
|
|
|
import spaces |
|
import requests |
|
import gradio as gr |
|
from bs4 import BeautifulSoup |
|
from transformers import pipeline |
|
|
|
from kvpress import ( |
|
ExpectedAttentionPress, |
|
KnormPress, |
|
RandomPress, |
|
SnapKVPress, |
|
StreamingLLMPress, |
|
TOVAPress, |
|
) |
|
|
|
press_dict = { |
|
"ExpectedAttentionPress": ExpectedAttentionPress, |
|
"KnormPress": KnormPress, |
|
"RandomPress": RandomPress, |
|
"SnapKVPress": SnapKVPress, |
|
"StreamingLLMPress": StreamingLLMPress, |
|
"TOVAPress": TOVAPress, |
|
} |
|
|
|
pipe_dict = dict( |
|
(ckpt, pipeline("kv-press-text-generation", model=ckpt, device="cuda:0", torch_dtype="auto")) |
|
for ckpt in ["meta-llama/Meta-Llama-3.1-8B-Instruct", "Qwen/Qwen2.5-7B-Instruct-1M"] |
|
) |
|
|
|
@spaces.GPU |
|
def process_request(url, question, press_name, pipe_name, compression_ratio): |
|
""" """ |
|
|
|
if press_name not in press_dict: |
|
return f"Invalid press selected: {press_name}", -1, -1 |
|
|
|
|
|
try: |
|
content = requests.get(url).content |
|
except requests.exceptions.RequestException as e: |
|
return f"Error fetching the Wikipedia article: {str(e)}", -1, -1 |
|
|
|
try: |
|
|
|
soup = BeautifulSoup(content, "html.parser") |
|
context = "".join([p.text for p in soup.find_all("p")]) + "\n\n" |
|
|
|
|
|
press = press_dict[press_name](compression_ratio) |
|
num_tokens = pipe_dict[pipe_name].tokenizer(context, return_tensors="pt")["input_ids"].shape[1] |
|
pred_answer = pipe_dict[pipe_name](context, question=question, press=press)["answer"] |
|
|
|
return pred_answer, num_tokens, int(num_tokens * (1 - compression_ratio)) |
|
except Exception as e: |
|
if "CUDA out of memory" in str(e): |
|
return "Error: CUDA out of memory. Try using a smaller article or a lower compression ratio.", -1 |
|
else: |
|
return str(e), -1, -1 |
|
|
|
|
|
def gradio_interface(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Wikipedia Article Question Answering with kvpress |
|
This demo answers questions about any given Wikipedia article. |
|
Under the hood, [kvpress](https://github.com/NVIDIA/kvpress) *compresses the key-value (KV) cache* associated with the article, helping reduce memory usage and accelerate decoding. |
|
**How to use:** |
|
1. Enter a Wikipedia article URL |
|
2. Type your question |
|
3. Select a model, a press and the desired compression ratio |
|
4. Press "Submit" to see the answer, along with token statistics before and after compression |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
url_input = gr.Textbox(label="Wikipedia Article URL", placeholder="Enter the Wikipedia article URL here") |
|
question_input = gr.Textbox(label="Question", placeholder="Type your question here") |
|
|
|
with gr.Row(): |
|
|
|
pipe_selector = gr.Dropdown( |
|
choices=list(pipe_dict.keys()), |
|
value="meta-llama/Meta-Llama-3.1-8B-Instruct", |
|
label="Select Model", |
|
) |
|
|
|
press_selector = gr.Dropdown( |
|
choices=list(press_dict.keys()), |
|
value="ExpectedAttentionPress", |
|
label="Select Press", |
|
) |
|
compression_slider = gr.Slider(minimum=0.0, maximum=0.9, step=0.1, value=0.5, label="Compression Ratio") |
|
|
|
output = gr.Textbox(label="Output", lines=10) |
|
output_num_tokens = gr.Number(label="Number of tokens before compression", interactive=False) |
|
output_compressed_num_tokens = gr.Number(label="Number of tokens after compression", interactive=False) |
|
|
|
submit_button = gr.Button("Submit") |
|
|
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"https://en.wikipedia.org/wiki/Nvidia", |
|
"Complete this sentence: In May 2017, the program had 1,300 companies. As of March 2018, there were ", |
|
"ExpectedAttentionPress", |
|
0.5, |
|
], |
|
[ |
|
"https://en.wikipedia.org/wiki/Hugging_Face", |
|
"What was the original name of the transformers library ?", |
|
"ExpectedAttentionPress", |
|
0.5, |
|
], |
|
[ |
|
"https://en.wikipedia.org/wiki/World_Chess_Championship_2024", |
|
"On which move did the world chess championship end?", |
|
"ExpectedAttentionPress", |
|
0.5, |
|
], |
|
], |
|
inputs=[url_input, question_input, press_selector, compression_slider], |
|
) |
|
|
|
submit_button.click( |
|
process_request, |
|
inputs=[url_input, question_input, press_selector, pipe_selector, compression_slider], |
|
outputs=[output, output_num_tokens, output_compressed_num_tokens], |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo = gradio_interface() |
|
demo.launch() |
|
|