b1ade-1b / app.py
w601sxs's picture
Update app.py
f11c41a verified
import gradio as gr
import torch
from peft import PeftModel, PeftConfig, LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTTrainer
# import torch
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b-bf16", torch_dtype=torch.bfloat16)
ref_model = ref_model
ref_model.eval()
tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b-bf16")
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_ids:list):
self.keywords = keywords_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if input_ids[0][-1] in self.keywords:
return True
return False
stop_words = ['>', ' >','> ']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)
import numpy as np
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
ref_model.config.pad_token_id = ref_model.config.eos_token_id
# Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
probs_to_label = [
(0.99, "99%"),
(0.95, "95%"),
(0.9, "90%"),
(0.5, "50%"),
(0.1, "10%"),
(0.01, "1%"),
]
import numpy as np
def get_tokens_and_labels(prompt):
"""
Given the prompt (text), return a list of tuples (decoded_token, label)
"""
inputs = tokenizer([prompt], return_tensors="pt")
outputs = ref_model.generate(
**inputs,
max_new_tokens=1024,
return_dict_in_generate=True,
output_scores=True,
stopping_criteria=StoppingCriteriaList([stop_criteria],do_sample=True, top_p=0.2)
)
# Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
transition_scores = ref_model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
transition_proba = np.exp(transition_scores.double().cpu())
# print(transition_proba)
# print(inputs)
# We only have scores for the generated tokens, so pop out the prompt tokens
input_length = inputs.input_ids.shape[1]
generated_ids = outputs.sequences[:, input_length:]
generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])
# Important: you might need to find a tokenization character to replace (e.g. "Ġ" for BPE) and get the correct
# spacing into the final output 👼
if ref_model.config.is_encoder_decoder:
highlighted_out = []
else:
input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
highlighted_out = [(token.replace("▁", " "), None) for token in input_tokens]
# Get the (decoded_token, label) pairs for the generated tokens
for token, proba in zip(generated_tokens, transition_proba[0]):
this_label = None
assert 0. <= proba <= 1.0
for min_proba, label in probs_to_label:
if proba >= min_proba:
this_label = label
break
highlighted_out.append((token.replace("▁", " "), this_label))
return highlighted_out
import spacy
from spacy import displacy
from spacy.tokens import Span
from spacy.tokens import Doc
def render_output(context, question):
output = get_tokens_and_labels(f"context:<{context}>\nquestion:<{question}>\nanswer:<")
nlp = spacy.blank("en")
doc = nlp(''.join([a[0] for a in output]).replace('Ġ',' ').replace('Ċ','\n'))
words = [a[0].replace('Ġ',' ').replace('Ċ','\n') for a in output]#[:indices[2]]
doc = Doc(nlp.vocab, words=words)
doc.spans["sc"]=[]
c = 0
for outs in output:
tmpouts = outs[0].replace('Ġ','').replace('Ċ','\n')
# print(c, "to", c+len(tmpouts)," : ", tmpouts)
if outs[1] is not None:
doc.spans["sc"].append(Span(doc, c, c+1, outs[1] ))
c+=1
# if c>indices[2]-1:
# break
options = {'colors' : {
'99%': '#44ce1b',
'95%': '#bbdb44',
'90%': '#f7e379',
'50%': '#fec12a',
'10%': '#f2a134',
'1%': '#e51f1f',
'': '#e51f1f',
}}
return displacy.render(doc, style="span", options = options)
def predict(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = ref_model.generate(input_ids=inputs["input_ids"], max_new_tokens=256, top_p=0.2)
out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
return out_text.split(text)[-1]
demo = gr.Interface(
fn=render_output,
inputs=[gr.Textbox(label='context',value='As an AI assistant provide helpful, accurate and detailed answers to user questions'),gr.Textbox(label='question')],
outputs='html',
examples=[['As an AI assistant provide helpful, accurate and detailed answers to user questions','Given the fact that Inhaling, or breathing in, increases the size of the chest, which decreases air pressure inside the lungs. If Mona is done with a race and her chest contracts, what happens to the amount of air pressure in her lungs increases or decreases?'],['As an AI assistant provide helpful, accurate and detailed answers to user questions','In this task, you\'re given the title of a five-sentence story, the first four sentences, and two options for the fifth sentence as a and b. Your job is to pick the sentence option that seamlessly connects with the rest of the story, indicating your choice as "a" or "b". If both sentences are plausible, pick the one that makes more sense. Title: Missing Radio. Sentence 1: Josh was very sad to find out he could not find his radio. Sentence 2: He searched all day and night. Sentence 3: He even went back to school to find his radio. Sentence 4: Later on, someone turned in his radio to the lost and found. Choices: a. Once he got his new car, Stuart was very happy and relieved. b. Now, James was able to listen to the game on his radio'],['Given a statement and question, generate the answer to the question such that the answer is contained in the statement.','statement: internet is used for signals, question: Internet is used for what?'],['The center contact of the bulb typically connects to the medium-power filament, and the ring connects to the low-power filament. Thus, if a 3-way bulb is screwed into a standard light socket that has only a center contact, only the medium-power filament operates. In the case of the 50 W / 100 W / 150 W bulb, putting this bulb in a regular lamp socket will result in it behaving like a normal 100W bulb.','Question: Do 3 way light bulbs work in any lamp?'],['In this task, find the most appropriate number to replace the blank (indicated with _ ) and express it in words. ','Quarks can have _ colors, red, green and blue.'],['The private companies responsible for the most emissions during this period, according to the database, are from the United States: ExxonMobil, Chevron and Peabody.The largest emitter amongst state-owned companies in the Americas is Mexican company Pemex, followed by Venezuelan company Petróleos de Venezuela, S.A.','Which private companies in the Americas are the largest GHG emitters according to the Carbon Majors database?']])
demo.launch()