Petr Tsvetkov
Length comparison template notebook; grazie token is needed to run
bb44b5c
raw
history blame
1.78 kB
import pickle
import time
from grazie.api.client.chat.prompt import ChatPrompt
from grazie.api.client.endpoints import GrazieApiGatewayUrls
from grazie.api.client.gateway import GrazieApiGatewayClient, GrazieAgent, AuthType
from grazie.api.client.profiles import LLMProfile
import config
client = GrazieApiGatewayClient(
grazie_agent=GrazieAgent("grazie-toolformers", "v1.0"),
url=GrazieApiGatewayUrls.STAGING,
auth_type=AuthType.SERVICE,
grazie_jwt_token=config.GRAZIE_API_JWT_TOKEN
)
LLM_CACHE_FILE = config.CACHE_DIR / f"{config.LLM_MODEL}.cache.pkl"
LLM_CACHE = {}
LLM_CACHE_USED = {}
if not LLM_CACHE_FILE.exists():
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
with open(LLM_CACHE_FILE, "rb") as file:
LLM_CACHE = pickle.load(file=file)
def llm_request(prompt):
output = None
while output is None:
try:
output = client.chat(
chat=ChatPrompt()
.add_system("You are a helpful assistant.")
.add_user(prompt),
profile=LLMProfile(config.LLM_MODEL)
).content
except:
time.sleep(config.GRAZIE_TIMEOUT_SEC)
assert output is not None
return output
def generate_for_prompt(prompt):
if prompt not in LLM_CACHE:
LLM_CACHE[prompt] = []
if prompt not in LLM_CACHE_USED:
LLM_CACHE_USED[prompt] = 0
while LLM_CACHE_USED[prompt] >= len(LLM_CACHE[prompt]):
new_response = llm_request(prompt)
LLM_CACHE[prompt].append(new_response)
with open(LLM_CACHE_FILE, "wb") as file:
pickle.dump(obj=LLM_CACHE, file=file)
result = LLM_CACHE[prompt][LLM_CACHE_USED[prompt]]
LLM_CACHE_USED[prompt] += 1
return result