doberst commited on
Commit
be094e1
1 Parent(s): 23fd48e

Upload generation_test_llmware_script.py

Browse files
Files changed (1) hide show
  1. generation_test_llmware_script.py +16 -10
generation_test_llmware_script.py CHANGED
@@ -11,7 +11,7 @@ def load_rag_benchmark_tester_ds():
11
 
12
  dataset = load_dataset(ds_name)
13
 
14
- print("update: loading test dataset - ", dataset)
15
 
16
  test_set = []
17
  for i, samples in enumerate(dataset["train"]):
@@ -25,9 +25,10 @@ def load_rag_benchmark_tester_ds():
25
 
26
  def run_test(model_name, prompt_list):
27
 
28
- print("\nupdate: Starting RAG Benchmark Inference Test")
29
 
30
- prompter = Prompt().load_model(model_name,from_hf=True)
 
31
 
32
  for i, entries in enumerate(prompt_list):
33
 
@@ -36,21 +37,25 @@ def run_test(model_name, prompt_list):
36
 
37
  response = prompter.prompt_main(prompt,context=context,prompt_name="default_with_context", temperature=0.3)
38
 
 
 
 
39
  fc = prompter.evidence_check_numbers(response)
40
  sc = prompter.evidence_comparison_stats(response)
41
  sr = prompter.evidence_check_sources(response)
42
 
43
- print("\nupdate: model inference output - ", i, response["llm_response"])
44
- print("update: gold_answer - ", i, entries["answer"])
45
 
46
  for entries in fc:
47
- print("update: fact check - ", entries["fact_check"])
 
48
 
49
  for entries in sc:
50
  print("update: comparison stats - ", entries["comparison_stats"])
51
 
52
  for entries in sr:
53
- print("update: sources - ", entries["source_review"])
 
54
 
55
  return 0
56
 
@@ -59,6 +64,7 @@ if __name__ == "__main__":
59
 
60
  core_test_set = load_rag_benchmark_tester_ds()
61
 
62
- model_name = "llmware/dragon-red-pajama-7b-v0"
63
-
64
- output = run_test(model_name, core_test_set)
 
 
11
 
12
  dataset = load_dataset(ds_name)
13
 
14
+ print("update: loading RAG Benchmark test dataset - ", dataset)
15
 
16
  test_set = []
17
  for i, samples in enumerate(dataset["train"]):
 
25
 
26
  def run_test(model_name, prompt_list):
27
 
28
+ print("\nupdate: Starting RAG Benchmark Inference Test - ", model_name)
29
 
30
+ # pull DRAGON / BLING model directly from catalog, e.g., no from_hf=True
31
+ prompter = Prompt().load_model(model_name)
32
 
33
  for i, entries in enumerate(prompt_list):
34
 
 
37
 
38
  response = prompter.prompt_main(prompt,context=context,prompt_name="default_with_context", temperature=0.3)
39
 
40
+ print("\nupdate: model inference output - ", i, response["llm_response"])
41
+ print("update: gold_answer - ", i, entries["answer"])
42
+
43
  fc = prompter.evidence_check_numbers(response)
44
  sc = prompter.evidence_comparison_stats(response)
45
  sr = prompter.evidence_check_sources(response)
46
 
47
+ print("\nFact-Checking Tools")
 
48
 
49
  for entries in fc:
50
+ for f, facts in enumerate(entries["fact_check"]):
51
+ print("update: fact check - ", f, facts)
52
 
53
  for entries in sc:
54
  print("update: comparison stats - ", entries["comparison_stats"])
55
 
56
  for entries in sr:
57
+ for s, sources in enumerate(entries["source_review"]):
58
+ print("update: sources - ", s, sources)
59
 
60
  return 0
61
 
 
64
 
65
  core_test_set = load_rag_benchmark_tester_ds()
66
 
67
+ # one of the 7 gpu dragon models
68
+ gpu_model_name = "llmware/dragon-red-pajama-7b-v0"
69
+
70
+ output = run_test(gpu_model_name, core_test_set)