Spaces:
Sleeping
Sleeping
File size: 6,617 Bytes
fb2d628 995aaf9 fb2d628 b8dd3b4 fb2d628 b8dd3b4 fb2d628 b8dd3b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import requests
import nltk
import random
import json
import os
import pickle
import re
nltk.download('punkt')
hf_tokens = []
filepath = __file__.replace("\\", "/").replace("utils.py", "")
with open(filepath + "data/hf_tokens.pkl", "rb") as f:
hf_tokens = pickle.load(f)
MAX_TOKEN_LENGTH = 4096
MAX_CHUNK_SIZE = 16000
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct"
def prompt_template(prompt, sys_prompt = ""):
return_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<system_prompt><|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<user_prompt><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'.replace('<user_prompt>', prompt).replace('<system_prompt>', sys_prompt)
return return_prompt
def query(payload: dict, hf_token: str):
headers = {"Authorization": f"Bearer {hf_token}"}
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
def gen_prompt(prompt: str, sys_prompt:str = ""):
input_prompt = prompt_template(prompt, sys_prompt)
selected_token = ''
for token in hf_tokens:
test_output = query({
"inputs": prompt_template("Who are you?"),
"parameters": {"max_new_tokens": 100}
}, token)
if 'error' not in test_output:
selected_token = token
break
output = query({
"inputs": input_prompt,
"parameters": {"max_new_tokens": 512},
}, selected_token)
return output[0]['generated_text'][len(input_prompt):]
class Node:
def __init__(self, summary=None):
self.summary = summary
self.children = []
self.parent = None
def add_child(self, child_node):
child_node.parent = self
self.children.append(child_node)
class MemWalker:
def __init__(self, segments):
self.segments = segments
self.root = 0
def build_memory_tree(self):
# Step 1: Create leaf nodes for each segment
leaves = [Node(summarize(seg, 0)) for seg in self.segments]
# Step 2: Build tree recursively
while len(leaves) > 1:
new_leaves = []
for i in range(0, len(leaves), 2):
if i + 1 < len(leaves):
combined_summary = summarize(leaves[i].summary + ", " + leaves[i + 1].summary, 1)
parent_node = Node(combined_summary)
parent_node.add_child(leaves[i])
parent_node.add_child(leaves[i + 1])
else:
parent_node = leaves[i]
new_leaves.append(parent_node)
leaves = new_leaves
self.root = leaves[0]
# Placeholder functions for LLM operations
def summarize(text, sum_type: int = 1):
assert sum_type in [0, 1], "Lmao sum type should be either 0 or 1"
if sum_type == 0:
USER_PROMPT = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + text
else:
USER_PROMPT = "Compress the following summaries into a much shorter summary: " + "\n\n" + text
SYS_PROMPT = "Act as a professional technical meeting minutes writer."
tmp = gen_prompt(USER_PROMPT, SYS_PROMPT)
if len(tmp.split("\n\n")) == 1:
return tmp
else:
return tmp.split("\n\n")[1]
#return output[0]['generated_text'][len(input_prompt):]
def split_chunk(transcript: str):
sentences = nltk.sent_tokenize(transcript)
idx = 0
chunk = []
current_chunk = ""
while idx < len(sentences):
if len(current_chunk + sentences[idx]) < MAX_CHUNK_SIZE:
current_chunk += sentences[idx] + " "
else:
chunk.append(current_chunk)
current_chunk = ''
for i in range(10, -1, -1):
current_chunk += sentences[idx - i] + " "
idx += 1
chunk.append(current_chunk)
return chunk
def summarize_three_ways(chunks: list[str]):
SYS_PROMPT = "Act as a professional technical meeting minutes writer."
PROMPT_TEMPLATE = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + "{text}"
REFINE_TEMPLATE = (
"Your job is to produce a final summary\n"
"We have provided an existing summary up to a certain point: {existing_answer}\n"
"We have the opportunity to refine the existing summary"
"(only if needed) with some more context below.\n"
"------------\n"
"{text}\n"
"------------\n"
f"Given the new context, refine the original summary in English within 5 sentences. If the context isn't useful, return the original summary."
)
step = 0
prev_sum = ""
partial_sum = []
return_dict = {}
for chunk in chunks:
if step == 0:
CUR_PROMPT = PROMPT_TEMPLATE.replace("{text}", chunk)
cur_sum = gen_prompt(CUR_PROMPT , SYS_PROMPT)
else:
CUR_PROMPT = REFINE_TEMPLATE.replace("{existing_answer}", partial_sum[-1])
CUR_PROMPT = CUR_PROMPT.replace("{text}", chunk)
cur_sum = gen_prompt(CUR_PROMPT, SYS_PROMPT)
if len(cur_sum.split("\n\n")) > 1:
cur_sum = cur_sum.split("\n\n")[1]
#print(cur_sum)
partial_sum.append(cur_sum)
step += 1
#print(partial_sum)
CUR_PROMPT = "Rewrite the following text by maintaining coherency: " + "\n\n"
CUR_PROMPT += ' '.join(partial_sum)
tmp = gen_prompt(CUR_PROMPT, SYS_PROMPT)
final_sum = ''
if len(tmp.split("\n\n")) == 1:
final_sum = tmp
else:
final_sum = tmp.split("\n\n")[1]
return_dict['truncated'] = partial_sum[0]
return_dict['accumulate'] = partial_sum[-1]
return_dict['rewrite'] = final_sum
return return_dict
def get_example()->list[str]:
data = []
with open(filepath + "data/test.json", "r") as f:
for line in f:
data.append(json.loads(line))
#random_idx = random.sample(list(range(len(data))), 6)
random_idx = [2, 89, 94, 97]
#random_idx = [1, 2, 9, 13]
return ['\n'.join(nltk.sent_tokenize(data[i]['transcript'])) for i in random_idx]
if __name__ == "__main__":
data = []
with open(filepath + "data/test.json", "r") as f:
for line in f:
data.append(json.loads(line))
tmp = data[:100]
for j, i in enumerate(tmp):
print(j, len(i['transcript'])) |