import wandb from wandb.sdk.data_types.trace_tree import Trace def wandb_trace(rag_option, prompt, completion, result, generation_info, llm_output, chain, err_msg, start_time_ms, end_time_ms): wandb.init(project = "openai-llm-rag") trace = Trace( kind = "chain", name = "" if (chain == None) else type(chain).__name__, status_code = "success" if (str(err_msg) == "") else "error", status_message = str(err_msg), metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"], "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"], } if (str(err_msg) == "") else {}, inputs = {"rag_option": rag_option, "prompt": prompt, "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else str(chain.combine_documents_chain.llm_chain.prompt)), "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]), } if (str(err_msg) == "") else {}, outputs = {"result": result, "generation_info": str(generation_info), "llm_output": str(llm_output), "completion": str(completion), } if (str(err_msg) == "") else {}, model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else str(chain.combine_documents_chain.llm_chain.llm.client)), "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else str(chain.combine_documents_chain.llm_chain.llm.model_name)), "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else str(chain.combine_documents_chain.llm_chain.llm.temperature)), "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)), } if (str(err_msg) == "") else {}, start_time_ms = start_time_ms, end_time_ms = end_time_ms ) trace.log("evaluation") wandb.finish()