JerryLiJinyi commited on
Commit
c6a14bf
1 Parent(s): 10b912d

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +42 -0
  2. compressor.py +65 -0
  3. llmlingua_compressor_pro.py +1152 -0
  4. longlingua_compressor.py +1150 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from compressor import PromptCompressor
3
+
4
+
5
+ def compressit(original_text, compressor1, ratio, maxlength):
6
+
7
+ if compressor1=="Selective Context":
8
+ compressor = PromptCompressor(type='SCCompressor', lang='en', model='gpt2', device='cuda')
9
+ elif compressor1=="LLMLingua":
10
+ return "Sorry, currently we cannot provide services for LLMLingua due to the Huggingface Token issue. Please try other compressors."
11
+ elif compressor1=="LongLLMLingua":
12
+ return "Sorry, currently we cannot provide services for LongLLMLingua due to the Huggingface Token issue. Please try other compressors."
13
+ elif compressor1=="SCRL":
14
+ compressor = PromptCompressor(type='SCRLCompressor', model_dir="models/gigaword-L8/", device="cuda", tokenizer_dir="sentence-transformers/paraphrase-distilroberta-base-v2")
15
+ elif compressor1=="KiS":
16
+ compressor = PromptCompressor(type='KiSCompressor', device="cuda", model_dir="philippelaban/keep_it_simple")
17
+ else:
18
+ compressor = PromptCompressor(type='SCCompressor', lang='en', model='gpt2', device='cuda')
19
+
20
+ if compressor1 != "SCRL":
21
+ compressed_prompt = compressor.compressgo(original_prompt=original_text, ratio=float(ratio), max_length=int(maxlength))
22
+ else:
23
+ compressed_prompt = compressor.compressgo(original_prompt=original_text, ratio=float(ratio), max_length=int(maxlength))
24
+ return compressed_prompt["compressed_prompt"]
25
+
26
+
27
+ demo = gr.Interface(
28
+ fn=compressit,
29
+ inputs=[
30
+ gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="input", info="Enter the original prompt here."),
31
+ gr.Dropdown(
32
+ ["Selective Context", "LLMLingua", "LongLLMLingua", "SCRL", "KiS"], label="compressor", info="Choose your compressor here. \n Currently, we cannot support the online demo for LLMLingua and LongLLMLingua due to the Huggingface Token issue."
33
+ ),
34
+ gr.Textbox(lines=1, placeholder="Enter the compression ratio here...", info="Ratio only works for Selective Context, LLMLingua and LongLLMLingua."),
35
+ gr.Textbox(lines=1, placeholder="Enter the max_length parameter if you are using SCRL or KiS", label="max_length", info="If you are using SCRL or KiS, fill in the parameter, if not, just ignore this.\n Hint: For SCRL, max_length should be shorter than the lenght of original prompt; For KiS, max_length should be longer than it.")
36
+ ],
37
+ outputs=[
38
+ gr.Textbox(lines=1, info="Please note that when the text is very short, LLMLingua and LongLLMLingua will not work.")
39
+ ]
40
+ )
41
+
42
+ demo.launch(share=False)
compressor.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from selective_context_compressor import SCCompressor
2
+ from kis import KiSCompressor
3
+ from scrl_compressor import SCRLCompressor
4
+ from llmlingua_compressor_pro import LLMLinguaCompressor
5
+ from typing import List
6
+
7
+
8
+ class PromptCompressor:
9
+ def __init__(self, type: str = 'SCCompressor', lang: str = 'en', model='gpt2', device='cuda', model_dir: str = '',
10
+ use_auth_token: bool = False, open_api_config: dict = {}, token: str = '',
11
+ tokenizer_dir: str = "sentence-transformers/paraphrase-distilroberta-base-v2"):
12
+ self.type = type
13
+ if self.type == 'SCCompressor':
14
+ self.compressor = SCCompressor(lang=lang, model=model, device=device)
15
+ elif self.type == 'KiSCompressor':
16
+ self.compressor = KiSCompressor(DEVICE=device, model_dir=model_dir)
17
+ elif self.type == 'LLMLinguaCompressor':
18
+ self.compressor = LLMLinguaCompressor(device_map=device, model_name=model_dir, use_auth_token=use_auth_token, open_api_config=open_api_config, token=token)
19
+ elif self.type == 'LongLLMLinguaCompressor':
20
+ self.compressor = LLMLinguaCompressor(device_map=device, model_name=model_dir, use_auth_token=use_auth_token, open_api_config=open_api_config, token=token)
21
+ elif self.type == 'SCRLCompressor':
22
+ if model_dir:
23
+ self.compressor = SCRLCompressor(model_dir=model_dir, device=device, tokenizer_dir=tokenizer_dir)
24
+ else:
25
+ print("model_dir parameter is required")
26
+
27
+ def compressgo(self, original_prompt: str = '', ratio: float = 0.5, level: str = 'phrase',
28
+ max_length: int = 256, num_beams: int = 4, do_sample: bool = True, num_return_sequences: int = 1,
29
+ target_index: int = 0, instruction: str = "", question: str = "", target_token: float = -1,
30
+ iterative_size: int = 200, force_context_ids: List[int] = None, force_context_number: int = None,
31
+ use_sentence_level_filter: bool = False, use_context_level_filter: bool = True,
32
+ use_token_level_filter: bool = True, keep_split: bool = False, keep_first_sentence: int = 0,
33
+ keep_last_sentence: int = 0, keep_sentence_number: int = 0, high_priority_bonus: int = 100,
34
+ context_budget: str = "+100", token_budget_ratio: float = 1.4, condition_in_question: str = "none",
35
+ reorder_context: str = "original", dynamic_context_compression_ratio: float = 0.0,
36
+ condition_compare: bool = False, add_instruction: bool = False, rank_method: str = "llmlingua",
37
+ concate_question: bool = True,):
38
+ if self.type == 'SCCompressor':
39
+ return self.compressor.compress(original_prompt=original_prompt, ratio=ratio, level=level)
40
+ elif self.type == 'KiSCompressor':
41
+ return self.compressor.compress(original_prompt=original_prompt, ratio=ratio, max_length=max_length, num_beams=num_beams, do_sample=do_sample, num_return_sequences=num_return_sequences, target_index=target_index)
42
+ elif self.type == 'SCRLCompressor':
43
+ return self.compressor.compress(original_prompt=original_prompt, ratio=ratio, max_length=max_length)
44
+ elif self.type == 'LLMLinguaCompressor':
45
+ return self.compressor.compress(context=original_prompt, ratio=ratio, instruction=instruction, question=question, target_token=target_token,
46
+ iterative_size=iterative_size, force_context_ids=force_context_ids, force_context_number=force_context_number,
47
+ use_token_level_filter=use_token_level_filter, use_context_level_filter=use_context_level_filter,
48
+ use_sentence_level_filter=use_sentence_level_filter, keep_split=keep_split, keep_first_sentence=keep_first_sentence,
49
+ keep_last_sentence=keep_last_sentence, keep_sentence_number=keep_sentence_number, high_priority_bonus=high_priority_bonus,
50
+ context_budget=context_budget, token_budget_ratio=token_budget_ratio, condition_in_question=condition_in_question,
51
+ reorder_context = reorder_context, dynamic_context_compression_ratio=dynamic_context_compression_ratio, condition_compare=condition_compare,
52
+ add_instruction=add_instruction, rank_method=rank_method, concate_question=concate_question)
53
+ elif self.type == 'LongLLMLinguaCompressor':
54
+ return self.compressor.compress(context=original_prompt, ratio=ratio, instruction=instruction, question=question, target_token=target_token,
55
+ iterative_size=iterative_size, force_context_ids=force_context_ids, force_context_number=force_context_number,
56
+ use_token_level_filter=use_token_level_filter, use_context_level_filter=use_context_level_filter,
57
+ use_sentence_level_filter=use_sentence_level_filter, keep_split=keep_split, keep_first_sentence=keep_first_sentence,
58
+ keep_last_sentence=keep_last_sentence, keep_sentence_number=keep_sentence_number, high_priority_bonus=high_priority_bonus,
59
+ context_budget=context_budget, token_budget_ratio=token_budget_ratio, condition_in_question=condition_in_question,
60
+ reorder_context = reorder_context, dynamic_context_compression_ratio=dynamic_context_compression_ratio, condition_compare=condition_compare,
61
+ add_instruction=add_instruction, rank_method=rank_method, concate_question=concate_question)
62
+ else:
63
+ return self.compressor.compress(original_prompt=original_prompt, ratio=ratio)
64
+
65
+
llmlingua_compressor_pro.py ADDED
@@ -0,0 +1,1152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlingua import PromptCompressor
2
+ import bisect
3
+ from collections import defaultdict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ import nltk
10
+ import tiktoken
11
+ import re
12
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
13
+ from abs_compressor import AbstractCompressor
14
+
15
+ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
16
+
17
+ class LLMLinguaCompressor(AbstractCompressor):
18
+ def __init__(
19
+ self,
20
+ model_name: str = "meta-llama/Llama-2-7b-chat-hf",
21
+ device_map: str = "cuda",
22
+ use_auth_token: bool = False,
23
+ open_api_config: dict = {},
24
+ token: str = ''
25
+ ):
26
+ self.model_name = model_name
27
+ self.token = token
28
+ self.load_model(model_name, device_map, use_auth_token)
29
+ self.retrieval_model = None
30
+ self.retrieval_model_name = None
31
+ self.open_api_config = open_api_config
32
+ self.cache_bos_num = 10
33
+
34
+ def load_model(
35
+ self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
36
+ ):
37
+ config = AutoConfig.from_pretrained(self.model_name)
38
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
39
+ tokenizer.padding_side = "left"
40
+ tokenizer.pad_token_id = (
41
+ config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
42
+ )
43
+ self.device = (
44
+ device_map if any(key in device_map for key in ["cuda", "cpu"]) else "cuda"
45
+ )
46
+ if "cuda" in device_map or "cpu" in device_map:
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_name,
49
+ torch_dtype="auto" if device_map == "cuda" else torch.float32,
50
+ config=config,
51
+ ignore_mismatched_sizes=True,
52
+ trust_remote_code=True,
53
+ token=self.token
54
+ ).to(device_map)
55
+ else:
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ model_name,
58
+ device_map=device_map,
59
+ torch_dtype="auto",
60
+ pad_token_id=tokenizer.pad_token_id,
61
+ offload_folder="/tmp/offload",
62
+ offload_state_dict=True,
63
+ cache_dir="/tmp/cache",
64
+ use_auth_token=use_auth_token,
65
+ trust_remote_code=True,
66
+ token=self.token
67
+ )
68
+ self.tokenizer = tokenizer
69
+ self.model = model
70
+ self.context_idxs = []
71
+ self.max_position_embeddings = config.max_position_embeddings
72
+
73
+ def get_ppl(
74
+ self,
75
+ text: str,
76
+ granularity: str = "sentence",
77
+ input_ids=None,
78
+ attention_mask=None,
79
+ past_key_values=None,
80
+ return_kv=False,
81
+ end=None,
82
+ condition_mode: str = "none",
83
+ condition_pos_id: int = 0,
84
+ ):
85
+ if input_ids is None:
86
+ tokenized_text = self.tokenizer(text, return_tensors="pt")
87
+ input_ids = tokenized_text["input_ids"].to(self.device)
88
+ attention_mask = tokenized_text["attention_mask"].to(self.device)
89
+ if past_key_values is not None:
90
+ past_length = past_key_values[0][0].shape[2]
91
+ else:
92
+ past_length = 0
93
+ if end is None:
94
+ end = input_ids.shape[1]
95
+ end = min(end, past_length + self.max_position_embeddings)
96
+ with torch.no_grad():
97
+ response = self.model(
98
+ input_ids[:, past_length:end],
99
+ attention_mask=attention_mask[:, :end],
100
+ past_key_values=past_key_values,
101
+ use_cache=True,
102
+ )
103
+ past_key_values = response.past_key_values
104
+
105
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
106
+ shift_logits = response.logits[..., :-1, :].contiguous()
107
+ shift_labels = input_ids[..., past_length + 1 : end].contiguous()
108
+ # Flatten the tokens
109
+ active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
110
+ active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
111
+ active_labels = shift_labels.view(-1)[active]
112
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
113
+ loss = loss_fct(active_logits, active_labels)
114
+ if condition_mode == "before":
115
+ loss = loss[:condition_pos_id]
116
+ elif condition_mode == "after":
117
+ loss = loss[condition_pos_id:]
118
+ res = loss.mean() if granularity == "sentence" else loss
119
+ return (res, past_key_values) if return_kv else res
120
+
121
+ def __call__(self, *args, **kwargs):
122
+ return self.compress(*args, **kwargs)
123
+
124
+ def compress(
125
+ self,
126
+ context: List[str],
127
+ instruction: str = "",
128
+ question: str = "",
129
+ ratio: float = 0.5,
130
+ target_token: float = -1,
131
+ iterative_size: int = 200,
132
+ force_context_ids: List[int] = None,
133
+ force_context_number: int = None,
134
+ use_sentence_level_filter: bool = False,
135
+ use_context_level_filter: bool = True,
136
+ use_token_level_filter: bool = True,
137
+ keep_split: bool = False,
138
+ keep_first_sentence: int = 0,
139
+ keep_last_sentence: int = 0,
140
+ keep_sentence_number: int = 0,
141
+ high_priority_bonus: int = 100,
142
+ context_budget: str = "+100",
143
+ token_budget_ratio: float = 1.4,
144
+ condition_in_question: str = "none",
145
+ reorder_context: str = "original",
146
+ dynamic_context_compression_ratio: float = 0.0,
147
+ condition_compare: bool = False,
148
+ add_instruction: bool = False,
149
+ rank_method: str = "llmlingua",
150
+ concate_question: bool = True,
151
+ ):
152
+ if isinstance(context, str):
153
+ context = [context]
154
+ assert not (
155
+ rank_method == "longllmlingua" and not question
156
+ ), "In the LongLLMLingua, it is necessary to set a question."
157
+ if condition_compare and "_condition" not in condition_in_question:
158
+ condition_in_question += "_condition"
159
+ if rank_method == "longllmlingua":
160
+ if condition_in_question == "none":
161
+ condition_in_question = "after"
162
+ elif rank_method == "llmlingua":
163
+ condition_in_question = (
164
+ "none"
165
+ if "_condition" not in condition_in_question
166
+ else "none_condition"
167
+ )
168
+ origin_tokens = len(
169
+ encoding.encode("\n\n".join([instruction] + context + [question]).strip())
170
+ )
171
+ context_tokens_length = [self.get_token_length(c) for c in context]
172
+ instruction_tokens_length, question_tokens_length = self.get_token_length(
173
+ instruction
174
+ ), self.get_token_length(question)
175
+ if target_token == -1:
176
+ target_token = (
177
+ (
178
+ instruction_tokens_length
179
+ + question_tokens_length
180
+ + sum(context_tokens_length)
181
+ )
182
+ * (1 - ratio)
183
+ - instruction_tokens_length
184
+ - (question_tokens_length if concate_question else 0)
185
+ )
186
+ condition_flag = "_condition" in condition_in_question
187
+ condition_in_question = condition_in_question.replace("_condition", "")
188
+
189
+ if len(context) > 1 and use_context_level_filter:
190
+ context, dynamic_ratio = self.control_context_budget(
191
+ context,
192
+ context_tokens_length,
193
+ target_token,
194
+ force_context_ids,
195
+ force_context_number,
196
+ question,
197
+ condition_in_question,
198
+ reorder_context=reorder_context,
199
+ dynamic_context_compression_ratio=dynamic_context_compression_ratio,
200
+ rank_method=rank_method,
201
+ context_budget=context_budget,
202
+ )
203
+ else:
204
+ dynamic_ratio = [0.0] * len(context)
205
+
206
+ if use_sentence_level_filter:
207
+ context = self.control_sentence_budget(
208
+ context,
209
+ target_token,
210
+ keep_first_sentence=keep_first_sentence,
211
+ keep_last_sentence=keep_last_sentence,
212
+ keep_sentence_number=keep_sentence_number,
213
+ high_priority_bonus=high_priority_bonus,
214
+ token_budget_ratio=token_budget_ratio,
215
+ question=question,
216
+ condition_in_question=condition_in_question,
217
+ rank_method=rank_method,
218
+ )
219
+
220
+ if condition_flag:
221
+ if add_instruction:
222
+ context = [question + "\n\n" + instruction] + context
223
+ start = self.get_token_length(question + "\n\n" + instruction) + 2
224
+ else:
225
+ context = [question] + context
226
+ start = self.get_token_length(question) + 2
227
+ else:
228
+ start = 0
229
+
230
+ if use_token_level_filter:
231
+ context = self.iterative_compress_prompt(
232
+ context,
233
+ target_token,
234
+ iterative_size=iterative_size,
235
+ keep_split=keep_split,
236
+ start=start,
237
+ dynamic_ratio=dynamic_ratio,
238
+ condition_compare=condition_compare,
239
+ )
240
+ compressed_prompt = (
241
+ self.tokenizer.batch_decode(context[0])[0]
242
+ .replace("<s> ", "")
243
+ .replace("<s>", "")
244
+ )
245
+ else:
246
+ compressed_prompt = "\n\n".join(context)
247
+
248
+ if instruction:
249
+ compressed_prompt = instruction + "\n\n" + compressed_prompt
250
+ if question and concate_question:
251
+ compressed_prompt = compressed_prompt + "\n\n" + question
252
+
253
+ compressed_tokens = len(encoding.encode(compressed_prompt))
254
+ saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
255
+ return {
256
+ "compressed_prompt": compressed_prompt,
257
+ "origin_tokens": origin_tokens,
258
+ "compressed_tokens": compressed_tokens,
259
+ # "ratio": f"{origin_tokens/compressed_tokens:.1f}x",
260
+ "ratio": compressed_tokens / origin_tokens,
261
+ # "saving": f", Saving ${saving:.1f} in GPT-4.",
262
+ }
263
+
264
+ def get_token_length(self, text: str, add_special_tokens: bool = True):
265
+ return len(
266
+ self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
267
+ )
268
+
269
+ def get_condition_ppl(
270
+ self,
271
+ text: str,
272
+ question: str,
273
+ condition_in_question: str = "none",
274
+ granularity: str = "sentence",
275
+ ):
276
+ if condition_in_question == "none":
277
+ return self.get_ppl(text, granularity=granularity)
278
+ elif condition_in_question == "before":
279
+ return self.get_ppl(
280
+ question + text,
281
+ granularity=granularity,
282
+ condition_mode="after",
283
+ condition_pos_id=self.get_token_length(question) - 1,
284
+ )
285
+ elif condition_in_question == "after":
286
+ return self.get_ppl(
287
+ text + question,
288
+ granularity=granularity,
289
+ condition_mode="after",
290
+ condition_pos_id=self.get_token_length(text) - 1,
291
+ )
292
+
293
+ def get_dynamic_compression_ratio(
294
+ self,
295
+ context: list,
296
+ target_token: float,
297
+ iterative_size: int,
298
+ dynamic_ratio: list,
299
+ start: int,
300
+ ):
301
+ def get_ratio(base: float, delta: float):
302
+ return max(min(1, base + delta), 0)
303
+
304
+ context_length = [self.get_token_length(ii, False) + 2 for ii in context]
305
+ if start:
306
+ context_length = context_length[1:]
307
+ tau = target_token / (sum(context_length) + 1)
308
+ res, idx, last, last_target = [], 0, 1, []
309
+ while idx < len(context_length):
310
+ if last + context_length[idx] >= iterative_size:
311
+ last_target.append(
312
+ (iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
313
+ )
314
+ res.append(last_target)
315
+ last = last + context_length[idx] - iterative_size
316
+ if last > iterative_size:
317
+ k = last // iterative_size
318
+ res.extend(
319
+ [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
320
+ )
321
+ last -= k * iterative_size
322
+
323
+ last_target = (
324
+ [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
325
+ )
326
+ else:
327
+ last += context_length[idx]
328
+ last_target.append(
329
+ (context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
330
+ )
331
+ idx += 1
332
+ if last_target:
333
+ res.append(last_target)
334
+ return res
335
+
336
+ def control_context_budget(
337
+ self,
338
+ context: List[str],
339
+ context_tokens_length: List[int],
340
+ target_token: float,
341
+ force_context_ids: List[int] = None,
342
+ force_context_number: int = None,
343
+ question: str = "",
344
+ condition_in_question: str = "none",
345
+ reorder_context: str = "original",
346
+ dynamic_context_compression_ratio: float = 0.0,
347
+ rank_method: str = "longllmlingua",
348
+ context_budget: str = "+100",
349
+ ):
350
+ if force_context_ids is not None:
351
+ return [context[ii] for ii in force_context_ids]
352
+ demostrations_sort = self.get_rank_results(
353
+ context,
354
+ question,
355
+ rank_method,
356
+ condition_in_question,
357
+ context_tokens_length,
358
+ )
359
+
360
+ if target_token < 0:
361
+ target_token = 100
362
+ target_token = eval("target_token" + context_budget)
363
+ res = []
364
+ used = force_context_ids if force_context_ids is not None else []
365
+
366
+ self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
367
+ for idx, _ in demostrations_sort:
368
+ if idx >= len(context_tokens_length):
369
+ continue
370
+ target_token -= context_tokens_length[idx]
371
+ if idx not in used:
372
+ used.append(idx)
373
+ if target_token < 0 or (
374
+ force_context_number is not None and len(res) >= force_context_number
375
+ ):
376
+ break
377
+ original_used = used
378
+ if reorder_context == "original":
379
+ used = sorted(used)
380
+ elif reorder_context == "two_stage":
381
+ l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
382
+ _ for idx, _ in enumerate(used) if idx % 2 == 1
383
+ ]
384
+ used = l + r[::-1]
385
+
386
+ if dynamic_context_compression_ratio > 0:
387
+ N = len(used)
388
+ if condition_in_question:
389
+ rank = [
390
+ i
391
+ for i, _ in self.get_rank_results(
392
+ context,
393
+ question,
394
+ "longllmlingua",
395
+ "after",
396
+ context_tokens_length,
397
+ )
398
+ ]
399
+ used = sorted(used, key=lambda x: rank.index(x))
400
+ dynamic_ratio = [
401
+ i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
402
+ for i in range(-(N - 1), N, 2)
403
+ ][::-1]
404
+ dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
405
+ dynamic_ratio = [dynamic_ratio_map[i] for i in used]
406
+ else:
407
+ dynamic_ratio = [0.0] * len(used)
408
+
409
+ res = [context[idx] for idx in used if idx < len(context)]
410
+ return res, dynamic_ratio
411
+
412
+ def control_sentence_budget(
413
+ self,
414
+ context: List[str],
415
+ target_token: float,
416
+ keep_first_sentence: int = 0,
417
+ keep_last_sentence: int = 0,
418
+ keep_sentence_number: int = 0,
419
+ high_priority_bonus: int = 100,
420
+ token_budget_ratio: float = 1.4,
421
+ question: str = "",
422
+ condition_in_question: str = "none",
423
+ rank_method: str = "longllmlingua",
424
+ ):
425
+ def keep_sentence(dem_idx: int, sent_keep: int):
426
+ idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
427
+ for idx in idxs:
428
+ sentence_ppl[idx] += high_priority_bonus
429
+
430
+ sentences = [nltk.sent_tokenize(c) for c in context]
431
+ dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
432
+ for idx_d, s in enumerate(sentences):
433
+ for _ in s:
434
+ dem_g[idx_d].add(idx)
435
+ s2de[idx] = idx_d
436
+ idx += 1
437
+
438
+ context_sentences = [s for ii in sentences for s in ii]
439
+ sentence_tokens_length = [
440
+ self.get_token_length(sentence) for sentence in context_sentences
441
+ ]
442
+ N = len(context_sentences)
443
+ flags = list(range(len(context_sentences)))
444
+ if len(sentence_tokens_length) == 1:
445
+ return context
446
+ if rank_method == "longllmlingua":
447
+ sentence_ppl = [
448
+ self.get_condition_ppl(sentence, question, condition_in_question)
449
+ .cpu()
450
+ .numpy()
451
+ .item()
452
+ for sentence in context_sentences
453
+ ]
454
+ if keep_first_sentence:
455
+ sentence_ppl[:keep_first_sentence] = [
456
+ ii + high_priority_bonus
457
+ for ii in sentence_ppl[:keep_first_sentence]
458
+ ]
459
+ if keep_last_sentence:
460
+ sentence_ppl[-keep_last_sentence:] = [
461
+ ii + high_priority_bonus
462
+ for ii in sentence_ppl[-keep_last_sentence:]
463
+ ]
464
+ if keep_sentence_number:
465
+ for dem_idx in range(len(sentences)):
466
+ keep_sentence(dem_idx, keep_sentence_number)
467
+ sort_direct = -1 if condition_in_question == "none" else 1
468
+ sent_sort = sorted(
469
+ enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
470
+ )
471
+ else:
472
+ sent_sort = self.get_rank_results(
473
+ context_sentences,
474
+ question,
475
+ rank_method,
476
+ condition_in_question,
477
+ [0] * len(context_sentences),
478
+ )
479
+
480
+ sentence_flags = [False] * N
481
+ if target_token < 0:
482
+ target_token = 100
483
+ target_token *= token_budget_ratio
484
+ res = []
485
+ for idx, _ in sent_sort:
486
+ idx = flags[idx]
487
+ target_token -= sentence_tokens_length[idx]
488
+ sentence_flags[idx] = True
489
+ if target_token < 0:
490
+ break
491
+ idx = 0
492
+ res = []
493
+ for s in sentences:
494
+ tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
495
+ res.append("\n".join(tmp))
496
+ idx += len(s)
497
+ return res
498
+
499
+ def get_compressed_input(
500
+ self,
501
+ loss,
502
+ input_ids,
503
+ attention_mask,
504
+ end=200,
505
+ iterative_size=200,
506
+ threshold=0.5,
507
+ keep_flag=None,
508
+ split_token_id: int = 13,
509
+ start: int = 0,
510
+ self_loss=None,
511
+ self_input_ids=None,
512
+ self_attention_mask=None,
513
+ ):
514
+ if self_loss is not None:
515
+ need_idx = torch.concat(
516
+ [
517
+ loss[:start] > 0,
518
+ self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
519
+ loss[:1] > 0,
520
+ ]
521
+ )
522
+ else:
523
+ need_idx = torch.concat([loss > threshold, loss[:1] > 0])
524
+ need_idx[end:] = 1
525
+ need_idx[: end - iterative_size] = 1
526
+ loss = loss[need_idx[:-1]]
527
+ if self_loss is not None:
528
+ if need_idx.shape[0] < self_loss.shape[0] + start + 1:
529
+ need_idx = torch.cat(
530
+ [
531
+ need_idx,
532
+ torch.ones(
533
+ self_loss.shape[0] - need_idx.shape[0] + start + 1,
534
+ dtype=torch.bool,
535
+ ).to(need_idx.device),
536
+ ]
537
+ )
538
+ self_loss = self_loss[need_idx[start:-1]]
539
+
540
+ if need_idx.shape[0] < input_ids.shape[1]:
541
+ need_idx = torch.cat(
542
+ [
543
+ need_idx,
544
+ torch.ones(
545
+ input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
546
+ ).to(need_idx.device),
547
+ ]
548
+ )
549
+ elif need_idx.shape[0] > input_ids.shape[1]:
550
+ need_idx = need_idx[: input_ids.shape[1]]
551
+
552
+ if keep_flag is not None:
553
+ need_idx[keep_flag == 1] = 1
554
+ last = -1
555
+ if keep_flag is not None:
556
+ for ii in range(end - iterative_size, end):
557
+ if need_idx[ii] != 1:
558
+ continue
559
+ now = input_ids[0][ii].detach().cpu().item()
560
+ if (
561
+ now == split_token_id
562
+ and last == split_token_id
563
+ and keep_flag[ii].detach().cpu().item() == 0
564
+ ):
565
+ need_idx[ii] = 0
566
+ else:
567
+ last = now
568
+ compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
569
+ compressed_attention_mask = attention_mask[attention_mask == 1][
570
+ need_idx
571
+ ].unsqueeze(0)
572
+
573
+ if self_loss is not None:
574
+ self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
575
+ need_idx[start:]
576
+ ].unsqueeze(0)
577
+ self_compressed_attention_mask = self_attention_mask[
578
+ self_attention_mask == 1
579
+ ][need_idx[start:]].unsqueeze(0)
580
+ else:
581
+ self_compressed_input_ids, self_compressed_attention_mask = None, None
582
+ if keep_flag is not None:
583
+ if len(keep_flag) > len(need_idx):
584
+ keep_flag = torch.cat(
585
+ [
586
+ keep_flag[:start],
587
+ keep_flag[start : len(need_idx) + start][need_idx],
588
+ keep_flag[start + len(need_idx) :],
589
+ ]
590
+ )
591
+ else:
592
+ keep_flag = keep_flag[need_idx]
593
+ end -= (need_idx[:end] == 0).sum()
594
+ return (
595
+ compressed_input_ids,
596
+ compressed_attention_mask,
597
+ keep_flag,
598
+ end,
599
+ loss,
600
+ self_loss,
601
+ self_compressed_input_ids,
602
+ self_compressed_attention_mask,
603
+ )
604
+
605
+ def get_estimate_threshold_base_distribution(
606
+ self, ppl, ratio: float, condition_flag: bool = False
607
+ ):
608
+ ppl = ppl[ppl != 10000]
609
+ target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
610
+ return (
611
+ ppl.sort(descending=not condition_flag)
612
+ .values[target_token]
613
+ .detach()
614
+ .cpu()
615
+ .item()
616
+ )
617
+
618
+ def iterative_compress_prompt(
619
+ self,
620
+ context: List[str],
621
+ target_token: float,
622
+ iterative_size: int = 200,
623
+ keep_split: bool = False,
624
+ split_token_id: int = 13,
625
+ start: int = 0,
626
+ dynamic_ratio: list = None,
627
+ condition_compare: bool = False,
628
+ ):
629
+ iterative_ratios = self.get_dynamic_compression_ratio(
630
+ context, target_token, iterative_size, dynamic_ratio, start
631
+ )
632
+ context = "\n\n".join(context)
633
+ tokenized_text = self.tokenizer(context, return_tensors="pt")
634
+ input_ids = tokenized_text["input_ids"].to(self.device)
635
+ attention_mask = tokenized_text["attention_mask"].to(self.device)
636
+
637
+ N = (attention_mask == 1).sum()
638
+ compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
639
+ if condition_compare:
640
+ self_input_ids, self_attention_mask = (
641
+ input_ids[:, start:],
642
+ attention_mask[:, start:],
643
+ )
644
+ self_compressed_input_ids, self_compressed_attention_mask = (
645
+ self_input_ids,
646
+ self_attention_mask,
647
+ )
648
+
649
+ end = min(iterative_size + start, compressed_input_ids.shape[1])
650
+ threshold, keep_flag = None, None
651
+ if keep_split:
652
+ input_ids_numpy = input_ids.cpu().detach().numpy()[0]
653
+ N = len(input_ids_numpy)
654
+ keep_flag = [
655
+ int(
656
+ (
657
+ ii > 0
658
+ and input_ids_numpy[ii] == split_token_id
659
+ and input_ids_numpy[ii - 1] == split_token_id
660
+ )
661
+ or (
662
+ ii < N - 1
663
+ and input_ids_numpy[ii] == split_token_id
664
+ and input_ids_numpy[ii + 1] == split_token_id
665
+ )
666
+ )
667
+ for ii in range(N)
668
+ ]
669
+ keep_flag = torch.tensor(keep_flag).to(self.device)
670
+ past_key_values, past_loss, ready_end = None, None, 0
671
+ self_past_key_values, self_past_loss, self_ready_end = None, None, 0
672
+ pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
673
+ idx = 0
674
+ while end <= compressed_input_ids.shape[1]:
675
+ if end > self.max_position_embeddings and past_key_values is not None:
676
+ # KV-Cache Compression
677
+ e, s = end - self.max_position_embeddings, self.cache_bos_num
678
+ if pop_compressed_input_ids is None:
679
+ pop_compressed_input_ids = compressed_input_ids[:, :e]
680
+ else:
681
+ pop_compressed_input_ids = torch.cat(
682
+ [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
683
+ )
684
+ compressed_input_ids = compressed_input_ids[:, e:]
685
+ compressed_attention_mask = compressed_attention_mask[:, e:]
686
+ past_key_values = [
687
+ [
688
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
689
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
690
+ ]
691
+ for k, v in past_key_values
692
+ ]
693
+ end, ready_end = end - e, ready_end - e
694
+ if condition_compare:
695
+ self_ready_end -= e
696
+ if pop_self_compressed_input_ids is None:
697
+ pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
698
+ else:
699
+ pop_self_compressed_input_ids = torch.cat(
700
+ [
701
+ pop_self_compressed_input_ids,
702
+ self_compressed_input_ids[:, :e],
703
+ ],
704
+ dim=-1,
705
+ )
706
+ self_compressed_input_ids = self_compressed_input_ids[:, e:]
707
+ self_compressed_attention_mask = self_compressed_attention_mask[
708
+ :, e:
709
+ ]
710
+ self_past_key_values = [
711
+ [
712
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
713
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
714
+ ]
715
+ for k, v in self_past_key_values
716
+ ]
717
+
718
+ loss, past_key_values = self.get_ppl(
719
+ "",
720
+ "token",
721
+ compressed_input_ids,
722
+ compressed_attention_mask,
723
+ past_key_values=past_key_values,
724
+ return_kv=True,
725
+ end=end if idx else None,
726
+ )
727
+ if past_loss is not None:
728
+ if end - 1 > len(past_loss):
729
+ past_loss = torch.cat(
730
+ [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
731
+ )
732
+ past_loss[ready_end : end - 1] = loss
733
+ loss = past_loss
734
+ else:
735
+ past_loss = loss
736
+ if idx:
737
+ past_key_values = [
738
+ [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
739
+ for k, v in past_key_values
740
+ ]
741
+ else:
742
+ past_key_values = None
743
+
744
+ if condition_compare:
745
+ self_loss, self_past_key_values = self.get_ppl(
746
+ "",
747
+ "token",
748
+ self_compressed_input_ids,
749
+ self_compressed_attention_mask,
750
+ past_key_values=self_past_key_values,
751
+ return_kv=True,
752
+ end=end - start if idx else None,
753
+ )
754
+ if self_past_loss is not None:
755
+ if end - start - 1 > len(self_past_loss):
756
+ self_past_loss = torch.cat(
757
+ [
758
+ self_past_loss,
759
+ torch.zeros_like(self_loss)[
760
+ : end - 1 - start - len(self_past_loss)
761
+ ],
762
+ ]
763
+ )
764
+ self_past_loss[self_ready_end : end - start - 1] = self_loss
765
+ self_loss = self_past_loss
766
+ else:
767
+ self_past_loss = self_loss
768
+ if idx:
769
+ self_past_key_values = [
770
+ [
771
+ k[:, :, : end - iterative_size - start],
772
+ v[:, :, : end - iterative_size - start],
773
+ ]
774
+ for k, v in self_past_key_values
775
+ ]
776
+ else:
777
+ self_past_key_values = None
778
+
779
+ self_ready_end = (
780
+ end - start - iterative_size if not (start and idx == 0) else 0
781
+ )
782
+ ready_end = end - iterative_size if not (start and idx == 0) else 0
783
+
784
+ for delta_end, ratio in iterative_ratios[idx]:
785
+ loss = past_loss
786
+ if condition_compare:
787
+ self_loss = self_past_loss
788
+ threshold = self.get_estimate_threshold_base_distribution(
789
+ self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
790
+ )
791
+ else:
792
+ threshold = self.get_estimate_threshold_base_distribution(
793
+ loss, ratio, False
794
+ )
795
+
796
+ (
797
+ compressed_input_ids,
798
+ compressed_attention_mask,
799
+ keep_flag,
800
+ end,
801
+ past_loss,
802
+ self_past_loss,
803
+ self_compressed_input_ids,
804
+ self_compressed_attention_mask,
805
+ ) = self.get_compressed_input(
806
+ loss,
807
+ compressed_input_ids,
808
+ compressed_attention_mask,
809
+ end - iterative_size + delta_end,
810
+ iterative_size=delta_end,
811
+ threshold=threshold,
812
+ keep_flag=keep_flag,
813
+ split_token_id=split_token_id,
814
+ start=start,
815
+ self_loss=self_loss if condition_compare else None,
816
+ self_input_ids=self_compressed_input_ids
817
+ if condition_compare
818
+ else None,
819
+ self_attention_mask=self_compressed_attention_mask
820
+ if condition_compare
821
+ else None,
822
+ )
823
+ end += iterative_size
824
+ idx += 1
825
+ if pop_compressed_input_ids is not None:
826
+ compressed_input_ids = torch.cat(
827
+ [pop_compressed_input_ids, compressed_input_ids], dim=-1
828
+ )
829
+ return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
830
+
831
+ def recover(
832
+ self,
833
+ original_prompt: str,
834
+ compressed_prompt: str,
835
+ response: str,
836
+ ):
837
+ def match_from_compressed(response_word):
838
+ response_input_ids = self.tokenizer(
839
+ response_word, add_special_tokens=False
840
+ )["input_ids"]
841
+ response_set, response_c = set(response_input_ids), defaultdict(list)
842
+ for idx in range(M):
843
+ if original_input_ids[idx] in response_set:
844
+ response_c[original_input_ids[idx]].append(idx)
845
+ res, res_min, res_c = None, float("inf"), 1
846
+ n = len(response_input_ids)
847
+ for l in response_c[response_input_ids[0]]:
848
+ x, y, c = 0, l, 1
849
+ for x in range(1, n):
850
+ idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
851
+ if (
852
+ idx >= len(response_c[response_input_ids[x]])
853
+ or response_c[response_input_ids[x]][idx] - y > 10
854
+ ):
855
+ continue
856
+ c += 1
857
+ y = response_c[response_input_ids[x]][idx]
858
+ if c > res_c:
859
+ res_c = c
860
+ res_min = y - l + 1
861
+ res = (l, y + 1)
862
+ elif c == res_c and y - l + 1 < res_min:
863
+ res_min = y - l + 1
864
+ res = (l, y + 1)
865
+
866
+ if res is None:
867
+ return response_word
868
+ # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
869
+ # l -= 1
870
+ # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
871
+ # l -= 1
872
+ return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
873
+
874
+ response_words = response.split(" ")
875
+
876
+ original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
877
+ "input_ids"
878
+ ]
879
+ N, M = len(response_words), len(original_input_ids)
880
+ recovered_response_words = []
881
+ l = 0
882
+ while l < N:
883
+ if response_words[l] not in compressed_prompt:
884
+ recovered_response_words.append(response_words[l])
885
+ l += 1
886
+ continue
887
+ r = l
888
+ while (
889
+ r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
890
+ ):
891
+ r += 1
892
+
893
+ match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
894
+ recovered_response_words.append(match_words)
895
+ l = r + 1
896
+ return " ".join(recovered_response_words)
897
+
898
+ def get_rank_results(
899
+ self,
900
+ context: list,
901
+ question: str,
902
+ rank_method: str,
903
+ condition_in_question: str,
904
+ context_tokens_length: list,
905
+ ):
906
+ def get_distance_bm25(corpus, query):
907
+ from rank_bm25 import BM25Okapi
908
+
909
+ tokenized_corpus = [doc.split(" ") for doc in corpus]
910
+ bm25 = BM25Okapi(tokenized_corpus)
911
+ tokenized_query = query.split(" ")
912
+ doc_scores = bm25.get_scores(tokenized_query)
913
+ idx = [(ii, 0) for ii in (-doc_scores).argsort()]
914
+ return idx
915
+
916
+ def get_distance_gzip(corpus, query):
917
+ def get_score(x, y):
918
+ cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
919
+ cxy = len(gzip.compress(f"{x} {y}".encode()))
920
+ return (cxy - min(cx, cy)) / max(cx, cy)
921
+
922
+ import gzip
923
+
924
+ doc_scores = [get_score(doc, query) for doc in corpus]
925
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
926
+ return idx
927
+
928
+ def get_distance_sentbert(corpus, query):
929
+ from sentence_transformers import SentenceTransformer, util
930
+
931
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
932
+ self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
933
+ self.retrieval_model_name = rank_method
934
+ doc_embeds = self.retrieval_model.encode(corpus)
935
+ query = self.retrieval_model.encode(query)
936
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
937
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
938
+ return idx
939
+
940
+ def get_distance_openai(corpus, query):
941
+ import openai
942
+ from sentence_transformers import util
943
+
944
+ openai.api_key = self.open_api_config.get("api_key", "")
945
+ openai.api_base = self.open_api_config.get(
946
+ "api_base", "https://api.openai.com/v1"
947
+ )
948
+ openai.api_type = self.open_api_config.get("api_type", "open_ai")
949
+ openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
950
+ engine = self.open_api_config.get("engine", "text-embedding-ada-002")
951
+
952
+ def get_embed(text):
953
+ return openai.Embedding.create(
954
+ input=[text.replace("\n", " ")], engine=engine
955
+ )["LongBench"][0]["embedding"]
956
+
957
+ doc_embeds = [get_embed(i) for i in corpus]
958
+ query = get_embed(query)
959
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
960
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
961
+ return idx
962
+
963
+ def get_distance_sentbert_bge(corpus, query):
964
+ from sentence_transformers import SentenceTransformer, util
965
+
966
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
967
+ self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
968
+ self.retrieval_model_name = rank_method
969
+ doc_embeds = self.retrieval_model.encode(
970
+ [i for i in corpus], normalize_embeddings=True
971
+ )
972
+ query = self.retrieval_model.encode(query, normalize_embeddings=True)
973
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
974
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
975
+ return idx
976
+
977
+ def get_distance_bge_ranker(corpus, query):
978
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
979
+
980
+ pairs = [[i, query] for i in corpus]
981
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
982
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large")
983
+ model = (
984
+ AutoModelForSequenceClassification.from_pretrained(
985
+ "BAAI/bge-reranker-large"
986
+ )
987
+ .eval()
988
+ .to(self.device)
989
+ )
990
+ self.retrieval_model = [tokenizer, model]
991
+ self.retrieval_model_name = rank_method
992
+ with torch.no_grad():
993
+ inputs = self.retrieval_model[0](
994
+ pairs,
995
+ padding=True,
996
+ truncation=True,
997
+ return_tensors="pt",
998
+ max_length=512,
999
+ ).to(self.device)
1000
+ scores = (
1001
+ self.retrieval_model[1](**inputs, return_dict=True)
1002
+ .logits.view(
1003
+ -1,
1004
+ )
1005
+ .float()
1006
+ )
1007
+ idx = [(ii, 0) for ii in np.argsort(-scores.cpu())]
1008
+ return idx
1009
+
1010
+ def get_distance_bge_llmembedder(corpus, query):
1011
+ from transformers import AutoModel, AutoTokenizer
1012
+
1013
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
1014
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
1015
+ model = (
1016
+ AutoModel.from_pretrained("BAAI/llm-embedder")
1017
+ .eval()
1018
+ .to(self.device)
1019
+ )
1020
+ self.retrieval_model = [tokenizer, model]
1021
+ self.retrieval_model_name = rank_method
1022
+
1023
+ instruction_qa_query = (
1024
+ "Represent this query for retrieving relevant documents: "
1025
+ )
1026
+ instruction_qa_key = "Represent this document for retrieval: "
1027
+ queries = [instruction_qa_query + query for _ in corpus]
1028
+ keys = [instruction_qa_key + key for key in corpus]
1029
+ with torch.no_grad():
1030
+ query_inputs = self.retrieval_model[0](
1031
+ queries,
1032
+ padding=True,
1033
+ truncation=True,
1034
+ return_tensors="pt",
1035
+ max_length=512,
1036
+ ).to(self.device)
1037
+ key_inputs = self.retrieval_model[0](
1038
+ keys,
1039
+ padding=True,
1040
+ truncation=True,
1041
+ return_tensors="pt",
1042
+ max_length=512,
1043
+ ).to(self.device)
1044
+ query_outputs = self.retrieval_model[1](**query_inputs)
1045
+ key_outputs = self.retrieval_model[1](**key_inputs)
1046
+ # CLS pooling
1047
+ query_embeddings = query_outputs.last_hidden_state[:, 0]
1048
+ key_embeddings = key_outputs.last_hidden_state[:, 0]
1049
+ # Normalize
1050
+ query_embeddings = torch.nn.functional.normalize(
1051
+ query_embeddings, p=2, dim=1
1052
+ )
1053
+ key_embeddings = torch.nn.functional.normalize(
1054
+ key_embeddings, p=2, dim=1
1055
+ )
1056
+ similarity = query_embeddings @ key_embeddings.T
1057
+ idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())]
1058
+ return idx
1059
+
1060
+ def get_distance_jinza(corpus, query):
1061
+ from numpy.linalg import norm
1062
+
1063
+ from transformers import AutoModel
1064
+
1065
+ def cos_sim(a, b):
1066
+ return (a @ b.T) / (norm(a) * norm(b))
1067
+
1068
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
1069
+ model = (
1070
+ AutoModel.from_pretrained(
1071
+ "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
1072
+ )
1073
+ .eval()
1074
+ .to(self.device)
1075
+ )
1076
+ self.retrieval_model = model
1077
+ self.retrieval_model_name = rank_method
1078
+
1079
+ doc_embeds = self.retrieval_model.encode(corpus)
1080
+ query = self.retrieval_model.encode(query)
1081
+ doc_scores = cos_sim(doc_embeds, query)
1082
+ idx = [(ii, 0) for ii in np.argsort(-doc_scores)]
1083
+ return idx
1084
+
1085
+ def get_distance_voyageai(corpus, query):
1086
+ import voyageai
1087
+ from sentence_transformers import util
1088
+
1089
+ voyageai.api_key = self.open_api_config.get("voyageai_api_key", "")
1090
+
1091
+ def get_embed(text):
1092
+ return voyageai.get_embedding(text, model="voyage-01")
1093
+
1094
+ doc_embeds = [get_embed(i) for i in corpus]
1095
+ query = get_embed(query)
1096
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
1097
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
1098
+ return idx
1099
+
1100
+ def get_distance_cohere(corpus, query):
1101
+ import cohere
1102
+
1103
+ api_key = self.open_api_config.get("cohere_api_key", "")
1104
+ co = cohere.Client(api_key)
1105
+ results = co.rerank(
1106
+ model="rerank-english-v2.0", query=query, documents=corpus, top_n=20
1107
+ )
1108
+ c_map = {jj: ii for ii, jj in enumerate(corpus)}
1109
+ doc_rank = [c_map[ii.document["text"]] for ii in results]
1110
+ idx = [(ii, 0) for ii in doc_rank]
1111
+ return idx
1112
+
1113
+ def get_distance_longllmlingua(corpus, query):
1114
+ context_ppl = [
1115
+ self.get_condition_ppl(
1116
+ d,
1117
+ query
1118
+ + " We can get the answer to this question in the given documents.",
1119
+ condition_in_question,
1120
+ )
1121
+ - dl * 2 / 250 * 0
1122
+ for d, dl in zip(corpus, context_tokens_length)
1123
+ ]
1124
+ sort_direct = -1 if condition_in_question == "none" else 1
1125
+ ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
1126
+ return ys
1127
+
1128
+ method = None
1129
+ if rank_method == "bm25":
1130
+ method = get_distance_bm25
1131
+ elif rank_method == "gzip":
1132
+ method = get_distance_gzip
1133
+ elif rank_method == "sentbert":
1134
+ method = get_distance_sentbert
1135
+ elif rank_method == "openai":
1136
+ method = get_distance_openai
1137
+ elif rank_method in ["longllmlingua", "llmlingua"]:
1138
+ method = get_distance_longllmlingua
1139
+ elif rank_method == "bge":
1140
+ method = get_distance_sentbert_bge
1141
+ elif rank_method == "bge_reranker":
1142
+ method = get_distance_bge_ranker
1143
+ elif rank_method == "bge_llmembedder":
1144
+ method = get_distance_bge_llmembedder
1145
+ elif rank_method == "jinza":
1146
+ method = get_distance_jinza
1147
+ elif rank_method == "voyageai":
1148
+ method = get_distance_voyageai
1149
+ elif rank_method == "cohere":
1150
+ method = get_distance_cohere
1151
+ return method(context, question)
1152
+
longlingua_compressor.py ADDED
@@ -0,0 +1,1150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llmlingua import PromptCompressor
2
+ import bisect
3
+ from collections import defaultdict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ import nltk
10
+ import tiktoken
11
+ import re
12
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
13
+
14
+ from abs_compressor import AbstractCompressor
15
+
16
+ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
17
+
18
+ class LongLLMLinguaCompressor(AbstractCompressor):
19
+ def __init__(
20
+ self,
21
+ model_name: str = "meta-llama/Llama-2-7b-chat-hf",
22
+ device_map: str = "cuda",
23
+ use_auth_token: bool = False,
24
+ open_api_config: dict = {},
25
+ ):
26
+ self.load_model(model_name, device_map, use_auth_token)
27
+ self.retrieval_model = None
28
+ self.retrieval_model_name = None
29
+ self.open_api_config = open_api_config
30
+ self.cache_bos_num = 10
31
+
32
+ def load_model(
33
+ self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
34
+ ):
35
+ config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
36
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
37
+ tokenizer.padding_side = "left"
38
+ tokenizer.pad_token_id = (
39
+ config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
40
+ )
41
+ self.device = (
42
+ device_map if any(key in device_map for key in ["cuda", "cpu"]) else "cuda"
43
+ )
44
+ if "cuda" in device_map or "cpu" in device_map:
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ torch_dtype="auto" if device_map == "cuda" else torch.float32,
48
+ config=config,
49
+ ignore_mismatched_sizes=True,
50
+ trust_remote_code=True,
51
+ token="Your Token here"
52
+ ).to(device_map)
53
+ else:
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ model_name,
56
+ device_map=device_map,
57
+ torch_dtype="auto",
58
+ pad_token_id=tokenizer.pad_token_id,
59
+ offload_folder="/tmp/offload",
60
+ offload_state_dict=True,
61
+ cache_dir="/tmp/cache",
62
+ use_auth_token=use_auth_token,
63
+ trust_remote_code=True,
64
+ token="Your Token here"
65
+ )
66
+ self.tokenizer = tokenizer
67
+ self.model = model
68
+ self.context_idxs = []
69
+ self.max_position_embeddings = config.max_position_embeddings
70
+
71
+ def get_ppl(
72
+ self,
73
+ text: str,
74
+ granularity: str = "sentence",
75
+ input_ids=None,
76
+ attention_mask=None,
77
+ past_key_values=None,
78
+ return_kv=False,
79
+ end=None,
80
+ condition_mode: str = "none",
81
+ condition_pos_id: int = 0,
82
+ ):
83
+ if input_ids is None:
84
+ tokenized_text = self.tokenizer(text, return_tensors="pt")
85
+ input_ids = tokenized_text["input_ids"].to(self.device)
86
+ attention_mask = tokenized_text["attention_mask"].to(self.device)
87
+ if past_key_values is not None:
88
+ past_length = past_key_values[0][0].shape[2]
89
+ else:
90
+ past_length = 0
91
+ if end is None:
92
+ end = input_ids.shape[1]
93
+ end = min(end, past_length + self.max_position_embeddings)
94
+ with torch.no_grad():
95
+ response = self.model(
96
+ input_ids[:, past_length:end],
97
+ attention_mask=attention_mask[:, :end],
98
+ past_key_values=past_key_values,
99
+ use_cache=True,
100
+ )
101
+ past_key_values = response.past_key_values
102
+
103
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
104
+ shift_logits = response.logits[..., :-1, :].contiguous()
105
+ shift_labels = input_ids[..., past_length + 1 : end].contiguous()
106
+ # Flatten the tokens
107
+ active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
108
+ active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
109
+ active_labels = shift_labels.view(-1)[active]
110
+ loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
111
+ loss = loss_fct(active_logits, active_labels)
112
+ if condition_mode == "before":
113
+ loss = loss[:condition_pos_id]
114
+ elif condition_mode == "after":
115
+ loss = loss[condition_pos_id:]
116
+ res = loss.mean() if granularity == "sentence" else loss
117
+ return (res, past_key_values) if return_kv else res
118
+
119
+ def __call__(self, *args, **kwargs):
120
+ return self.compress(*args, **kwargs)
121
+
122
+ def compress(
123
+ self,
124
+ context: List[str],
125
+ instruction: str = "",
126
+ question: str = " ",
127
+ ratio: float = 0.5,
128
+ target_token: float = -1,
129
+ iterative_size: int = 200,
130
+ force_context_ids: List[int] = None,
131
+ force_context_number: int = None,
132
+ use_sentence_level_filter: bool = False,
133
+ use_context_level_filter: bool = True,
134
+ use_token_level_filter: bool = True,
135
+ keep_split: bool = False,
136
+ keep_first_sentence: int = 0,
137
+ keep_last_sentence: int = 0,
138
+ keep_sentence_number: int = 0,
139
+ high_priority_bonus: int = 100,
140
+ context_budget: str = "+100",
141
+ token_budget_ratio: float = 1.4,
142
+ condition_in_question: str = "none",
143
+ reorder_context: str = "original",
144
+ dynamic_context_compression_ratio: float = 0.0,
145
+ condition_compare: bool = False,
146
+ add_instruction: bool = False,
147
+ rank_method: str = "longllmlingua",
148
+ concate_question: bool = True,
149
+ ):
150
+ if isinstance(context, str):
151
+ context = [context]
152
+ assert not (
153
+ rank_method == "longllmlingua" and not question
154
+ ), "In the LongLLMLingua, it is necessary to set a question."
155
+ if condition_compare and "_condition" not in condition_in_question:
156
+ condition_in_question += "_condition"
157
+ if rank_method == "longllmlingua":
158
+ if condition_in_question == "none":
159
+ condition_in_question = "after"
160
+ elif rank_method == "llmlingua":
161
+ condition_in_question = (
162
+ "none"
163
+ if "_condition" not in condition_in_question
164
+ else "none_condition"
165
+ )
166
+ origin_tokens = len(
167
+ encoding.encode("\n\n".join([instruction] + context + [question]).strip())
168
+ )
169
+ context_tokens_length = [self.get_token_length(c) for c in context]
170
+ instruction_tokens_length, question_tokens_length = self.get_token_length(
171
+ instruction
172
+ ), self.get_token_length(question)
173
+ if target_token == -1:
174
+ target_token = (
175
+ (
176
+ instruction_tokens_length
177
+ + question_tokens_length
178
+ + sum(context_tokens_length)
179
+ )
180
+ * (1 - ratio)
181
+ - instruction_tokens_length
182
+ - (question_tokens_length if concate_question else 0)
183
+ )
184
+ condition_flag = "_condition" in condition_in_question
185
+ condition_in_question = condition_in_question.replace("_condition", "")
186
+
187
+ if len(context) > 1 and use_context_level_filter:
188
+ context, dynamic_ratio = self.control_context_budget(
189
+ context,
190
+ context_tokens_length,
191
+ target_token,
192
+ force_context_ids,
193
+ force_context_number,
194
+ question,
195
+ condition_in_question,
196
+ reorder_context=reorder_context,
197
+ dynamic_context_compression_ratio=dynamic_context_compression_ratio,
198
+ rank_method=rank_method,
199
+ context_budget=context_budget,
200
+ )
201
+ else:
202
+ dynamic_ratio = [0.0] * len(context)
203
+
204
+ if use_sentence_level_filter:
205
+ context = self.control_sentence_budget(
206
+ context,
207
+ target_token,
208
+ keep_first_sentence=keep_first_sentence,
209
+ keep_last_sentence=keep_last_sentence,
210
+ keep_sentence_number=keep_sentence_number,
211
+ high_priority_bonus=high_priority_bonus,
212
+ token_budget_ratio=token_budget_ratio,
213
+ question=question,
214
+ condition_in_question=condition_in_question,
215
+ rank_method=rank_method,
216
+ )
217
+
218
+ if condition_flag:
219
+ if add_instruction:
220
+ context = [question + "\n\n" + instruction] + context
221
+ start = self.get_token_length(question + "\n\n" + instruction) + 2
222
+ else:
223
+ context = [question] + context
224
+ start = self.get_token_length(question) + 2
225
+ else:
226
+ start = 0
227
+
228
+ if use_token_level_filter:
229
+ context = self.iterative_compress_prompt(
230
+ context,
231
+ target_token,
232
+ iterative_size=iterative_size,
233
+ keep_split=keep_split,
234
+ start=start,
235
+ dynamic_ratio=dynamic_ratio,
236
+ condition_compare=condition_compare,
237
+ )
238
+ compressed_prompt = (
239
+ self.tokenizer.batch_decode(context[0])[0]
240
+ .replace("<s> ", "")
241
+ .replace("<s>", "")
242
+ )
243
+ else:
244
+ compressed_prompt = "\n\n".join(context)
245
+
246
+ if instruction:
247
+ compressed_prompt = instruction + "\n\n" + compressed_prompt
248
+ if question and concate_question:
249
+ compressed_prompt = compressed_prompt + "\n\n" + question
250
+
251
+ compressed_tokens = len(encoding.encode(compressed_prompt))
252
+ saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
253
+ return {
254
+ "compressed_prompt": compressed_prompt,
255
+ "origin_tokens": origin_tokens,
256
+ "compressed_tokens": compressed_tokens,
257
+ # "ratio": f"{origin_tokens/compressed_tokens:.1f}x",
258
+ "ratio": compressed_tokens / origin_tokens,
259
+ # "saving": f", Saving ${saving:.1f} in GPT-4.",
260
+ }
261
+
262
+ def get_token_length(self, text: str, add_special_tokens: bool = True):
263
+ return len(
264
+ self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
265
+ )
266
+
267
+ def get_condition_ppl(
268
+ self,
269
+ text: str,
270
+ question: str,
271
+ condition_in_question: str = "none",
272
+ granularity: str = "sentence",
273
+ ):
274
+ if condition_in_question == "none":
275
+ return self.get_ppl(text, granularity=granularity)
276
+ elif condition_in_question == "before":
277
+ return self.get_ppl(
278
+ question + text,
279
+ granularity=granularity,
280
+ condition_mode="after",
281
+ condition_pos_id=self.get_token_length(question) - 1,
282
+ )
283
+ elif condition_in_question == "after":
284
+ return self.get_ppl(
285
+ text + question,
286
+ granularity=granularity,
287
+ condition_mode="after",
288
+ condition_pos_id=self.get_token_length(text) - 1,
289
+ )
290
+
291
+ def get_dynamic_compression_ratio(
292
+ self,
293
+ context: list,
294
+ target_token: float,
295
+ iterative_size: int,
296
+ dynamic_ratio: list,
297
+ start: int,
298
+ ):
299
+ def get_ratio(base: float, delta: float):
300
+ return max(min(1, base + delta), 0)
301
+
302
+ context_length = [self.get_token_length(ii, False) + 2 for ii in context]
303
+ if start:
304
+ context_length = context_length[1:]
305
+ tau = target_token / (sum(context_length) + 1)
306
+ res, idx, last, last_target = [], 0, 1, []
307
+ while idx < len(context_length):
308
+ if last + context_length[idx] >= iterative_size:
309
+ last_target.append(
310
+ (iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
311
+ )
312
+ res.append(last_target)
313
+ last = last + context_length[idx] - iterative_size
314
+ if last > iterative_size:
315
+ k = last // iterative_size
316
+ res.extend(
317
+ [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
318
+ )
319
+ last -= k * iterative_size
320
+
321
+ last_target = (
322
+ [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
323
+ )
324
+ else:
325
+ last += context_length[idx]
326
+ last_target.append(
327
+ (context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
328
+ )
329
+ idx += 1
330
+ if last_target:
331
+ res.append(last_target)
332
+ return res
333
+
334
+ def control_context_budget(
335
+ self,
336
+ context: List[str],
337
+ context_tokens_length: List[int],
338
+ target_token: float,
339
+ force_context_ids: List[int] = None,
340
+ force_context_number: int = None,
341
+ question: str = "",
342
+ condition_in_question: str = "none",
343
+ reorder_context: str = "original",
344
+ dynamic_context_compression_ratio: float = 0.0,
345
+ rank_method: str = "longllmlingua",
346
+ context_budget: str = "+100",
347
+ ):
348
+ if force_context_ids is not None:
349
+ return [context[ii] for ii in force_context_ids]
350
+ demostrations_sort = self.get_rank_results(
351
+ context,
352
+ question,
353
+ rank_method,
354
+ condition_in_question,
355
+ context_tokens_length,
356
+ )
357
+
358
+ if target_token < 0:
359
+ target_token = 100
360
+ target_token = eval("target_token" + context_budget)
361
+ res = []
362
+ used = force_context_ids if force_context_ids is not None else []
363
+
364
+ self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
365
+ for idx, _ in demostrations_sort:
366
+ if idx >= len(context_tokens_length):
367
+ continue
368
+ target_token -= context_tokens_length[idx]
369
+ if idx not in used:
370
+ used.append(idx)
371
+ if target_token < 0 or (
372
+ force_context_number is not None and len(res) >= force_context_number
373
+ ):
374
+ break
375
+ original_used = used
376
+ if reorder_context == "original":
377
+ used = sorted(used)
378
+ elif reorder_context == "two_stage":
379
+ l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
380
+ _ for idx, _ in enumerate(used) if idx % 2 == 1
381
+ ]
382
+ used = l + r[::-1]
383
+
384
+ if dynamic_context_compression_ratio > 0:
385
+ N = len(used)
386
+ if condition_in_question:
387
+ rank = [
388
+ i
389
+ for i, _ in self.get_rank_results(
390
+ context,
391
+ question,
392
+ "longllmlingua",
393
+ "after",
394
+ context_tokens_length,
395
+ )
396
+ ]
397
+ used = sorted(used, key=lambda x: rank.index(x))
398
+ dynamic_ratio = [
399
+ i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
400
+ for i in range(-(N - 1), N, 2)
401
+ ][::-1]
402
+ dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
403
+ dynamic_ratio = [dynamic_ratio_map[i] for i in used]
404
+ else:
405
+ dynamic_ratio = [0.0] * len(used)
406
+
407
+ res = [context[idx] for idx in used if idx < len(context)]
408
+ return res, dynamic_ratio
409
+
410
+ def control_sentence_budget(
411
+ self,
412
+ context: List[str],
413
+ target_token: float,
414
+ keep_first_sentence: int = 0,
415
+ keep_last_sentence: int = 0,
416
+ keep_sentence_number: int = 0,
417
+ high_priority_bonus: int = 100,
418
+ token_budget_ratio: float = 1.4,
419
+ question: str = "",
420
+ condition_in_question: str = "none",
421
+ rank_method: str = "longllmlingua",
422
+ ):
423
+ def keep_sentence(dem_idx: int, sent_keep: int):
424
+ idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
425
+ for idx in idxs:
426
+ sentence_ppl[idx] += high_priority_bonus
427
+
428
+ sentences = [nltk.sent_tokenize(c) for c in context]
429
+ dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
430
+ for idx_d, s in enumerate(sentences):
431
+ for _ in s:
432
+ dem_g[idx_d].add(idx)
433
+ s2de[idx] = idx_d
434
+ idx += 1
435
+
436
+ context_sentences = [s for ii in sentences for s in ii]
437
+ sentence_tokens_length = [
438
+ self.get_token_length(sentence) for sentence in context_sentences
439
+ ]
440
+ N = len(context_sentences)
441
+ flags = list(range(len(context_sentences)))
442
+ if len(sentence_tokens_length) == 1:
443
+ return context
444
+ if rank_method == "longllmlingua":
445
+ sentence_ppl = [
446
+ self.get_condition_ppl(sentence, question, condition_in_question)
447
+ .cpu()
448
+ .numpy()
449
+ .item()
450
+ for sentence in context_sentences
451
+ ]
452
+ if keep_first_sentence:
453
+ sentence_ppl[:keep_first_sentence] = [
454
+ ii + high_priority_bonus
455
+ for ii in sentence_ppl[:keep_first_sentence]
456
+ ]
457
+ if keep_last_sentence:
458
+ sentence_ppl[-keep_last_sentence:] = [
459
+ ii + high_priority_bonus
460
+ for ii in sentence_ppl[-keep_last_sentence:]
461
+ ]
462
+ if keep_sentence_number:
463
+ for dem_idx in range(len(sentences)):
464
+ keep_sentence(dem_idx, keep_sentence_number)
465
+ sort_direct = -1 if condition_in_question == "none" else 1
466
+ sent_sort = sorted(
467
+ enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
468
+ )
469
+ else:
470
+ sent_sort = self.get_rank_results(
471
+ context_sentences,
472
+ question,
473
+ rank_method,
474
+ condition_in_question,
475
+ [0] * len(context_sentences),
476
+ )
477
+
478
+ sentence_flags = [False] * N
479
+ if target_token < 0:
480
+ target_token = 100
481
+ target_token *= token_budget_ratio
482
+ res = []
483
+ for idx, _ in sent_sort:
484
+ idx = flags[idx]
485
+ target_token -= sentence_tokens_length[idx]
486
+ sentence_flags[idx] = True
487
+ if target_token < 0:
488
+ break
489
+ idx = 0
490
+ res = []
491
+ for s in sentences:
492
+ tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
493
+ res.append("\n".join(tmp))
494
+ idx += len(s)
495
+ return res
496
+
497
+ def get_compressed_input(
498
+ self,
499
+ loss,
500
+ input_ids,
501
+ attention_mask,
502
+ end=200,
503
+ iterative_size=200,
504
+ threshold=0.5,
505
+ keep_flag=None,
506
+ split_token_id: int = 13,
507
+ start: int = 0,
508
+ self_loss=None,
509
+ self_input_ids=None,
510
+ self_attention_mask=None,
511
+ ):
512
+ if self_loss is not None:
513
+ need_idx = torch.concat(
514
+ [
515
+ loss[:start] > 0,
516
+ self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
517
+ loss[:1] > 0,
518
+ ]
519
+ )
520
+ else:
521
+ need_idx = torch.concat([loss > threshold, loss[:1] > 0])
522
+ need_idx[end:] = 1
523
+ need_idx[: end - iterative_size] = 1
524
+ loss = loss[need_idx[:-1]]
525
+ if self_loss is not None:
526
+ if need_idx.shape[0] < self_loss.shape[0] + start + 1:
527
+ need_idx = torch.cat(
528
+ [
529
+ need_idx,
530
+ torch.ones(
531
+ self_loss.shape[0] - need_idx.shape[0] + start + 1,
532
+ dtype=torch.bool,
533
+ ).to(need_idx.device),
534
+ ]
535
+ )
536
+ self_loss = self_loss[need_idx[start:-1]]
537
+
538
+ if need_idx.shape[0] < input_ids.shape[1]:
539
+ need_idx = torch.cat(
540
+ [
541
+ need_idx,
542
+ torch.ones(
543
+ input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
544
+ ).to(need_idx.device),
545
+ ]
546
+ )
547
+ elif need_idx.shape[0] > input_ids.shape[1]:
548
+ need_idx = need_idx[: input_ids.shape[1]]
549
+
550
+ if keep_flag is not None:
551
+ need_idx[keep_flag == 1] = 1
552
+ last = -1
553
+ if keep_flag is not None:
554
+ for ii in range(end - iterative_size, end):
555
+ if need_idx[ii] != 1:
556
+ continue
557
+ now = input_ids[0][ii].detach().cpu().item()
558
+ if (
559
+ now == split_token_id
560
+ and last == split_token_id
561
+ and keep_flag[ii].detach().cpu().item() == 0
562
+ ):
563
+ need_idx[ii] = 0
564
+ else:
565
+ last = now
566
+ compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
567
+ compressed_attention_mask = attention_mask[attention_mask == 1][
568
+ need_idx
569
+ ].unsqueeze(0)
570
+
571
+ if self_loss is not None:
572
+ self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
573
+ need_idx[start:]
574
+ ].unsqueeze(0)
575
+ self_compressed_attention_mask = self_attention_mask[
576
+ self_attention_mask == 1
577
+ ][need_idx[start:]].unsqueeze(0)
578
+ else:
579
+ self_compressed_input_ids, self_compressed_attention_mask = None, None
580
+ if keep_flag is not None:
581
+ if len(keep_flag) > len(need_idx):
582
+ keep_flag = torch.cat(
583
+ [
584
+ keep_flag[:start],
585
+ keep_flag[start : len(need_idx) + start][need_idx],
586
+ keep_flag[start + len(need_idx) :],
587
+ ]
588
+ )
589
+ else:
590
+ keep_flag = keep_flag[need_idx]
591
+ end -= (need_idx[:end] == 0).sum()
592
+ return (
593
+ compressed_input_ids,
594
+ compressed_attention_mask,
595
+ keep_flag,
596
+ end,
597
+ loss,
598
+ self_loss,
599
+ self_compressed_input_ids,
600
+ self_compressed_attention_mask,
601
+ )
602
+
603
+ def get_estimate_threshold_base_distribution(
604
+ self, ppl, ratio: float, condition_flag: bool = False
605
+ ):
606
+ ppl = ppl[ppl != 10000]
607
+ target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
608
+ return (
609
+ ppl.sort(descending=not condition_flag)
610
+ .values[target_token]
611
+ .detach()
612
+ .cpu()
613
+ .item()
614
+ )
615
+
616
+ def iterative_compress_prompt(
617
+ self,
618
+ context: List[str],
619
+ target_token: float,
620
+ iterative_size: int = 200,
621
+ keep_split: bool = False,
622
+ split_token_id: int = 13,
623
+ start: int = 0,
624
+ dynamic_ratio: list = None,
625
+ condition_compare: bool = False,
626
+ ):
627
+ iterative_ratios = self.get_dynamic_compression_ratio(
628
+ context, target_token, iterative_size, dynamic_ratio, start
629
+ )
630
+ context = "\n\n".join(context)
631
+ tokenized_text = self.tokenizer(context, return_tensors="pt")
632
+ input_ids = tokenized_text["input_ids"].to(self.device)
633
+ attention_mask = tokenized_text["attention_mask"].to(self.device)
634
+
635
+ N = (attention_mask == 1).sum()
636
+ compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
637
+ if condition_compare:
638
+ self_input_ids, self_attention_mask = (
639
+ input_ids[:, start:],
640
+ attention_mask[:, start:],
641
+ )
642
+ self_compressed_input_ids, self_compressed_attention_mask = (
643
+ self_input_ids,
644
+ self_attention_mask,
645
+ )
646
+
647
+ end = min(iterative_size + start, compressed_input_ids.shape[1])
648
+ threshold, keep_flag = None, None
649
+ if keep_split:
650
+ input_ids_numpy = input_ids.cpu().detach().numpy()[0]
651
+ N = len(input_ids_numpy)
652
+ keep_flag = [
653
+ int(
654
+ (
655
+ ii > 0
656
+ and input_ids_numpy[ii] == split_token_id
657
+ and input_ids_numpy[ii - 1] == split_token_id
658
+ )
659
+ or (
660
+ ii < N - 1
661
+ and input_ids_numpy[ii] == split_token_id
662
+ and input_ids_numpy[ii + 1] == split_token_id
663
+ )
664
+ )
665
+ for ii in range(N)
666
+ ]
667
+ keep_flag = torch.tensor(keep_flag).to(self.device)
668
+ past_key_values, past_loss, ready_end = None, None, 0
669
+ self_past_key_values, self_past_loss, self_ready_end = None, None, 0
670
+ pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
671
+ idx = 0
672
+ while end <= compressed_input_ids.shape[1]:
673
+ if end > self.max_position_embeddings and past_key_values is not None:
674
+ # KV-Cache Compression
675
+ e, s = end - self.max_position_embeddings, self.cache_bos_num
676
+ if pop_compressed_input_ids is None:
677
+ pop_compressed_input_ids = compressed_input_ids[:, :e]
678
+ else:
679
+ pop_compressed_input_ids = torch.cat(
680
+ [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
681
+ )
682
+ compressed_input_ids = compressed_input_ids[:, e:]
683
+ compressed_attention_mask = compressed_attention_mask[:, e:]
684
+ past_key_values = [
685
+ [
686
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
687
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
688
+ ]
689
+ for k, v in past_key_values
690
+ ]
691
+ end, ready_end = end - e, ready_end - e
692
+ if condition_compare:
693
+ self_ready_end -= e
694
+ if pop_self_compressed_input_ids is None:
695
+ pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
696
+ else:
697
+ pop_self_compressed_input_ids = torch.cat(
698
+ [
699
+ pop_self_compressed_input_ids,
700
+ self_compressed_input_ids[:, :e],
701
+ ],
702
+ dim=-1,
703
+ )
704
+ self_compressed_input_ids = self_compressed_input_ids[:, e:]
705
+ self_compressed_attention_mask = self_compressed_attention_mask[
706
+ :, e:
707
+ ]
708
+ self_past_key_values = [
709
+ [
710
+ torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
711
+ torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
712
+ ]
713
+ for k, v in self_past_key_values
714
+ ]
715
+
716
+ loss, past_key_values = self.get_ppl(
717
+ "",
718
+ "token",
719
+ compressed_input_ids,
720
+ compressed_attention_mask,
721
+ past_key_values=past_key_values,
722
+ return_kv=True,
723
+ end=end if idx else None,
724
+ )
725
+ if past_loss is not None:
726
+ if end - 1 > len(past_loss):
727
+ past_loss = torch.cat(
728
+ [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
729
+ )
730
+ past_loss[ready_end : end - 1] = loss
731
+ loss = past_loss
732
+ else:
733
+ past_loss = loss
734
+ if idx:
735
+ past_key_values = [
736
+ [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
737
+ for k, v in past_key_values
738
+ ]
739
+ else:
740
+ past_key_values = None
741
+
742
+ if condition_compare:
743
+ self_loss, self_past_key_values = self.get_ppl(
744
+ "",
745
+ "token",
746
+ self_compressed_input_ids,
747
+ self_compressed_attention_mask,
748
+ past_key_values=self_past_key_values,
749
+ return_kv=True,
750
+ end=end - start if idx else None,
751
+ )
752
+ if self_past_loss is not None:
753
+ if end - start - 1 > len(self_past_loss):
754
+ self_past_loss = torch.cat(
755
+ [
756
+ self_past_loss,
757
+ torch.zeros_like(self_loss)[
758
+ : end - 1 - start - len(self_past_loss)
759
+ ],
760
+ ]
761
+ )
762
+ self_past_loss[self_ready_end : end - start - 1] = self_loss
763
+ self_loss = self_past_loss
764
+ else:
765
+ self_past_loss = self_loss
766
+ if idx:
767
+ self_past_key_values = [
768
+ [
769
+ k[:, :, : end - iterative_size - start],
770
+ v[:, :, : end - iterative_size - start],
771
+ ]
772
+ for k, v in self_past_key_values
773
+ ]
774
+ else:
775
+ self_past_key_values = None
776
+
777
+ self_ready_end = (
778
+ end - start - iterative_size if not (start and idx == 0) else 0
779
+ )
780
+ ready_end = end - iterative_size if not (start and idx == 0) else 0
781
+
782
+ for delta_end, ratio in iterative_ratios[idx]:
783
+ loss = past_loss
784
+ if condition_compare:
785
+ self_loss = self_past_loss
786
+ threshold = self.get_estimate_threshold_base_distribution(
787
+ self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
788
+ )
789
+ else:
790
+ threshold = self.get_estimate_threshold_base_distribution(
791
+ loss, ratio, False
792
+ )
793
+
794
+ (
795
+ compressed_input_ids,
796
+ compressed_attention_mask,
797
+ keep_flag,
798
+ end,
799
+ past_loss,
800
+ self_past_loss,
801
+ self_compressed_input_ids,
802
+ self_compressed_attention_mask,
803
+ ) = self.get_compressed_input(
804
+ loss,
805
+ compressed_input_ids,
806
+ compressed_attention_mask,
807
+ end - iterative_size + delta_end,
808
+ iterative_size=delta_end,
809
+ threshold=threshold,
810
+ keep_flag=keep_flag,
811
+ split_token_id=split_token_id,
812
+ start=start,
813
+ self_loss=self_loss if condition_compare else None,
814
+ self_input_ids=self_compressed_input_ids
815
+ if condition_compare
816
+ else None,
817
+ self_attention_mask=self_compressed_attention_mask
818
+ if condition_compare
819
+ else None,
820
+ )
821
+ end += iterative_size
822
+ idx += 1
823
+ if pop_compressed_input_ids is not None:
824
+ compressed_input_ids = torch.cat(
825
+ [pop_compressed_input_ids, compressed_input_ids], dim=-1
826
+ )
827
+ return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
828
+
829
+ def recover(
830
+ self,
831
+ original_prompt: str,
832
+ compressed_prompt: str,
833
+ response: str,
834
+ ):
835
+ def match_from_compressed(response_word):
836
+ response_input_ids = self.tokenizer(
837
+ response_word, add_special_tokens=False
838
+ )["input_ids"]
839
+ response_set, response_c = set(response_input_ids), defaultdict(list)
840
+ for idx in range(M):
841
+ if original_input_ids[idx] in response_set:
842
+ response_c[original_input_ids[idx]].append(idx)
843
+ res, res_min, res_c = None, float("inf"), 1
844
+ n = len(response_input_ids)
845
+ for l in response_c[response_input_ids[0]]:
846
+ x, y, c = 0, l, 1
847
+ for x in range(1, n):
848
+ idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
849
+ if (
850
+ idx >= len(response_c[response_input_ids[x]])
851
+ or response_c[response_input_ids[x]][idx] - y > 10
852
+ ):
853
+ continue
854
+ c += 1
855
+ y = response_c[response_input_ids[x]][idx]
856
+ if c > res_c:
857
+ res_c = c
858
+ res_min = y - l + 1
859
+ res = (l, y + 1)
860
+ elif c == res_c and y - l + 1 < res_min:
861
+ res_min = y - l + 1
862
+ res = (l, y + 1)
863
+
864
+ if res is None:
865
+ return response_word
866
+ # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
867
+ # l -= 1
868
+ # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
869
+ # l -= 1
870
+ return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
871
+
872
+ response_words = response.split(" ")
873
+
874
+ original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
875
+ "input_ids"
876
+ ]
877
+ N, M = len(response_words), len(original_input_ids)
878
+ recovered_response_words = []
879
+ l = 0
880
+ while l < N:
881
+ if response_words[l] not in compressed_prompt:
882
+ recovered_response_words.append(response_words[l])
883
+ l += 1
884
+ continue
885
+ r = l
886
+ while (
887
+ r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
888
+ ):
889
+ r += 1
890
+
891
+ match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
892
+ recovered_response_words.append(match_words)
893
+ l = r + 1
894
+ return " ".join(recovered_response_words)
895
+
896
+ def get_rank_results(
897
+ self,
898
+ context: list,
899
+ question: str,
900
+ rank_method: str,
901
+ condition_in_question: str,
902
+ context_tokens_length: list,
903
+ ):
904
+ def get_distance_bm25(corpus, query):
905
+ from rank_bm25 import BM25Okapi
906
+
907
+ tokenized_corpus = [doc.split(" ") for doc in corpus]
908
+ bm25 = BM25Okapi(tokenized_corpus)
909
+ tokenized_query = query.split(" ")
910
+ doc_scores = bm25.get_scores(tokenized_query)
911
+ idx = [(ii, 0) for ii in (-doc_scores).argsort()]
912
+ return idx
913
+
914
+ def get_distance_gzip(corpus, query):
915
+ def get_score(x, y):
916
+ cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
917
+ cxy = len(gzip.compress(f"{x} {y}".encode()))
918
+ return (cxy - min(cx, cy)) / max(cx, cy)
919
+
920
+ import gzip
921
+
922
+ doc_scores = [get_score(doc, query) for doc in corpus]
923
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
924
+ return idx
925
+
926
+ def get_distance_sentbert(corpus, query):
927
+ from sentence_transformers import SentenceTransformer, util
928
+
929
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
930
+ self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
931
+ self.retrieval_model_name = rank_method
932
+ doc_embeds = self.retrieval_model.encode(corpus)
933
+ query = self.retrieval_model.encode(query)
934
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
935
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
936
+ return idx
937
+
938
+ def get_distance_openai(corpus, query):
939
+ import openai
940
+ from sentence_transformers import util
941
+
942
+ openai.api_key = self.open_api_config.get("api_key", "")
943
+ openai.api_base = self.open_api_config.get(
944
+ "api_base", "https://api.openai.com/v1"
945
+ )
946
+ openai.api_type = self.open_api_config.get("api_type", "open_ai")
947
+ openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
948
+ engine = self.open_api_config.get("engine", "text-embedding-ada-002")
949
+
950
+ def get_embed(text):
951
+ return openai.Embedding.create(
952
+ input=[text.replace("\n", " ")], engine=engine
953
+ )["LongBench"][0]["embedding"]
954
+
955
+ doc_embeds = [get_embed(i) for i in corpus]
956
+ query = get_embed(query)
957
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
958
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
959
+ return idx
960
+
961
+ def get_distance_sentbert_bge(corpus, query):
962
+ from sentence_transformers import SentenceTransformer, util
963
+
964
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
965
+ self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
966
+ self.retrieval_model_name = rank_method
967
+ doc_embeds = self.retrieval_model.encode(
968
+ [i for i in corpus], normalize_embeddings=True
969
+ )
970
+ query = self.retrieval_model.encode(query, normalize_embeddings=True)
971
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
972
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
973
+ return idx
974
+
975
+ def get_distance_bge_ranker(corpus, query):
976
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
977
+
978
+ pairs = [[i, query] for i in corpus]
979
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
980
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large")
981
+ model = (
982
+ AutoModelForSequenceClassification.from_pretrained(
983
+ "BAAI/bge-reranker-large"
984
+ )
985
+ .eval()
986
+ .to(self.device)
987
+ )
988
+ self.retrieval_model = [tokenizer, model]
989
+ self.retrieval_model_name = rank_method
990
+ with torch.no_grad():
991
+ inputs = self.retrieval_model[0](
992
+ pairs,
993
+ padding=True,
994
+ truncation=True,
995
+ return_tensors="pt",
996
+ max_length=512,
997
+ ).to(self.device)
998
+ scores = (
999
+ self.retrieval_model[1](**inputs, return_dict=True)
1000
+ .logits.view(
1001
+ -1,
1002
+ )
1003
+ .float()
1004
+ )
1005
+ idx = [(ii, 0) for ii in np.argsort(-scores.cpu())]
1006
+ return idx
1007
+
1008
+ def get_distance_bge_llmembedder(corpus, query):
1009
+ from transformers import AutoModel, AutoTokenizer
1010
+
1011
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
1012
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
1013
+ model = (
1014
+ AutoModel.from_pretrained("BAAI/llm-embedder")
1015
+ .eval()
1016
+ .to(self.device)
1017
+ )
1018
+ self.retrieval_model = [tokenizer, model]
1019
+ self.retrieval_model_name = rank_method
1020
+
1021
+ instruction_qa_query = (
1022
+ "Represent this query for retrieving relevant documents: "
1023
+ )
1024
+ instruction_qa_key = "Represent this document for retrieval: "
1025
+ queries = [instruction_qa_query + query for _ in corpus]
1026
+ keys = [instruction_qa_key + key for key in corpus]
1027
+ with torch.no_grad():
1028
+ query_inputs = self.retrieval_model[0](
1029
+ queries,
1030
+ padding=True,
1031
+ truncation=True,
1032
+ return_tensors="pt",
1033
+ max_length=512,
1034
+ ).to(self.device)
1035
+ key_inputs = self.retrieval_model[0](
1036
+ keys,
1037
+ padding=True,
1038
+ truncation=True,
1039
+ return_tensors="pt",
1040
+ max_length=512,
1041
+ ).to(self.device)
1042
+ query_outputs = self.retrieval_model[1](**query_inputs)
1043
+ key_outputs = self.retrieval_model[1](**key_inputs)
1044
+ # CLS pooling
1045
+ query_embeddings = query_outputs.last_hidden_state[:, 0]
1046
+ key_embeddings = key_outputs.last_hidden_state[:, 0]
1047
+ # Normalize
1048
+ query_embeddings = torch.nn.functional.normalize(
1049
+ query_embeddings, p=2, dim=1
1050
+ )
1051
+ key_embeddings = torch.nn.functional.normalize(
1052
+ key_embeddings, p=2, dim=1
1053
+ )
1054
+ similarity = query_embeddings @ key_embeddings.T
1055
+ idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())]
1056
+ return idx
1057
+
1058
+ def get_distance_jinza(corpus, query):
1059
+ from numpy.linalg import norm
1060
+
1061
+ from transformers import AutoModel
1062
+
1063
+ def cos_sim(a, b):
1064
+ return (a @ b.T) / (norm(a) * norm(b))
1065
+
1066
+ if self.retrieval_model is None or self.retrieval_model_name != rank_method:
1067
+ model = (
1068
+ AutoModel.from_pretrained(
1069
+ "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
1070
+ )
1071
+ .eval()
1072
+ .to(self.device)
1073
+ )
1074
+ self.retrieval_model = model
1075
+ self.retrieval_model_name = rank_method
1076
+
1077
+ doc_embeds = self.retrieval_model.encode(corpus)
1078
+ query = self.retrieval_model.encode(query)
1079
+ doc_scores = cos_sim(doc_embeds, query)
1080
+ idx = [(ii, 0) for ii in np.argsort(-doc_scores)]
1081
+ return idx
1082
+
1083
+ def get_distance_voyageai(corpus, query):
1084
+ import voyageai
1085
+ from sentence_transformers import util
1086
+
1087
+ voyageai.api_key = self.open_api_config.get("voyageai_api_key", "")
1088
+
1089
+ def get_embed(text):
1090
+ return voyageai.get_embedding(text, model="voyage-01")
1091
+
1092
+ doc_embeds = [get_embed(i) for i in corpus]
1093
+ query = get_embed(query)
1094
+ doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
1095
+ idx = [(ii, 0) for ii in np.argsort(doc_scores)]
1096
+ return idx
1097
+
1098
+ def get_distance_cohere(corpus, query):
1099
+ import cohere
1100
+
1101
+ api_key = self.open_api_config.get("cohere_api_key", "")
1102
+ co = cohere.Client(api_key)
1103
+ results = co.rerank(
1104
+ model="rerank-english-v2.0", query=query, documents=corpus, top_n=20
1105
+ )
1106
+ c_map = {jj: ii for ii, jj in enumerate(corpus)}
1107
+ doc_rank = [c_map[ii.document["text"]] for ii in results]
1108
+ idx = [(ii, 0) for ii in doc_rank]
1109
+ return idx
1110
+
1111
+ def get_distance_longllmlingua(corpus, query):
1112
+ context_ppl = [
1113
+ self.get_condition_ppl(
1114
+ d,
1115
+ query
1116
+ + " We can get the answer to this question in the given documents.",
1117
+ condition_in_question,
1118
+ )
1119
+ - dl * 2 / 250 * 0
1120
+ for d, dl in zip(corpus, context_tokens_length)
1121
+ ]
1122
+ sort_direct = -1 if condition_in_question == "none" else 1
1123
+ ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
1124
+ return ys
1125
+
1126
+ method = None
1127
+ if rank_method == "bm25":
1128
+ method = get_distance_bm25
1129
+ elif rank_method == "gzip":
1130
+ method = get_distance_gzip
1131
+ elif rank_method == "sentbert":
1132
+ method = get_distance_sentbert
1133
+ elif rank_method == "openai":
1134
+ method = get_distance_openai
1135
+ elif rank_method in ["longllmlingua", "llmlingua"]:
1136
+ method = get_distance_longllmlingua
1137
+ elif rank_method == "bge":
1138
+ method = get_distance_sentbert_bge
1139
+ elif rank_method == "bge_reranker":
1140
+ method = get_distance_bge_ranker
1141
+ elif rank_method == "bge_llmembedder":
1142
+ method = get_distance_bge_llmembedder
1143
+ elif rank_method == "jinza":
1144
+ method = get_distance_jinza
1145
+ elif rank_method == "voyageai":
1146
+ method = get_distance_voyageai
1147
+ elif rank_method == "cohere":
1148
+ method = get_distance_cohere
1149
+ return method(context, question)
1150
+