|
import pickle |
|
import time |
|
|
|
from grazie.api.client.chat.prompt import ChatPrompt |
|
from grazie.api.client.endpoints import GrazieApiGatewayUrls |
|
from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType |
|
from grazie.api.client.profiles import LLMProfile |
|
|
|
import config |
|
|
|
client = GrazieApiGatewayClient( |
|
grazie_agent=GrazieAgent(name="commit-rewriting-synthetic-end-to-start", version="dev"), |
|
url=GrazieApiGatewayUrls.STAGING, |
|
auth_type=AuthType.USER, |
|
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN |
|
) |
|
|
|
LLM_CACHE_FILE = config.CACHE_DIR / f"{config.LLM_MODEL}.cache.pkl" |
|
LLM_CACHE = {} |
|
LLM_CACHE_USED = {} |
|
|
|
if not LLM_CACHE_FILE.exists(): |
|
with open(LLM_CACHE_FILE, "wb") as file: |
|
pickle.dump(obj=LLM_CACHE, file=file) |
|
|
|
with open(LLM_CACHE_FILE, "rb") as file: |
|
LLM_CACHE = pickle.load(file=file) |
|
|
|
|
|
def llm_request(prompt): |
|
output = None |
|
|
|
while output is None: |
|
try: |
|
output = client.chat( |
|
chat=ChatPrompt() |
|
.add_system("You are a helpful assistant.") |
|
.add_user(prompt), |
|
profile=LLMProfile(config.LLM_MODEL) |
|
).content |
|
except: |
|
time.sleep(config.GRAZIE_TIMEOUT_SEC) |
|
|
|
assert output is not None |
|
|
|
return output |
|
|
|
|
|
def generate_for_prompt(prompt): |
|
if prompt not in LLM_CACHE: |
|
LLM_CACHE[prompt] = [] |
|
|
|
if prompt not in LLM_CACHE_USED: |
|
LLM_CACHE_USED[prompt] = 0 |
|
|
|
while LLM_CACHE_USED[prompt] >= len(LLM_CACHE[prompt]): |
|
new_response = llm_request(prompt) |
|
LLM_CACHE[prompt].append(new_response) |
|
|
|
with open(LLM_CACHE_FILE, "wb") as file: |
|
pickle.dump(obj=LLM_CACHE, file=file) |
|
|
|
result = LLM_CACHE[prompt][LLM_CACHE_USED[prompt]] |
|
LLM_CACHE_USED[prompt] += 1 |
|
|
|
return result |
|
|