bstraehle commited on
Commit
c962a65
1 Parent(s): 93003ed

Update trace.py

Browse files
Files changed (1) hide show
  1. trace.py +10 -12
trace.py CHANGED
@@ -4,9 +4,7 @@ from wandb.sdk.data_types.trace_tree import Trace
4
 
5
  WANDB_API_KEY = os.environ["WANDB_API_KEY"]
6
 
7
- RAG_OFF = "Off"
8
-
9
- def wandb_trace(rag_option,
10
  prompt,
11
  completion,
12
  result,
@@ -23,27 +21,27 @@ def wandb_trace(rag_option,
23
  name = "" if (chain == None) else type(chain).__name__,
24
  status_code = "success" if (str(err_msg) == "") else "error",
25
  status_message = str(err_msg),
26
- metadata = {"chunk_overlap": "" if (rag_option == RAG_OFF) else config["chunk_overlap"],
27
- "chunk_size": "" if (rag_option == RAG_OFF) else config["chunk_size"],
28
  } if (str(err_msg) == "") else {},
29
- inputs = {"rag_option": rag_option,
30
  "prompt": prompt,
31
- "chain_prompt": (str(chain.prompt) if (rag_option == RAG_OFF) else
32
  str(chain.combine_documents_chain.llm_chain.prompt)),
33
- "source_documents": "" if (rag_option == RAG_OFF) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
34
  } if (str(err_msg) == "") else {},
35
  outputs = {"result": result,
36
  "generation_info": str(generation_info),
37
  "llm_output": str(llm_output),
38
  "completion": str(completion),
39
  } if (str(err_msg) == "") else {},
40
- model_dict = {"client": (str(chain.llm.client) if (rag_option == RAG_OFF) else
41
  str(chain.combine_documents_chain.llm_chain.llm.client)),
42
- "model_name": (str(chain.llm.model_name) if (rag_option == RAG_OFF) else
43
  str(chain.combine_documents_chain.llm_chain.llm.model_name)),
44
- "temperature": (str(chain.llm.temperature) if (rag_option == RAG_OFF) else
45
  str(chain.combine_documents_chain.llm_chain.llm.temperature)),
46
- "retriever": ("" if (rag_option == RAG_OFF) else str(chain.retriever)),
47
  } if (str(err_msg) == "") else {},
48
  start_time_ms = start_time_ms,
49
  end_time_ms = end_time_ms
 
4
 
5
  WANDB_API_KEY = os.environ["WANDB_API_KEY"]
6
 
7
+ def wandb_trace(is_rag_off,
 
 
8
  prompt,
9
  completion,
10
  result,
 
21
  name = "" if (chain == None) else type(chain).__name__,
22
  status_code = "success" if (str(err_msg) == "") else "error",
23
  status_message = str(err_msg),
24
+ metadata = {"chunk_overlap": "" if (is_rag_off) else config["chunk_overlap"],
25
+ "chunk_size": "" if (is_rag_off) else config["chunk_size"],
26
  } if (str(err_msg) == "") else {},
27
+ inputs = {"is_rag": not is_rag_off,
28
  "prompt": prompt,
29
+ "chain_prompt": (str(chain.prompt) if (is_rag_off) else
30
  str(chain.combine_documents_chain.llm_chain.prompt)),
31
+ "source_documents": "" if (is_rag_off) else str([doc.metadata["source"] for doc in completion["source_documents"]]),
32
  } if (str(err_msg) == "") else {},
33
  outputs = {"result": result,
34
  "generation_info": str(generation_info),
35
  "llm_output": str(llm_output),
36
  "completion": str(completion),
37
  } if (str(err_msg) == "") else {},
38
+ model_dict = {"client": (str(chain.llm.client) if (is_rag_off) else
39
  str(chain.combine_documents_chain.llm_chain.llm.client)),
40
+ "model_name": (str(chain.llm.model_name) if (is_rag_off) else
41
  str(chain.combine_documents_chain.llm_chain.llm.model_name)),
42
+ "temperature": (str(chain.llm.temperature) if (is_rag_off) else
43
  str(chain.combine_documents_chain.llm_chain.llm.temperature)),
44
+ "retriever": ("" if (is_rag_off) else str(chain.retriever)),
45
  } if (str(err_msg) == "") else {},
46
  start_time_ms = start_time_ms,
47
  end_time_ms = end_time_ms