|
|
|
|
|
import logging |
|
import time |
|
import html |
|
import numpy as np |
|
import gradio as gr |
|
import util |
|
|
|
|
|
|
|
|
|
def make_html_table(headers, data): |
|
rows = ['<tr>' + ''.join(f'<th style="width: 50%">{h}</th>' for h in headers) + '</tr>\n'] |
|
for row in data: |
|
rows.append('<tr>' + ''.join(f'<td style="width: 50%; font-family: monospace; white-space: pre-wrap;">{v}</td>' for v in row) + '</tr>\n') |
|
return '<table style="width: 100%; table-layout: fixed">\n' + ''.join(rows) + '</table>\n' |
|
|
|
def highlight_prefix(tokens, prefix_len): |
|
prefix_tokens = tokens[:prefix_len] |
|
|
|
s = tokenizer.decode(tokens, skip_special_tokens=True) |
|
prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True) |
|
|
|
s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s))) |
|
|
|
prefix_html = html.escape(s[:s_lcp_len]) |
|
suffix_html = html.escape(s[s_lcp_len:]) |
|
|
|
|
|
|
|
highlight_style = 'background-color: #90FF90;' |
|
|
|
return f'<span style="{highlight_style}">{prefix_html}</span>{suffix_html}' |
|
|
|
def format_response_pair(tokens_a, tokens_b): |
|
|
|
token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b) |
|
return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len) |
|
|
|
HEADERS = ['Response (Left)', 'Response (Right)'] |
|
repo_id = "Qwen/Qwen2-0.5B-Instruct" |
|
|
|
DRY_RUN = False |
|
|
|
if DRY_RUN: |
|
from load import load_tokenizer |
|
|
|
tokenizer = load_tokenizer(repo_id) |
|
|
|
def fn(max_tokens, num_responses, prompt_x, prompt_y): |
|
rows = [['']*2 for i in range(num_responses)] |
|
|
|
yield make_html_table(HEADERS, rows) |
|
|
|
for j in range(num_responses): |
|
response_raw_a = f'Sure!\n\n1 2 3 4 & 5.' |
|
response_raw_b = f'Sure!\n\n1 2 3 4 5 & 6.' |
|
|
|
response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0] |
|
response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0] |
|
|
|
steps = 1 + max(len(response_tok_a), len(response_tok_b)) |
|
|
|
for i in range(steps): |
|
time.sleep(0.1) |
|
prefix_tok_a = response_tok_a[:i] |
|
prefix_tok_b = response_tok_b[:i] |
|
|
|
content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b) |
|
|
|
rows[j][0] = content_a |
|
rows[j][1] = content_b |
|
|
|
yield make_html_table(HEADERS, rows) |
|
else: |
|
from load import load_model |
|
import algorithms |
|
|
|
logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s') |
|
algorithms.logger.setLevel(logging.INFO) |
|
|
|
model, tokenizer = load_model(repo_id) |
|
|
|
def make_chat(system_msg, prompt): |
|
chat = [ |
|
{ |
|
'role': 'system', |
|
'content': system_msg, |
|
}, |
|
{ |
|
'role': 'user', |
|
'content': prompt, |
|
}, |
|
] |
|
return chat |
|
|
|
def fn(max_tokens, num_responses, prompt_x, prompt_y): |
|
rows = [['']*2 for i in range(num_responses)] |
|
yield make_html_table(HEADERS, rows) |
|
|
|
for j in range(num_responses): |
|
system_msg = "You are a helpful assistant." |
|
|
|
chat_x = make_chat(system_msg, prompt_x) |
|
chat_y = make_chat(system_msg, prompt_y) |
|
|
|
gen = algorithms.apoc_streaming( |
|
'cpu', |
|
model, |
|
model, |
|
tokenizer, |
|
chat_x, |
|
chat_y, |
|
max_tokens=max_tokens, |
|
) |
|
response_a_L = [] |
|
response_b_L = [] |
|
for token_a, token_b in gen: |
|
dirty = False |
|
if token_a is not None: |
|
response_a_L.append(token_a) |
|
dirty = True |
|
if token_b is not None: |
|
response_b_L.append(token_b) |
|
dirty = True |
|
|
|
if dirty: |
|
content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L)) |
|
|
|
rows[j][0] = content_a |
|
rows[j][1] = content_b |
|
|
|
yield make_html_table(HEADERS, rows) |
|
|
|
demo = gr.Interface( |
|
fn=fn, |
|
inputs=[ |
|
gr.Slider(1, 512, label='Max Tokens', value=48), |
|
gr.Slider(1, 16, step=1, label='Num Responses', value=8), |
|
gr.Textbox(label='Prompt (Left)'), |
|
gr.Textbox(label='Prompt (Right)'), |
|
], |
|
outputs=[ |
|
gr.HTML(), |
|
], |
|
title='All-Prefix-Optimal Coupling', |
|
description='Try similar prompts to see the effect of the difference between them. ' |
|
f'Model: `{repo_id}`.' |
|
, |
|
examples=[ |
|
[48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'], |
|
|
|
|
|
|
|
|
|
[48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'], |
|
[48, 8, "What's the capital of Canada?", "What's the capital of France?"], |
|
], |
|
|
|
|
|
cache_examples=False, |
|
) |
|
|
|
demo.launch() |