File size: 4,688 Bytes
f7cdfd6
 
 
 
e477c81
 
98d93be
84280ca
 
 
 
 
 
 
e477c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98d93be
e477c81
 
 
 
 
 
 
 
 
 
f7cdfd6
 
e477c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98d93be
e477c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7cdfd6
 
 
 
 
e477c81
f7cdfd6
 
 
 
 
 
e477c81
f7cdfd6
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import torch
from peft import PeftModel, PeftConfig, LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
# import torch
from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
import numpy as np

ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)

tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b")

ref_model.eval()

class KeywordsStoppingCriteria(StoppingCriteria):
    def __init__(self, keywords_ids:list):
        self.keywords = keywords_ids

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        if input_ids[0][-1] in self.keywords:
            return True
        return False


stop_words = ['>', ' >','> ']
stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
stop_criteria = KeywordsStoppingCriteria(stop_ids)


if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    ref_model.config.pad_token_id = ref_model.config.eos_token_id

# Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
probs_to_label = [
    (0.99, "99%"),
    (0.95, "95%"),
    (0.9, "90%"),
    (0.5, "50%"),
    (0.1, "10%"),
    (0.01, "1%"),
]




def get_tokens_and_labels(prompt):
    """
    Given the prompt (text), return a list of tuples (decoded_token, label)
    """
    inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
    outputs = ref_model.generate(
        **inputs,
        max_new_tokens=1000,
        return_dict_in_generate=True,
        output_scores=True,
        stopping_criteria=StoppingCriteriaList([stop_criteria])
    )
    # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
    transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
    transition_proba = np.exp(transition_scores.double().cpu())
    
    # print(transition_proba)
    # print(inputs)
    # We only have scores for the generated tokens, so pop out the prompt tokens
    input_length = inputs.input_ids.shape[1]
    generated_ids = outputs.sequences[:, input_length:]
    
    generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])

    # Important: you might need to find a tokenization character to replace (e.g. "Ġ" for BPE) and get the correct
    # spacing into the final output 👼
    if ref_model.config.is_encoder_decoder:
        highlighted_out = []
    else:
        input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
        highlighted_out = [(token.replace("▁", " "), None) for token in input_tokens]
    # Get the (decoded_token, label) pairs for the generated tokens
    for token, proba in zip(generated_tokens, transition_proba[0]):
        this_label = None
        assert 0. <= proba <= 1.0
        for min_proba, label in probs_to_label:
            if proba >= min_proba:
                this_label = label
                break
        highlighted_out.append((token.replace("▁", " "), this_label))

    return highlighted_out



import spacy
from spacy import displacy
from spacy.tokens import Span
from spacy.tokens import Doc



def render_output(prompt):
    output = get_tokens_and_labels(prompt)
    nlp = spacy.blank("en")
    doc = nlp(''.join([a[0] for a in output]).replace('Ġ',' ').replace('Ċ','\n'))
    words = [a[0].replace('Ġ',' ').replace('Ċ','\n') for a in output]#[:indices[2]]
    doc = Doc(nlp.vocab, words=words)

    doc.spans["sc"]=[]
    c = 0

    for outs in output:
        tmpouts = outs[0].replace('Ġ','').replace('Ċ','\n')
        # print(c, "to", c+len(tmpouts)," : ", tmpouts)

        if outs[1] is not None:
            doc.spans["sc"].append(Span(doc, c, c+1, outs[1] ))  

        c+=1

    #     if c>indices[2]-1:
    #         break


    options = {'colors' : {
            '99%': '#44ce1b',
            '95%': '#bbdb44',
            '90%': '#f7e379',
            '50%': '#fec12a',
            '10%': '#f2a134',
            '1%': '#e51f1f',
            '': '#e51f1f',
    }}

    return displacy.render(doc, style="span", options = options)




def predict(text):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = ref_model.generate(input_ids=inputs["input_ids"], max_new_tokens=128)
        out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
    
    return out_text.split(text)[-1]
    

demo = gr.Interface(
    fn=render_output, 
    inputs='text',
    outputs='text',
)

demo.launch()