File size: 2,295 Bytes
e874531
fcd3a75
a857f12
 
9e0f123
 
301617d
b552593
86ffba3
 
 
 
c097673
86ffba3
 
 
a857f12
 
 
 
 
 
 
c962a65
 
a857f12
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
 
b4f203e
a857f12
 
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
c962a65
a857f12
 
 
 
 
 
86ffba3
a857f12
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os, wandb

from wandb.sdk.data_types.trace_tree import Trace

WANDB_API_KEY = os.environ["WANDB_API_KEY"]

def trace_wandb(config,
                is_rag_off, 
                prompt, 
                completion, 
                result, 
                chain, 
                cb, 
                err_msg, 
                start_time_ms, 
                end_time_ms):
    wandb.init(project = "openai-llm-rag")
    
    trace = Trace(
        kind = "chain",
        name = "" if (chain == None) else type(chain).__name__,
        status_code = "success" if (str(err_msg) == "") else "error",
        status_message = str(err_msg),
        metadata = {"chunk_overlap": "" if (is_rag_off) else config["chunk_overlap"],
                    "chunk_size": "" if (is_rag_off) else config["chunk_size"],
                   } if (str(err_msg) == "") else {},
        inputs = {"is_rag": not is_rag_off,
                  "prompt": prompt,
                  "chain_prompt": (str(chain.prompt) if (is_rag_off) else 
                                   str(chain.combine_documents_chain.llm_chain.prompt)),
                  "source_documents": "" if (is_rag_off) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
                 } if (str(err_msg) == "") else {},
        outputs = {"result": result,
                   "cb": str(cb),
                   "completion": str(completion),
                  } if (str(err_msg) == "") else {},
        model_dict = {"client": (str(chain.llm.client) if (is_rag_off) else
                                 str(chain.combine_documents_chain.llm_chain.llm.client)),
                      "model_name": (str(chain.llm.model_name) if (is_rag_off) else
                                     str(chain.combine_documents_chain.llm_chain.llm.model_name)),
                      "temperature": (str(chain.llm.temperature) if (is_rag_off) else
                                      str(chain.combine_documents_chain.llm_chain.llm.temperature)),
                      "retriever": ("" if (is_rag_off) else str(chain.retriever)),
                     } if (str(err_msg) == "") else {},
        start_time_ms = start_time_ms,
        end_time_ms = end_time_ms
    )
    
    trace.log("evaluation")
                    
    wandb.finish()