|
import os, wandb |
|
|
|
from wandb.sdk.data_types.trace_tree import Trace |
|
|
|
WANDB_API_KEY = os.environ["WANDB_API_KEY"] |
|
|
|
def wandb_trace(config, |
|
is_rag_off, |
|
prompt, |
|
completion, |
|
result, |
|
generation_info, |
|
llm_output, |
|
chain, |
|
err_msg, |
|
start_time_ms, |
|
end_time_ms): |
|
wandb.init(project = "openai-llm-rag") |
|
|
|
trace = Trace( |
|
kind = "chain", |
|
name = "" if (chain == None) else type(chain).__name__, |
|
status_code = "success" if (str(err_msg) == "") else "error", |
|
status_message = str(err_msg), |
|
metadata = {"chunk_overlap": "" if (is_rag_off) else config["chunk_overlap"], |
|
"chunk_size": "" if (is_rag_off) else config["chunk_size"], |
|
} if (str(err_msg) == "") else {}, |
|
inputs = {"is_rag": not is_rag_off, |
|
"prompt": prompt, |
|
"chain_prompt": (str(chain.prompt) if (is_rag_off) else |
|
str(chain.combine_documents_chain.llm_chain.prompt)), |
|
"source_documents": "" if (is_rag_off) else str([doc.metadata["source"] for doc in completion["source_documents"]]), |
|
} if (str(err_msg) == "") else {}, |
|
outputs = {"result": result, |
|
"generation_info": str(generation_info), |
|
"llm_output": str(llm_output), |
|
"completion": str(completion), |
|
} if (str(err_msg) == "") else {}, |
|
model_dict = {"client": (str(chain.llm.client) if (is_rag_off) else |
|
str(chain.combine_documents_chain.llm_chain.llm.client)), |
|
"model_name": (str(chain.llm.model_name) if (is_rag_off) else |
|
str(chain.combine_documents_chain.llm_chain.llm.model_name)), |
|
"temperature": (str(chain.llm.temperature) if (is_rag_off) else |
|
str(chain.combine_documents_chain.llm_chain.llm.temperature)), |
|
"retriever": ("" if (is_rag_off) else str(chain.retriever)), |
|
} if (str(err_msg) == "") else {}, |
|
start_time_ms = start_time_ms, |
|
end_time_ms = end_time_ms |
|
) |
|
|
|
trace.log("evaluation") |
|
|
|
wandb.finish() |