# Gradio demo of streaming generation of multiple LLM response pairs. import logging import time import html import numpy as np import gradio as gr import util # gr.DataFrame is currently bugged for updating values, # so we must use raw HTML. # https://github.com/gradio-app/gradio/issues/8160 def make_html_table(headers, data): rows = ['' + ''.join(f'{h}' for h in headers) + '\n'] for row in data: rows.append('' + ''.join(f'{v}' for v in row) + '\n') return '\n' + ''.join(rows) + '
\n' def highlight_prefix(tokens, prefix_len): prefix_tokens = tokens[:prefix_len] s = tokenizer.decode(tokens, skip_special_tokens=True) prefix_s = tokenizer.decode(prefix_tokens, skip_special_tokens=True) s_lcp_len = util.longest_common_prefix(np.array(list(s)), np.array(list(prefix_s))) prefix_html = html.escape(s[:s_lcp_len]) suffix_html = html.escape(s[s_lcp_len:]) #highlight_style = 'background-color: #FFFFAE;' #highlight_style = 'text-decoration: underline;' highlight_style = 'background-color: #90FF90;' return f'{prefix_html}{suffix_html}' def format_response_pair(tokens_a, tokens_b): # This is slightly convoluted, so as to properly handle grapheme clusters that span token boundaries. token_lcp_len = util.longest_common_prefix(tokens_a, tokens_b) return highlight_prefix(tokens_a, token_lcp_len), highlight_prefix(tokens_b, token_lcp_len) HEADERS = ['Response (Left)', 'Response (Right)'] repo_id = "Qwen/Qwen2-0.5B-Instruct" DRY_RUN = False if DRY_RUN: from load import load_tokenizer tokenizer = load_tokenizer(repo_id) def fn(max_tokens, num_responses, prompt_x, prompt_y): rows = [['']*2 for i in range(num_responses)] yield make_html_table(HEADERS, rows) for j in range(num_responses): response_raw_a = f'Sure!\n\n1 2 3 4 & 5.' response_raw_b = f'Sure!\n\n1 2 3 4 5 & 6.' response_tok_a = tokenizer.encode(response_raw_a, add_special_tokens=False, return_tensors='np')[0] response_tok_b = tokenizer.encode(response_raw_b, add_special_tokens=False, return_tensors='np')[0] steps = 1 + max(len(response_tok_a), len(response_tok_b)) for i in range(steps): time.sleep(0.1) prefix_tok_a = response_tok_a[:i] prefix_tok_b = response_tok_b[:i] content_a, content_b = format_response_pair(prefix_tok_a, prefix_tok_b) rows[j][0] = content_a rows[j][1] = content_b yield make_html_table(HEADERS, rows) else: from load import load_model import algorithms logging.basicConfig(format='%(levelname)s:%(name)s: %(message)s') algorithms.logger.setLevel(logging.INFO) model, tokenizer = load_model(repo_id) def make_chat(system_msg, prompt): chat = [ { 'role': 'system', 'content': system_msg, }, { 'role': 'user', 'content': prompt, }, ] return chat def fn(max_tokens, num_responses, prompt_x, prompt_y): rows = [['']*2 for i in range(num_responses)] yield make_html_table(HEADERS, rows) for j in range(num_responses): system_msg = "You are a helpful assistant." chat_x = make_chat(system_msg, prompt_x) chat_y = make_chat(system_msg, prompt_y) gen = algorithms.apoc_streaming( 'cpu', model, model, tokenizer, chat_x, chat_y, max_tokens=max_tokens, ) response_a_L = [] response_b_L = [] for token_a, token_b in gen: dirty = False if token_a is not None: response_a_L.append(token_a) dirty = True if token_b is not None: response_b_L.append(token_b) dirty = True if dirty: content_a, content_b = format_response_pair(np.array(response_a_L), np.array(response_b_L)) rows[j][0] = content_a rows[j][1] = content_b yield make_html_table(HEADERS, rows) demo = gr.Interface( fn=fn, inputs=[ gr.Slider(1, 512, label='Max Tokens', value=48), gr.Slider(1, 16, step=1, label='Num Responses', value=8), gr.Textbox(label='Prompt (Left)'), gr.Textbox(label='Prompt (Right)'), ], outputs=[ gr.HTML(), ], title='All-Prefix-Optimal Coupling', description='Try similar prompts to see the effect of the difference between them. ' f'Model: `{repo_id}`.' , examples=[ [48, 8, 'Count from 1 to 5.', 'Count from 1 to 6.'], # This would be a good example, but Qwen2-0.5B occasionally goes off-color. #[48, 8, 'Tell me a joke.', 'Tell me a funny joke.'], [48, 8, 'Calculate 3 + 4', 'Calculate 3 + 5'], [48, 8, "What's the capital of Canada?", "What's the capital of France?"], ], # In HuggingFace Spaces, this defaults to true, which makes startup # take a very long time. cache_examples=False, ) demo.launch()