File size: 1,279 Bytes
2e4bc3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Evaluate AMBER models"""

import argparse

import mteb

from models import PROMPTS 


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True, help="Model name", choices=PROMPTS.keys())
    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
    parser.add_argument("--corpus_chunk_size", type=int, default=50000)
    parser.add_argument("--convert_to_tensor", action="store_true")
    return parser.parse_args()


def main():
    args = get_args()
    prompt = PROMPTS[args.model_type]
    model = mteb.get_model(args.model_name_or_path, model_prompts=prompt)

    tasks = [mteb.get_task("MultiLongDocRetrieval", languages=["jpn"])]
    evaluation = mteb.MTEB(tasks=tasks)

    encode_kwargs = {
        "batch_size": args.batch_size,
        "convert_to_tensor": args.convert_to_tensor,
    }

    evaluation.run(
        model,
        output_folder=args.output_dir,
        encode_kwargs=encode_kwargs,
        corpus_chunk_size=args.corpus_chunk_size,
    )


if __name__ == "__main__":
    main()