w601sxs commited on
Commit
e477c81
·
1 Parent(s): 8785832

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -7
app.py CHANGED
@@ -4,27 +4,152 @@ from peft import PeftModel, PeftConfig, LoraConfig
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from datasets import load_dataset
6
  from trl import SFTTrainer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
9
- peft_model_id = "w601sxs/b1ade-1b-orca-chkpt-506k"
10
 
11
- config = PeftConfig.from_pretrained(peft_model_id)
12
- model = PeftModel.from_pretrained(ref_model, peft_model_id)
13
- tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- model.eval()
16
 
17
  def predict(text):
18
  inputs = tokenizer(text, return_tensors="pt")
19
  with torch.no_grad():
20
- outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=128)
21
  out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
22
 
23
  return out_text.split(text)[-1]
24
 
25
 
26
  demo = gr.Interface(
27
- fn=predict,
28
  inputs='text',
29
  outputs='text',
30
  )
 
4
  from transformers import AutoTokenizer, AutoModelForCausalLM
5
  from datasets import load_dataset
6
  from trl import SFTTrainer
7
+ # import torch
8
+ from transformers import StoppingCriteria, AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList
9
+
10
+ class KeywordsStoppingCriteria(StoppingCriteria):
11
+ def __init__(self, keywords_ids:list):
12
+ self.keywords = keywords_ids
13
+
14
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
15
+ if input_ids[0][-1] in self.keywords:
16
+ return True
17
+ return False
18
+
19
+
20
+ stop_words = ['>', ' >','> ']
21
+ stop_ids = [tokenizer.encode(w)[0] for w in stop_words]
22
+ stop_criteria = KeywordsStoppingCriteria(stop_ids)
23
+
24
+
25
+
26
+
27
+
28
+ import numpy as np
29
+
30
+ if tokenizer.pad_token_id is None:
31
+ tokenizer.pad_token_id = tokenizer.eos_token_id
32
+ model.config.pad_token_id = model.config.eos_token_id
33
+
34
+ # Define your color-coding labels; if prob > x, then label = y; Sorted in descending probability order!
35
+ probs_to_label = [
36
+ (0.99, "99%"),
37
+ (0.95, "95%"),
38
+ (0.9, "90%"),
39
+ (0.5, "50%"),
40
+ (0.1, "10%"),
41
+ (0.01, "1%"),
42
+
43
+ ]
44
 
45
  ref_model = AutoModelForCausalLM.from_pretrained("w601sxs/b1ade-1b", torch_dtype=torch.bfloat16)
 
46
 
47
+ tokenizer = AutoTokenizer.from_pretrained("w601sxs/b1ade-1b")
48
+
49
+ ref_model.eval()
50
+
51
+
52
+ def get_tokens_and_labels(prompt):
53
+ """
54
+ Given the prompt (text), return a list of tuples (decoded_token, label)
55
+ """
56
+ inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
57
+ outputs = ref_model.generate(
58
+ **inputs,
59
+ max_new_tokens=1000,
60
+ return_dict_in_generate=True,
61
+ output_scores=True,
62
+ stopping_criteria=StoppingCriteriaList([stop_criteria])
63
+ )
64
+ # Important: don't forget to set `normalize_logits=True` to obtain normalized probabilities (i.e. sum(p) = 1)
65
+ transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
66
+ transition_proba = np.exp(transition_scores.double().cpu())
67
+
68
+ # print(transition_proba)
69
+ # print(inputs)
70
+ # We only have scores for the generated tokens, so pop out the prompt tokens
71
+ input_length = inputs.input_ids.shape[1]
72
+ generated_ids = outputs.sequences[:, input_length:]
73
+
74
+ generated_tokens = tokenizer.convert_ids_to_tokens(generated_ids[0])
75
+
76
+ # Important: you might need to find a tokenization character to replace (e.g. "Ġ" for BPE) and get the correct
77
+ # spacing into the final output 👼
78
+ if model.config.is_encoder_decoder:
79
+ highlighted_out = []
80
+ else:
81
+ input_tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
82
+ highlighted_out = [(token.replace("▁", " "), None) for token in input_tokens]
83
+ # Get the (decoded_token, label) pairs for the generated tokens
84
+ for token, proba in zip(generated_tokens, transition_proba[0]):
85
+ this_label = None
86
+ assert 0. <= proba <= 1.0
87
+ for min_proba, label in probs_to_label:
88
+ if proba >= min_proba:
89
+ this_label = label
90
+ break
91
+ highlighted_out.append((token.replace("▁", " "), this_label))
92
+
93
+ return highlighted_out
94
+
95
+
96
+
97
+ import spacy
98
+ from spacy import displacy
99
+ from spacy.tokens import Span
100
+ from spacy.tokens import Doc
101
+
102
+
103
+
104
+ def render_output(prompt):
105
+ output = get_tokens_and_labels(prompt)
106
+ nlp = spacy.blank("en")
107
+ doc = nlp(''.join([a[0] for a in output]).replace('Ġ',' ').replace('Ċ','\n'))
108
+ words = [a[0].replace('Ġ',' ').replace('Ċ','\n') for a in output]#[:indices[2]]
109
+ doc = Doc(nlp.vocab, words=words)
110
+
111
+ doc.spans["sc"]=[]
112
+ c = 0
113
+
114
+ for outs in output:
115
+ tmpouts = outs[0].replace('Ġ','').replace('Ċ','\n')
116
+ # print(c, "to", c+len(tmpouts)," : ", tmpouts)
117
+
118
+ if outs[1] is not None:
119
+ doc.spans["sc"].append(Span(doc, c, c+1, outs[1] ))
120
+
121
+ c+=1
122
+
123
+ # if c>indices[2]-1:
124
+ # break
125
+
126
+
127
+ options = {'colors' : {
128
+ '99%': '#44ce1b',
129
+ '95%': '#bbdb44',
130
+ '90%': '#f7e379',
131
+ '50%': '#fec12a',
132
+ '10%': '#f2a134',
133
+ '1%': '#e51f1f',
134
+ '': '#e51f1f',
135
+ }}
136
+
137
+ return displacy.render(doc, style="span", options = options)
138
+
139
+
140
 
 
141
 
142
  def predict(text):
143
  inputs = tokenizer(text, return_tensors="pt")
144
  with torch.no_grad():
145
+ outputs = ref_model.generate(input_ids=inputs["input_ids"], max_new_tokens=128)
146
  out_text = tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0].split("answer:")[-1]
147
 
148
  return out_text.split(text)[-1]
149
 
150
 
151
  demo = gr.Interface(
152
+ fn=render_output,
153
  inputs='text',
154
  outputs='text',
155
  )