b1ade-1b / app.py
w601sxs's picture
Update app.py
98d93be
raw
history blame
4.69 kB
import gradio as gr
import torch
from peft import PeftModel, PeftConfig, LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
import numpy as np
ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b")
ref_model.eval()
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids:list):
self.keywords = keywords_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
return True
return False
stop_words = ['>', ' >','> ']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
ref_model.config.pad_token_id = ref_model.config.eos_token_id
# Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
probs_to_label = [
(0.99, "99%"),
(0.95, "95%"),
(0.9, "90%"),
(0.5, "50%"),
(0.1, "10%"),
(0.01, "1%"),
]
def get_tokens_and_labels(prompt):
"""
Given the prompt (text), return a list of tuples (decoded_token, label)
"""
inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
outputs = ref_model.generate(
**inputs,
max_new_tokens=1000,
return_dict_in_generate=True,
output_scores=True,
stopping_criteria=StoppingCriteriaList([stop_criteria])
)
# Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
transition_proba = np.exp(transition_scores.double().cpu())
# print(transition_proba)
# print(inputs)
# We only have scores for the generated tokens, so pop out the prompt tokens
input_length = inputs.input_ids.shape[1]
generated_ids = outputs.sequences[:, input_length:]
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])
# Important: you might need to find a tokenization character to replace (e.g. "Ġ" for BPE) and get the correct
# spacing into the final output 👼
if ref_model.config.is_encoder_decoder:
highlighted_out = []
else:
input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
highlighted_out = [(token.replace("▁", " "), None) for token in input_tokens]
# Get the (decoded_token, label) pairs for the generated tokens
for token, proba in zip(generated_tokens, transition_proba[0]):
this_label = None
assert 0. <= proba <= 1.0
for min_proba, label in probs_to_label:
if proba >= min_proba:
this_label = label
break
highlighted_out.append((token.replace("▁", " "), this_label))
return highlighted_out
import spacy
from spacy import displacy
from spacy.tokens import Span
from spacy.tokens import Doc
def render_output(prompt):
output = get_tokens_and_labels(prompt)
nlp = spacy.blank("en")
doc = nlp(''.join([a[0] for a in output]).replace('Ġ',' ').replace('Ċ','\n'))
words = [a[0].replace('Ġ',' ').replace('Ċ','\n') for a in output]#[:indices[2]]
doc = Doc(nlp.vocab, words=words)
doc.spans["sc"]=[]
c = 0
for outs in output:
tmpouts = outs[0].replace('Ġ','').replace('Ċ','\n')
# print(c, "to", c+len(tmpouts)," : ", tmpouts)
if outs[1] is not None:
doc.spans["sc"].append(Span(doc, c, c+1, outs[1] ))
c+=1
# if c>indices[2]-1:
# break
options = {'colors' : {
'99%': '#44ce1b',
'95%': '#bbdb44',
'90%': '#f7e379',
'50%': '#fec12a',
'10%': '#f2a134',
'1%': '#e51f1f',
'': '#e51f1f',
}}
return displacy.render(doc, style="span", options = options)
def predict(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = ref_model.generate(input_ids=inputs["input_ids"], max_new_tokens=128)
out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
return out_text.split(text)[-1]
demo = gr.Interface(
fn=render_output,
inputs='text',
outputs='text',
)
demo.launch()