|
|
|
|
|
import spaces |
|
import logging |
|
import time |
|
import html |
|
import os |
|
import numpy as np |
|
import gradio as gr |
|
import util |
|
|
|
import huggingface_hub |
|
import torch |
|
import transformers |
|
import accelerate |
|
|
|
|
|
print('Dependency versions:') |
|
print(f'huggingface_hub=={huggingface_hub.__version__}') |
|
print(f'numpy=={np.__version__}') |
|
print(f'torch=={torch.__version__}') |
|
print(f'transformers=={transformers.__version__}') |
|
print(f'accelerate=={accelerate.__version__}') |
|
print() |
|
|
|
|
|
logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s', level=logging.WARNING) |
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
|
css = ''' |
|
.response-table { |
|
width: 100%; |
|
table-layout: fixed; |
|
} |
|
.response-table th, .response-table td { |
|
width: 50%; |
|
} |
|
.response-table td { |
|
font-family: monospace; |
|
white-space: pre-wrap; |
|
text-align: left; |
|
vertical-align: top; |
|
} |
|
.highlight { |
|
background-color: #90FF90; |
|
} |
|
''' |
|
|
|
def make_html_table(headers, data): |
|
rows = ['<tr>' + ''.join(f'<th>{h}</th>' for h in headers) + '</tr>\n'] |
|
for row in data: |
|
rows.append('<tr>' + ''.join(f'<td>{v}</td>' for v in row) + '</tr>\n') |
|
return '<table class="response-table">\n' + ''.join(rows) + '</table>\n' |
|
|
|
def highlight_prefix(tokens, prefix_len): |
|
prefix_tokens = tokens[:prefix_len] |
|
|
|
s = tokenizer.decode(tokens, skip_special_tokens=True) |
|
prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True) |
|
|
|
s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s))) |
|
|
|
prefix_html = html.escape(s[:s_lcp_len]) |
|
suffix_html = html.escape(s[s_lcp_len:]) |
|
|
|
return f'<span class="highlight">{prefix_html}</span>{suffix_html}' |
|
|
|
def format_response_pair(tokens_a, tokens_b): |
|
|
|
token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b) |
|
return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len) |
|
|
|
HEADERS = ['Response (Left)', 'Response (Right)'] |
|
repo_id = "Qwen/Qwen2-0.5B-Instruct" |
|
|
|
DRY_RUN = os.environ.get('DRY_RUN') == '1' |
|
|
|
if DRY_RUN: |
|
from load import load_tokenizer |
|
|
|
tokenizer = load_tokenizer(repo_id) |
|
|
|
def fn(max_tokens, num_responses, prompt_x, prompt_y): |
|
logger.info('Starting generation...') |
|
generation_start = time.perf_counter() |
|
|
|
rows = [['']*2 for i in range(num_responses)] |
|
|
|
yield make_html_table(HEADERS, rows) |
|
|
|
for j in range(num_responses): |
|
response_raw_a = f'Sure!\n\n1 2 3 4 & 5.' |
|
response_raw_b = f'Sure!\n\n1 2 3 4 5 &\n\n\n\n6.' |
|
|
|
response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0] |
|
response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0] |
|
|
|
steps = 1 + max(len(response_tok_a), len(response_tok_b)) |
|
|
|
for i in range(steps): |
|
time.sleep(0.01) |
|
prefix_tok_a = response_tok_a[:i] |
|
prefix_tok_b = response_tok_b[:i] |
|
|
|
content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b) |
|
|
|
rows[j][0] = content_a |
|
rows[j][1] = content_b |
|
|
|
yield make_html_table(HEADERS, rows) |
|
|
|
generation_end = time.perf_counter() |
|
logger.info(f'Generation took {(generation_end - generation_start):.3f} s') |
|
else: |
|
from load import load_model |
|
import algorithms |
|
|
|
|
|
model, tokenizer = load_model(repo_id) |
|
|
|
def make_chat(system_msg, prompt): |
|
chat = [ |
|
{ |
|
'role': 'system', |
|
'content': system_msg, |
|
}, |
|
{ |
|
'role': 'user', |
|
'content': prompt, |
|
}, |
|
] |
|
return chat |
|
|
|
@spaces.GPU |
|
def fn(max_tokens, num_responses, prompt_x, prompt_y): |
|
logger.info('Starting generation...') |
|
generation_start = time.perf_counter() |
|
|
|
|
|
torch.use_deterministic_algorithms(True) |
|
|
|
rows = [['']*2 for i in range(num_responses)] |
|
yield make_html_table(HEADERS, rows) |
|
|
|
for j in range(num_responses): |
|
system_msg = "You are a helpful assistant." |
|
|
|
chat_x = make_chat(system_msg, prompt_x) |
|
chat_y = make_chat(system_msg, prompt_y) |
|
|
|
gen = algorithms.apoc_streaming( |
|
model, |
|
model, |
|
tokenizer, |
|
chat_x, |
|
chat_y, |
|
max_tokens=max_tokens, |
|
) |
|
response_a_L = [] |
|
response_b_L = [] |
|
for token_a, token_b in gen: |
|
dirty = False |
|
if token_a is not None: |
|
response_a_L.append(token_a) |
|
dirty = True |
|
if token_b is not None: |
|
response_b_L.append(token_b) |
|
dirty = True |
|
|
|
if dirty: |
|
content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L)) |
|
|
|
rows[j][0] = content_a |
|
rows[j][1] = content_b |
|
|
|
yield make_html_table(HEADERS, rows) |
|
|
|
generation_end = time.perf_counter() |
|
logger.info(f'Generation took {(generation_end - generation_start):.3f} s') |
|
|
|
demo = gr.Interface( |
|
fn=fn, |
|
inputs=[ |
|
gr.Slider(1, 512, label='Max Tokens', value=48), |
|
gr.Slider(1, 16, step=1, label='Num Responses', value=8), |
|
gr.Textbox(label='Prompt (Left)'), |
|
gr.Textbox(label='Prompt (Right)'), |
|
], |
|
outputs=[ |
|
gr.HTML(), |
|
], |
|
css=css, |
|
title='All-Prefix-Optimal Coupling', |
|
description='Try similar prompts to see the effect of the difference between them. ' |
|
f'Model: `{repo_id}`.' |
|
, |
|
examples=[ |
|
[48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'], |
|
|
|
|
|
|
|
|
|
[48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'], |
|
[48, 8, "What's the capital of Canada?", "What's the capital of France?"], |
|
[48, 8, "1 3 5. What number is next?", "4 5 6. What number is next?"], |
|
], |
|
|
|
|
|
cache_examples=False, |
|
) |
|
|
|
demo.launch() |