File size: 2,413 Bytes
e874531
fcd3a75
a857f12
 
9e0f123
 
b552593
 
86ffba3
 
 
 
 
 
 
 
 
a857f12
 
 
 
 
 
 
c962a65
 
a857f12
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
 
 
 
 
 
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
 
 
 
 
 
86ffba3
a857f12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os, wandb

from wandb.sdk.data_types.trace_tree import Trace

WANDB_API_KEY = os.environ["WANDB_API_KEY"]

def wandb_trace(config,
                is_rag_off, 
                prompt, 
                completion, 
                result, 
                generation_info, 
                llm_output, 
                chain, 
                err_msg, 
                start_time_ms, 
                end_time_ms):
    wandb.init(project = "openai-llm-rag")
    
    trace = Trace(
        kind = "chain",
        name = "" if (chain == None) else type(chain).__name__,
        status_code = "success" if (str(err_msg) == "") else "error",
        status_message = str(err_msg),
        metadata = {"chunk_overlap": "" if (is_rag_off) else config["chunk_overlap"],
                    "chunk_size": "" if (is_rag_off) else config["chunk_size"],
                   } if (str(err_msg) == "") else {},
        inputs = {"is_rag": not is_rag_off,
                  "prompt": prompt,
                  "chain_prompt": (str(chain.prompt) if (is_rag_off) else 
                                   str(chain.combine_documents_chain.llm_chain.prompt)),
                  "source_documents": "" if (is_rag_off) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
                 } if (str(err_msg) == "") else {},
        outputs = {"result": result,
                   "generation_info": str(generation_info),
                   "llm_output": str(llm_output),
                   "completion": str(completion),
                  } if (str(err_msg) == "") else {},
        model_dict = {"client": (str(chain.llm.client) if (is_rag_off) else
                                 str(chain.combine_documents_chain.llm_chain.llm.client)),
                      "model_name": (str(chain.llm.model_name) if (is_rag_off) else
                                     str(chain.combine_documents_chain.llm_chain.llm.model_name)),
                      "temperature": (str(chain.llm.temperature) if (is_rag_off) else
                                      str(chain.combine_documents_chain.llm_chain.llm.temperature)),
                      "retriever": ("" if (is_rag_off) else str(chain.retriever)),
                     } if (str(err_msg) == "") else {},
        start_time_ms = start_time_ms,
        end_time_ms = end_time_ms
    )
    
    trace.log("evaluation")
                    
    wandb.finish()