File size: 1,279 Bytes
2e4bc3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
"""Evaluate AMBER models"""
import argparse
import mteb
from models import PROMPTS
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, required=True, help="Model name", choices=PROMPTS.keys())
parser.add_argument("--model_name_or_path", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--output_dir", type=str, required=True, help="Output directory")
parser.add_argument("--corpus_chunk_size", type=int, default=50000)
parser.add_argument("--convert_to_tensor", action="store_true")
return parser.parse_args()
def main():
args = get_args()
prompt = PROMPTS[args.model_type]
model = mteb.get_model(args.model_name_or_path, model_prompts=prompt)
tasks = [mteb.get_task("MultiLongDocRetrieval", languages=["jpn"])]
evaluation = mteb.MTEB(tasks=tasks)
encode_kwargs = {
"batch_size": args.batch_size,
"convert_to_tensor": args.convert_to_tensor,
}
evaluation.run(
model,
output_folder=args.output_dir,
encode_kwargs=encode_kwargs,
corpus_chunk_size=args.corpus_chunk_size,
)
if __name__ == "__main__":
main()
|