import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer from threading import Thread import json import torch import os import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import to_hex import itertools import transformers import time transformers.logging.set_verbosity_error() # Variable to define number of instances n_instances = 1 gpu_name = "CPU" for i in range(torch.cuda.device_count()): gpu_name = torch.cuda.get_device_properties(i).name # Reusing the original MLP class and other functions (unchanged) except those specific to Streamlit class MLP(torch.nn.Module): def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False): super(MLP, self).__init__() self.fc1 = torch.nn.Linear(input_dim, hidden_dim) self.fc3 = torch.nn.Linear(hidden_dim, output_dim) self.layer_id = layer_id if cuda: self.device = "cuda" else: self.device = "cpu" self.to(self.device) def forward(self, x): x = torch.flatten(x, start_dim=1) x = torch.relu(self.fc1(x)) x = self.fc3(x) return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach() def map_value_to_color(value, colormap_name='tab20c'): value = np.clip(value, 0.0, 1.0) colormap = plt.get_cmap(colormap_name) rgba_color = colormap(value) css_color = to_hex(rgba_color) return css_color # Caching functions for model and classifier model_cache = {} def get_multiple_model_and_tokenizer(name, n_instances): model_instances = [] for _ in range(n_instances): tok = AutoTokenizer.from_pretrained(name, token=os.getenv('HF_TOKEN'), pad_token_id=128001) model = AutoModelForCausalLM.from_pretrained(name, token=os.getenv('HF_TOKEN'), torch_dtype="bfloat16", pad_token_id=128001, device_map="auto") if torch.cuda.is_available(): model.cuda() model_instances.append((model, tok)) return model_instances def get_classifiers_for_model(att_size, emb_size, device, config_paths): config = { "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")), "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r")) } layer_id = config["classifier_token"]["layer"] classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device) classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device, weights_only=True)) classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device) classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device, weights_only=True)) return classifier_span, classifier_token, config["classifier_token"]["label_map"] def find_datasets_and_model_ids(root_dir): datasets = {} for root, dirs, files in os.walk(root_dir): if 'config.json' in files and 'stoke_config.json' in files: config_path = os.path.join(root, 'config.json') stoke_config_path = os.path.join(root, 'stoke_config.json') with open(config_path, 'r') as f: config_data = json.load(f) model_id = config_data.get('model_id') if model_id: dataset_name = os.path.basename(os.path.dirname(config_path)) with open(stoke_config_path, 'r') as f: stoke_config_data = json.load(f) if model_id: dataset_name = os.path.basename(os.path.dirname(stoke_config_path)) datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data return datasets def filter_spans(spans_and_values): if spans_and_values == []: return [], [] # Create a dictionary to store spans based on their second index values span_dict = {} spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values] # Iterate through the spans and update the dictionary with the highest value for span, value in zip(spans, values): start, end = span if start > end or end - start > 15 or start == 0: continue current_value = span_dict.get(end, None) if current_value is None or current_value[1] < value: span_dict[end] = (span, value) if span_dict == {}: return [], [] # Extract the filtered spans and values filtered_spans, filtered_values = zip(*span_dict.values()) return list(filtered_spans), list(filtered_values) def remove_overlapping_spans(spans): # Sort the spans based on their end points sorted_spans = sorted(spans, key=lambda x: x[0][1]) non_overlapping_spans = [] last_end = float('-inf') # Iterate through the sorted spans for span in sorted_spans: start, end = span[0] value = span[1] # If the current span does not overlap with the previous one if start >= last_end: non_overlapping_spans.append(span) last_end = end else: # If it overlaps, choose the one with the highest value existing_span_index = -1 for i, existing_span in enumerate(non_overlapping_spans): if existing_span[0][1] <= start: existing_span_index = i break if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value: non_overlapping_spans[existing_span_index] = span return non_overlapping_spans def generate_html_no_overlap(tokenized_text, spans): current_index = 0 html_content = "" for (span_start, span_end), value in spans: # Add text before the span html_content += "".join(tokenized_text[current_index:span_start]) # Add the span with underlining html_content += "" html_content += "".join(tokenized_text[span_start:span_end]) html_content += " " current_index = span_end # Add any remaining text after the last span html_content += "".join(tokenized_text[current_index:]) return html_content css = """ """ def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer, new_tags): # spanwise annotated text annotated = [] span_ends = -1 in_span = False out_of_span_tokens = [] for i in reversed(range(len(tokenwise_preds))): if in_span: if i >= span_ends: continue else: in_span = False predicted_class = "" style = "" span = None for s in spans: if s[1] == i+1: span = s if tokenwise_preds[i] != 0 and span is not None: predicted_class = f"highlight spanhighlight" style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}" if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "): annotated.append("Ġ") span_opener = f"Ġ".replace(" ", "Ġ") span_end = f"{new_tags[tokenwise_preds[i]]}" annotated.extend(out_of_span_tokens) out_of_span_tokens = [] span_ends = span[0] in_span = True annotated.append(span_end) annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))]) annotated.append(span_opener) else: out_of_span_tokens.append(token_strings[i]) annotated.extend(out_of_span_tokens) return [x for x in reversed(annotated)] def gen_json(input_text, max_new_tokens): streamer = STOKEStreamer(tok, classifier_token, classifier_span) new_tags = label_map inputs = tok([f" {input_text}"], return_tensors="pt").to(model.device) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, repetition_penalty=1.2, do_sample=False ) def generate_async(): model.generate(**generation_kwargs) thread = Thread(target=generate_async) thread.start() # Display generated text as it becomes available output_text = "" text_tokenwise = "" text_spans = "" removed_spans = "" tags = [] spans = [] for new_text in streamer: if new_text[1] is not None and new_text[2] != ['']: text_tokenwise = "" output_text = "" tags.extend(new_text[1]) spans.extend(new_text[-1]) # Tokenwise Classification for tk, pred in zip(new_text[2],tags): if pred != 0: style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}" if tk.startswith(" "): text_tokenwise += " " text_tokenwise += f"{tk}" output_text += tk else: text_tokenwise += tk output_text += tk # Span Classification text_spans = "" if len(spans) > 0: filtered_spans = remove_overlapping_spans(spans) text_spans = generate_html_no_overlap(new_text[2], filtered_spans) if len(spans) - len(filtered_spans) > 0: removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap." else: for tk in new_text[2]: text_spans += f"{tk}" # Spanwise Classification annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags) generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "") output = f"{css}
" output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n
" #output += "
Show tokenwise classification
\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$").replace("<|endoftext|>", "").replace("<|begin_of_text|>", "") #output += "
Show spans\n" + text_spans.replace("\n", " ").replace("$", "\\$") #if removed_spans != "": # output += f"

({removed_spans})" list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"] out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "".strip()), "entites": list_of_spans} yield out_dict return # Gradio app function to generate text using the assigned model instance def generate_text(input_text, max_new_tokens=2): if input_text == "": yield "Please enter some text first." return # Select the next model instance in a round-robin manner model, tok = next(model_round_robin) streamer = STOKEStreamer(tok, classifier_token, classifier_span) new_tags = label_map inputs = tok([f"{input_text[:200]}"], return_tensors="pt").to(model.device) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, repetition_penalty=1.2, do_sample=False, temperature=None, top_p=None ) def generate_async(): model.generate(**generation_kwargs) thread = Thread(target=generate_async) thread.start() # Display generated text as it becomes available output_text = "" text_tokenwise = "" text_spans = "" removed_spans = "" tags = [] spans = [] for new_text in streamer: if new_text[1] is not None and new_text[2] != ['']: text_tokenwise = "" output_text = "" tags.extend(new_text[1]) spans.extend(new_text[-1]) # Tokenwise Classification for tk, pred in zip(new_text[2],tags): if pred != 0: style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}" if tk.startswith(" "): text_tokenwise += " " text_tokenwise += f"{tk}" output_text += tk else: text_tokenwise += tk output_text += tk # Span Classification text_spans = "" if len(spans) > 0: filtered_spans = remove_overlapping_spans(spans) text_spans = generate_html_no_overlap(new_text[2], filtered_spans) if len(spans) - len(filtered_spans) > 0: removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap." else: for tk in new_text[2]: text_spans += f"{tk}" # Spanwise Classification annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags) generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "") output = f"{css}

" output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n
" list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"] out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "").strip(), "entites": list_of_spans} output_tokenwise = f"""{css}
""" output_tokenwise += """""" for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): span = "" if i in [x[0][1]-2 for x in spans] and pred != 0: top_span = [x for x in spans if x[0][1]-2 == i][0] spanstring = ''.join(new_text[2][top_span[0][0]:top_span[0][1]]) color = map_value_to_color((pred-1)/(len(new_tags)-1)) + "88" span = f"{spanstring}{new_tags[pred]}" output_tokenwise += f"" else: output_tokenwise += f"" output_tokenwise += "" output_tokenwise += """""" for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[:])): span = "" if i in [x[0][1]-1 for x in spans]: top_span = [x for x in spans if x[0][1]-1 == i][0] spanstring = ''.join(new_text[2][top_span[0][0]:top_span[0][1]]) span = f"{spanstring}" output_tokenwise += f"" else: output_tokenwise += f"" output_tokenwise += "" output_tokenwise += """""" for tk, pred in zip(new_text[2][1:],tags[1:]): style = "background-color: lightgrey;" if pred != 0: style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))};" output_tokenwise += f"" else: output_tokenwise += "" #output_tokenwise += f"" output_tokenwise += "" for tk, pred in zip(new_text[2][1:],tags[1:]): output_tokenwise += f"" output_tokenwise += "" for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): style = "border-color: lightgray;background-color: transparent;" if i in [x[0][1]-1 for x in spans]: style = "background-color: yellow;" output_tokenwise += f"" output_tokenwise += "" for tk, pred in zip(new_text[2][1:],tags[1:]): if pred != 0: style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}" output_tokenwise += f"" else: output_tokenwise += f"" output_tokenwise += "" for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): style = "border-color: lightgray;background-color: transparent;" if i in [x[0][1]-1 for x in spans]: style = "background-color: yellow;" output_tokenwise += f"" output_tokenwise += "" for tk, pred in zip(new_text[2][1:],tags[1:]): output_tokenwise += f"" output_tokenwise += "" for i, (tk, pred) in enumerate(zip(new_text[2][1:],tags[1:])): style = "border-color: lightgray;background-color: transparent;" if i in [x[0][1]-1 for x in spans]: style = "background-color: yellow;" output_tokenwise += f"" output_tokenwise += "" for tk, pred in zip(new_text[2][1:],tags[1:]): output_tokenwise += f"" output_tokenwise += "" #yield output + "" yield output_tokenwise + "
Span detection + label propagation{span}
Span detection{span}
Tokenwise
entity typing
{new_tags[pred]}
{tk}
" #time.sleep(0.5) return # Load datasets and models for the Gradio app datasets = find_datasets_and_model_ids("data/") available_models = list(datasets.keys()) available_datasets = {model: list(datasets[model].keys()) for model in available_models} available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models} def update_datasets(model_name): return available_datasets[model_name] def update_configs(model_name, dataset_name): return available_configs[model_name][dataset_name] # Load datasets and models for the Gradio app datasets = find_datasets_and_model_ids("data/") available_models = list(datasets.keys()) available_datasets = {model: list(datasets[model].keys()) for model in available_models} available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models} # Set the model ID and data configurations model_id = "meta-llama/Llama-3.2-1B" data_id = "STOKE_100" config_id = "default" # Load n_instances separate instances of the model and tokenizer model_instances = get_multiple_model_and_tokenizer(model_id, n_instances) # Set up the round-robin iterator to distribute the requests across model instances model_round_robin = itertools.cycle(model_instances) # Load model classifiers try: classifier_span, classifier_token, label_map = get_classifiers_for_model( model_instances[0][0].config.n_head * model_instances[0][0].config.n_layer, model_instances[0][0].config.n_embd, model_instances[0][0].device, datasets[model_id][data_id][config_id] ) except: classifier_span, classifier_token, label_map = get_classifiers_for_model( model_instances[0][0].config.num_attention_heads * model_instances[0][0].config.num_hidden_layers, model_instances[0][0].config.hidden_size, model_instances[0][0].device, datasets[model_id][data_id][config_id] ) initial_output = (css+"""
Span detection + label propagationThe New York Film FestivalEVENT
Span detectionThe New York Film Festival
Tokenwise
entity typing
GPEORGORGEVENT
The New York Film Festival is an annual
""", {'text': 'Miami is a city in the U.S. state of Florida, and it\'s also known as "The Magic City." It was founded by Henry Flagler on October 28th, 1896.', 'entites': [{'name': 'Miami', 'type': 'GPE'}, {'name': 'U.S.', 'type': 'GPE'}, {'name': 'Florida', 'type': 'GPE'}, {'name': 'The Magic City', 'type': 'WORK_OF_ART'}, {'name': 'Henry Flagler', 'type': 'PERSON'}, {'name': 'October 28th, 1896', 'type': 'DATE'}]}) with gr.Blocks(css="footer{display:none !important} .gradio-container {padding: 0!important; height:400px;}", fill_width=True, fill_height=True) as demo: with gr.Tab("EMBER Demo"): with gr.Row(): output_text = gr.HTML(label="Generated Text", value=initial_output[0]) with gr.Group(): with gr.Row(): input_text = gr.Textbox(label="Try with your own text!", value="The New York Film Festival is an", max_length=40, submit_btn=True) # New HTML output for model info model_info_html = gr.HTML( label="Model Info", value=f'
{model_id} running on {gpu_name}
' ) input_text.submit( fn=generate_text, inputs=[input_text], outputs=[output_text], concurrency_limit=n_instances, concurrency_id="queue" ) # Function to refresh the model info HTML def refresh_model_info(): return f'
{model_id} running on {gpu_name}
' # Update the model info HTML on button click input_text.submit( fn=refresh_model_info, inputs=[], outputs=[model_info_html], queue=False ) demo.queue() demo.launch()