File size: 2,992 Bytes
08ae6c5
 
 
 
f72e694
08ae6c5
1d6da9d
7798457
08ae6c5
eb2a0ba
45e5a75
08ae6c5
 
 
72bd0af
 
 
 
08ae6c5
 
 
 
72bd0af
9378bc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d6da9d
 
 
2e91fe1
 
1d6da9d
 
7798457
1d6da9d
 
 
d876fb6
 
9378bc3
f72e694
1d6da9d
 
 
 
eb2a0ba
1d6da9d
51bfeec
7798457
 
1d6da9d
7798457
 
 
 
08ae6c5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import json
import os
import logging
from datetime import datetime
from argparse import Namespace

from lighteval.main_accelerate import main, EnvConfig, create_model_config, load_model
from src.envs import RESULTS_REPO, CACHE_PATH, TOKEN, OWNER
from src.backend.manage_requests import EvalRequest
from lighteval.logging.evaluation_tracker import EnhancedJSONEncoder
from lighteval.models.model_loader import ModelInfo

logging.getLogger("openai").setLevel(logging.WARNING)

class DefaultNamespace(Namespace):
    def __getattr__(self, name):
        return self.__dict__.get(name, None)

def run_evaluation(eval_request: EvalRequest, task_names: str, batch_size: int, local_dir: str, accelerator: str, region: str, vendor: str, instance_size: str, instance_type: str, limit=None):
    if limit:
        print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")

    args = DefaultNamespace(**{
            "model_config": dict(model=dict(
                type="endpoint",
                base_params=dict(
                    endpoint_name=f'{eval_request.model.split("/")[1].replace(".", "-").lower()}-lighteval'[-32:],
                    model=eval_request.model,
                    revision=eval_request.revision,
                    dtype=eval_request.precision,
                    reuse_existing=False
                ),
                instance=dict(
                    accelerator=accelerator,
                    region=region,
                    vendor=vendor,
                    instance_size=instance_size,
                    instance_type=instance_type,
                    framework='pytorch',
                    endpoint_type='protected',
                    namespace=OWNER
                ),
                generation=dict(
                    add_special_tokens=True
                )
            )),
            "max_samples": limit,
            "job_id": str(datetime.now()),
            "push_results_to_hub": True,
            "save_details": False,
            "push_details_to_hub": False,
            "public_run": False,
            "cache_dir": CACHE_PATH,
            "results_org": OWNER,
            "output_dir": local_dir,
            "override_batch_size": batch_size,
            "custom_tasks": "custom_tasks.py",
            "tasks": task_names,
            "dataset_loading_processes": 24,
            "num_fewshot_seeds": 0
    })

    try:
        results = main(args)

        dumped = json.dumps(results, cls=EnhancedJSONEncoder, indent=2)
        print(dumped)
    except Exception as ex: # if eval failed, we force a cleanup
        import traceback
        traceback.print_exception(ex)
        env_config = EnvConfig(token=TOKEN, cache_dir=args.cache_dir)
        args.reuse_existing = True
        model_config = create_model_config(args=args, accelerator=accelerator)
        model, _ = load_model(config=model_config, env_config=env_config)
        model.cleanup()

    return results