Spaces:
Sleeping
Sleeping
import requests | |
import nltk | |
import random | |
import json | |
import os | |
import pickle | |
import re | |
nltk.download('punkt') | |
hf_tokens = [] | |
filepath = __file__.replace("\\", "/").replace("utils.py", "") | |
with open(filepath + "data/hf_tokens.pkl", "rb") as f: | |
hf_tokens = pickle.load(f) | |
MAX_TOKEN_LENGTH = 4096 | |
MAX_CHUNK_SIZE = 16000 | |
API_URL = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct" | |
def prompt_template(prompt, sys_prompt = ""): | |
return_prompt = '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<system_prompt><|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<user_prompt><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n'.replace('<user_prompt>', prompt).replace('<system_prompt>', sys_prompt) | |
return return_prompt | |
def query(payload: dict, hf_token: str): | |
headers = {"Authorization": f"Bearer {hf_token}"} | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
def gen_prompt(prompt: str, sys_prompt:str = ""): | |
input_prompt = prompt_template(prompt, sys_prompt) | |
selected_token = '' | |
for token in hf_tokens: | |
test_output = query({ | |
"inputs": prompt_template("Who are you?"), | |
"parameters": {"max_new_tokens": 100} | |
}, token) | |
if 'error' not in test_output: | |
selected_token = token | |
break | |
output = query({ | |
"inputs": input_prompt, | |
"parameters": {"max_new_tokens": 512}, | |
}, selected_token) | |
return output[0]['generated_text'][len(input_prompt):] | |
class Node: | |
def __init__(self, summary=None): | |
self.summary = summary | |
self.children = [] | |
self.parent = None | |
def add_child(self, child_node): | |
child_node.parent = self | |
self.children.append(child_node) | |
class MemWalker: | |
def __init__(self, segments): | |
self.segments = segments | |
self.root = 0 | |
def build_memory_tree(self): | |
# Step 1: Create leaf nodes for each segment | |
leaves = [Node(summarize(seg, 0)) for seg in self.segments] | |
# Step 2: Build tree recursively | |
while len(leaves) > 1: | |
new_leaves = [] | |
for i in range(0, len(leaves), 2): | |
if i + 1 < len(leaves): | |
combined_summary = summarize(leaves[i].summary + ", " + leaves[i + 1].summary, 1) | |
parent_node = Node(combined_summary) | |
parent_node.add_child(leaves[i]) | |
parent_node.add_child(leaves[i + 1]) | |
else: | |
parent_node = leaves[i] | |
new_leaves.append(parent_node) | |
leaves = new_leaves | |
self.root = leaves[0] | |
# Placeholder functions for LLM operations | |
def summarize(text, sum_type: int = 1): | |
assert sum_type in [0, 1], "Lmao sum type should be either 0 or 1" | |
if sum_type == 0: | |
USER_PROMPT = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + text | |
else: | |
USER_PROMPT = "Compress the following summaries into a much shorter summary: " + "\n\n" + text | |
SYS_PROMPT = "Act as a professional technical meeting minutes writer." | |
tmp = gen_prompt(USER_PROMPT, SYS_PROMPT) | |
if len(tmp.split("\n\n")) == 1: | |
return tmp | |
else: | |
return tmp.split("\n\n")[1] | |
#return output[0]['generated_text'][len(input_prompt):] | |
def split_chunk(transcript: str): | |
sentences = nltk.sent_tokenize(transcript) | |
idx = 0 | |
chunk = [] | |
current_chunk = "" | |
while idx < len(sentences): | |
if len(current_chunk + sentences[idx]) < MAX_CHUNK_SIZE: | |
current_chunk += sentences[idx] + " " | |
else: | |
chunk.append(current_chunk) | |
current_chunk = '' | |
for i in range(10, -1, -1): | |
current_chunk += sentences[idx - i] + " " | |
idx += 1 | |
chunk.append(current_chunk) | |
return chunk | |
def summarize_three_ways(chunks: list[str]): | |
SYS_PROMPT = "Act as a professional technical meeting minutes writer." | |
PROMPT_TEMPLATE = "Write a concise summary of the meeting transcript in maximum 5 sentences:" + "\n\n" + "{text}" | |
REFINE_TEMPLATE = ( | |
"Your job is to produce a final summary\n" | |
"We have provided an existing summary up to a certain point: {existing_answer}\n" | |
"We have the opportunity to refine the existing summary" | |
"(only if needed) with some more context below.\n" | |
"------------\n" | |
"{text}\n" | |
"------------\n" | |
f"Given the new context, refine the original summary in English within 5 sentences. If the context isn't useful, return the original summary." | |
) | |
step = 0 | |
prev_sum = "" | |
partial_sum = [] | |
return_dict = {} | |
for chunk in chunks: | |
if step == 0: | |
CUR_PROMPT = PROMPT_TEMPLATE.replace("{text}", chunk) | |
cur_sum = gen_prompt(CUR_PROMPT , SYS_PROMPT) | |
else: | |
CUR_PROMPT = REFINE_TEMPLATE.replace("{existing_answer}", partial_sum[-1]) | |
CUR_PROMPT = CUR_PROMPT.replace("{text}", chunk) | |
cur_sum = gen_prompt(CUR_PROMPT, SYS_PROMPT) | |
if len(cur_sum.split("\n\n")) > 1: | |
cur_sum = cur_sum.split("\n\n")[1] | |
#print(cur_sum) | |
partial_sum.append(cur_sum) | |
step += 1 | |
#print(partial_sum) | |
CUR_PROMPT = "Rewrite the following text by maintaining coherency: " + "\n\n" | |
CUR_PROMPT += ' '.join(partial_sum) | |
tmp = gen_prompt(CUR_PROMPT, SYS_PROMPT) | |
final_sum = '' | |
if len(tmp.split("\n\n")) == 1: | |
final_sum = tmp | |
else: | |
final_sum = tmp.split("\n\n")[1] | |
return_dict['truncated'] = partial_sum[0] | |
return_dict['accumulate'] = partial_sum[-1] | |
return_dict['rewrite'] = final_sum | |
return return_dict | |
def get_example()->list[str]: | |
data = [] | |
with open(filepath + "data/test.json", "r") as f: | |
for line in f: | |
data.append(json.loads(line)) | |
#random_idx = random.sample(list(range(len(data))), 6) | |
random_idx = [1, 2, 9, 13] | |
return ['\n'.join(nltk.sent_tokenize(data[i]['transcript'])) for i in random_idx] | |
if __name__ == "__main__": | |
'''data = [] | |
with open(filepath + "data/test.json", "r") as f: | |
for line in f: | |
data.append(json.loads(line)) | |
tmp = data[:100] | |
for j, i in enumerate(tmp): | |
print(j, len(i['transcript']))''' |