Spaces:
Running
Running
import gradio as gr | |
import torch | |
from peft import PeftModel, PeftConfig, LoraConfig | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from datasets import load_dataset | |
from trl import SFTTrainer | |
# import torch | |
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList | |
ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b-bf16", torch_dtype=torch.bfloat16) | |
ref_model = ref_model | |
ref_model.eval() | |
tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b-bf16") | |
class KeywordsStoppingCriteria(StoppingCriteria): | |
def __init__(self, keywords_ids:list): | |
self.keywords = keywords_ids | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
if input_ids[0][-1] in self.keywords: | |
return True | |
return False | |
stop_words = ['>', ' >','> '] | |
stop_ids = [tokenizer.encode(w)[0] for w in stop_words] | |
stop_criteria = KeywordsStoppingCriteria(stop_ids) | |
import numpy as np | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
ref_model.config.pad_token_id = ref_model.config.eos_token_id | |
# Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order! | |
probs_to_label = [ | |
(0.99, "99%"), | |
(0.95, "95%"), | |
(0.9, "90%"), | |
(0.5, "50%"), | |
(0.1, "10%"), | |
(0.01, "1%"), | |
] | |
import numpy as np | |
def get_tokens_and_labels(prompt): | |
""" | |
Given the prompt (text), return a list of tuples (decoded_token, label) | |
""" | |
inputs = tokenizer([prompt], return_tensors="pt") | |
outputs = ref_model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
return_dict_in_generate=True, | |
output_scores=True, | |
stopping_criteria=StoppingCriteriaList([stop_criteria],do_sample=True, top_p=0.2) | |
) | |
# Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1) | |
transition_scores = ref_model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True) | |
transition_proba = np.exp(transition_scores.double().cpu()) | |
# print(transition_proba) | |
# print(inputs) | |
# We only have scores for the generated tokens, so pop out the prompt tokens | |
input_length = inputs.input_ids.shape[1] | |
generated_ids = outputs.sequences[:, input_length:] | |
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0]) | |
# Important: you might need to find a tokenization character to replace (e.g. "Ġ" for BPE) and get the correct | |
# spacing into the final output 👼 | |
if ref_model.config.is_encoder_decoder: | |
highlighted_out = [] | |
else: | |
input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
highlighted_out = [(token.replace("▁", " "), None) for token in input_tokens] | |
# Get the (decoded_token, label) pairs for the generated tokens | |
for token, proba in zip(generated_tokens, transition_proba[0]): | |
this_label = None | |
assert 0. <= proba <= 1.0 | |
for min_proba, label in probs_to_label: | |
if proba >= min_proba: | |
this_label = label | |
break | |
highlighted_out.append((token.replace("▁", " "), this_label)) | |
return highlighted_out | |
import spacy | |
from spacy import displacy | |
from spacy.tokens import Span | |
from spacy.tokens import Doc | |
def render_output(context, question): | |
output = get_tokens_and_labels(f"context:<{context}>\nquestion:<{question}>\nanswer:<") | |
nlp = spacy.blank("en") | |
doc = nlp(''.join([a[0] for a in output]).replace('Ġ',' ').replace('Ċ','\n')) | |
words = [a[0].replace('Ġ',' ').replace('Ċ','\n') for a in output]#[:indices[2]] | |
doc = Doc(nlp.vocab, words=words) | |
doc.spans["sc"]=[] | |
c = 0 | |
for outs in output: | |
tmpouts = outs[0].replace('Ġ','').replace('Ċ','\n') | |
# print(c, "to", c+len(tmpouts)," : ", tmpouts) | |
if outs[1] is not None: | |
doc.spans["sc"].append(Span(doc, c, c+1, outs[1] )) | |
c+=1 | |
# if c>indices[2]-1: | |
# break | |
options = {'colors' : { | |
'99%': '#44ce1b', | |
'95%': '#bbdb44', | |
'90%': '#f7e379', | |
'50%': '#fec12a', | |
'10%': '#f2a134', | |
'1%': '#e51f1f', | |
'': '#e51f1f', | |
}} | |
return displacy.render(doc, style="span", options = options) | |
def predict(text): | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = ref_model.generate(input_ids=inputs["input_ids"], max_new_tokens=256, top_p=0.2) | |
out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1] | |
return out_text.split(text)[-1] | |
demo = gr.Interface( | |
fn=render_output, | |
inputs=[gr.Textbox(label='context',value='As an AI assistant provide helpful, accurate and detailed answers to user questions'),gr.Textbox(label='question')], | |
outputs='html', | |
examples=[['As an AI assistant provide helpful, accurate and detailed answers to user questions','Given the fact that Inhaling, or breathing in, increases the size of the chest, which decreases air pressure inside the lungs. If Mona is done with a race and her chest contracts, what happens to the amount of air pressure in her lungs increases or decreases?'],['As an AI assistant provide helpful, accurate and detailed answers to user questions','In this task, you\'re given the title of a five-sentence story, the first four sentences, and two options for the fifth sentence as a and b. Your job is to pick the sentence option that seamlessly connects with the rest of the story, indicating your choice as "a" or "b". If both sentences are plausible, pick the one that makes more sense. Title: Missing Radio. Sentence 1: Josh was very sad to find out he could not find his radio. Sentence 2: He searched all day and night. Sentence 3: He even went back to school to find his radio. Sentence 4: Later on, someone turned in his radio to the lost and found. Choices: a. Once he got his new car, Stuart was very happy and relieved. b. Now, James was able to listen to the game on his radio'],['Given a statement and question, generate the answer to the question such that the answer is contained in the statement.','statement: internet is used for signals, question: Internet is used for what?'],['The center contact of the bulb typically connects to the medium-power filament, and the ring connects to the low-power filament. Thus, if a 3-way bulb is screwed into a standard light socket that has only a center contact, only the medium-power filament operates. In the case of the 50 W / 100 W / 150 W bulb, putting this bulb in a regular lamp socket will result in it behaving like a normal 100W bulb.','Question: Do 3 way light bulbs work in any lamp?'],['In this task, find the most appropriate number to replace the blank (indicated with _ ) and express it in words. ','Quarks can have _ colors, red, green and blue.'],['The private companies responsible for the most emissions during this period, according to the database, are from the United States: ExxonMobil, Chevron and Peabody.The largest emitter amongst state-owned companies in the Americas is Mexican company Pemex, followed by Venezuelan company Petróleos de Venezuela, S.A.','Which private companies in the Americas are the largest GHG emitters according to the Carbon Majors database?']]) | |
demo.launch() |