import pickle import time from grazie.api.client.chat.prompt import ChatPrompt from grazie.api.client.endpoints import GrazieApiGatewayUrls from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType from grazie.api.client.profiles import LLMProfile import config client = GrazieApiGatewayClient( grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"), url=GrazieApiGatewayUrls.STAGING, auth_type=AuthType.USER, grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN ) LLM_CACHE_FILE = config.CACHE_DIR / f"{config.LLM_MODEL}.cache.pkl" LLM_CACHE = {} LLM_CACHE_USED = {} if not LLM_CACHE_FILE.exists(): with open(LLM_CACHE_FILE, "wb") as file: pickle.dump(obj=LLM_CACHE, file=file) with open(LLM_CACHE_FILE, "rb") as file: LLM_CACHE = pickle.load(file=file) def llm_request(prompt): output = None while output is None: try: output = client.chat( chat=ChatPrompt() .add_system("You are a helpful assistant.") .add_user(prompt), profile=LLMProfile(config.LLM_MODEL) ).content except: time.sleep(config.GRAZIE_TIMEOUT_SEC) assert output is not None return output def generate_for_prompt(prompt): if prompt not in LLM_CACHE: LLM_CACHE[prompt] = [] if prompt not in LLM_CACHE_USED: LLM_CACHE_USED[prompt] = 0 while LLM_CACHE_USED[prompt] >= len(LLM_CACHE[prompt]): new_response = llm_request(prompt) LLM_CACHE[prompt].append(new_response) with open(LLM_CACHE_FILE, "wb") as file: pickle.dump(obj=LLM_CACHE, file=file) result = LLM_CACHE[prompt][LLM_CACHE_USED[prompt]] LLM_CACHE_USED[prompt] += 1 return result