Daniel Nichols
add max code len to control length
972361d
raw
history blame
7.89 kB
""" Techniques for formatting the prompts that are passed to the LLMs.
These need to handle 2 major tasks:
1. Taking a directory of source code and embedding it in the prompt meaningfully (and possibly concatenating it).
2. Embedding a performance profile in the prompt if available.
"""
from abc import ABC, abstractmethod
from typing import Optional, List, Mapping
from os import PathLike
from os.path import basename
import random
from function_grabber import get_function_at_line
from profiles import Profile
def truncate_string(s: str, max_len: int) -> str:
return s if len(s) <= max_len else s[:max_len] + "..."
class PerfGuruPromptFormatter(ABC):
def __init__(self, name: str):
self.name = name
def _read_code_files(self, code_paths: List[PathLike]) -> Mapping[PathLike, str]:
code_files = {}
for code_path in code_paths:
with open(code_path, "r") as file:
code_files[code_path] = file.read()
return code_files
def _read_profile(self, profile_path: PathLike, profile_type: str) -> Profile:
return Profile(profile_path, profile_type)
@abstractmethod
def format_prompt(self, prompt: str, code_paths: List[PathLike], profile_path: Optional[PathLike] = None, profile_type: Optional[str] = None, error_fn: Optional[callable] = None) -> str:
pass
class BasicPromptFormatter(PerfGuruPromptFormatter):
def __init__(self):
super().__init__("basic")
def format_prompt(self, prompt: str, code_paths: List[PathLike], profile_path: Optional[PathLike] = None, profile_type: Optional[str] = None, error_fn: Optional[callable] = None) -> str:
if not code_paths:
if error_fn:
error_fn("No code files provided. At least one code file must be provided.")
return None
concatenated_code = ""
code_file_contents = self._read_code_files(code_paths)
for code_path, content in code_file_contents.items():
fname = basename(code_path)
concatenated_code += f"{fname}:\n{content}\n\n"
concatenated_code = truncate_string(concatenated_code, 4000)
if profile_path:
if not profile_type:
if error_fn:
error_fn("Profile type must be provided if a profile file is provided.")
return None
profile = self._read_profile(profile_path, profile_type)
profile_content = profile.profile_to_tree_str()
profile_content = truncate_string(profile_content, 4000)
else:
profile_content = ""
return f"Code:\n{concatenated_code}\n\n{profile_type} Profile:\n{profile_content}\n\n{prompt}"
class SlowestFunctionPromptFormatter(PerfGuruPromptFormatter):
def __init__(self, k):
super().__init__("slowest_function")
self.k = k
def format_prompt(self, prompt: str, code_paths: List[PathLike], profile_path: Optional[PathLike] = None, profile_type: Optional[str] = None, error_fn: Optional[callable] = None) -> str:
if not code_paths:
if error_fn:
error_fn("No code files provided. At least one code file must be provided.")
return None
concatenated_code = ""
code_file_contents = self._read_code_files(code_paths)
for code_path, content in code_file_contents.items():
fname = basename(code_path)
concatenated_code += f"{fname}:\n{content}\n\n"
concatenated_code = truncate_string(concatenated_code, 4000)
if profile_path:
if not profile_type:
if error_fn:
error_fn("Profile type must be provided if a profile file is provided.")
return None
try:
profile = self._read_profile(profile_path, profile_type)
slowest = profile.gf.dataframe.nlargest(self.k, 'time')
function_names = [slowest['name'].values[i] for i in range(self.k) if i < len(slowest['name'].values)]
execution_times = [slowest['time'].values[i] for i in range(self.k) if i < len(slowest['name'].values)]
hot_path = profile.gf.hot_path()
hot_path_functions = []
for node in hot_path:
if "name" in node.frame.attrs:
hot_path_functions.append(node.frame["name"])
hot_path_functions = hot_path_functions[:self.k]
profile_content = (f"The slowest functions are {function_names} and they took {execution_times} seconds, respectively." +
f" Also, these functions were in the hot path: {hot_path_functions}.")
except:
profile_content = ""
else:
profile_content = ""
return f"Code:\n{concatenated_code}\n\n{profile_type} Profile:\n{profile_content}\n\n{prompt}"
class SlowestFunctionParsedPromptFormatter(PerfGuruPromptFormatter):
def __init__(self):
super().__init__("slowest_function_parsed")
def format_prompt(self, prompt: str, code_paths: List[PathLike], profile_path: Optional[PathLike] = None, profile_type: Optional[str] = None, error_fn: Optional[callable] = None) -> str:
if not code_paths:
if error_fn:
error_fn("No code files provided. At least one code file must be provided.")
return None
concatenated_code = ""
profile_content = ""
if profile_path:
if not profile_type:
if error_fn:
error_fn("Profile type must be provided if a profile file is provided.")
return None
try:
k = 1
profile = self._read_profile(profile_path, profile_type)
slowest = profile.gf.dataframe.nlargest(k, 'time')
function_name = slowest['name'].values[0] if len(slowest['name'].values) > 0 else None
line_number = slowest['line'].values[0] if len(slowest['line'].values) > 0 else None
code = None
if line_number:
filename = ""
code_file_contents = self._read_code_files(code_paths)
for code_path, content in code_file_contents.items():
filename = basename(code_path)
code, _ = get_function_at_line(filename, str(line_number))
if code:
break
if code:
concatenated_code = f"{fname}:\n{code}\n\n"
print("Only function code:", concatenated_code)
profile_content = (f"The slowest function is {function_name}.")
except:
profile_content = ""
if concatenated_code == "":
code_file_contents = self._read_code_files(code_paths)
for code_path, content in code_file_contents.items():
fname = basename(code_path)
concatenated_code += f"{fname}:\n{content}\n\n"
concatenated_code = truncate_string(concatenated_code, 4000)
return f"Code:\n{concatenated_code}\n\n{profile_type} Profile:\n{profile_content}\n\n{prompt}"
AVAILABLE_FORMATTERS = []
AVAILABLE_FORMATTERS.append(SlowestFunctionPromptFormatter(k=1))
AVAILABLE_FORMATTERS.append(SlowestFunctionPromptFormatter(k=5))
AVAILABLE_FORMATTERS.append(SlowestFunctionPromptFormatter(k=10))
# AVAILABLE_FORMATTERS.append(BasicPromptFormatter())
AVAILABLE_FORMATTERS.append(SlowestFunctionParsedPromptFormatter())
def select_random_formatter() -> PerfGuruPromptFormatter:
return random.choice(AVAILABLE_FORMATTERS)