""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ from medomni.common.registry import registry from medomni.tasks.base_task import BaseTask from medomni.common.logger import MetricLogger, SmoothedValue from medomni.datasets.data_utils import prepare_sample import torch.distributed as dist @registry.register_task("image_text_pretrain") class ImageTextPretrainTask(BaseTask): def __init__(self): super().__init__() def evaluation(self, model, data_loader, cuda_enabled=True): if not hasattr(data_loader, "__next__"): # convert to iterator if not already data_loader = iter(data_loader) metric_logger = MetricLogger(delimiter=" ") header = "Evaluation" # TODO make it configurable print_freq = 10 results = [] ipdb.set_trace() for samples in metric_logger.log_every(data_loader, print_freq, header): samples = prepare_sample(samples, cuda_enabled=cuda_enabled) eval_output = self.valid_step(model=model, samples=samples) results.extend(eval_output) if is_dist_avail_and_initialized(): dist.barrier() return results