Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Utility functions for the project.""" | |
from __future__ import print_function | |
# pylint: disable=g-importing-member | |
from collections import defaultdict | |
from collections import deque | |
from copy import deepcopy | |
import datetime | |
import errno | |
import os | |
import sys | |
import time | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
import yaml | |
# pylint: disable=g-bad-import-order | |
from data.voc import CLASS2ID | |
from data.voc import VOC_CLASSES | |
_MB = 1024.0 * 1024.0 | |
DINO_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
]) | |
class Config: | |
def __init__(self, **kwargs): | |
for key, value in kwargs.items(): | |
if isinstance(value, dict): | |
setattr(self, key, Config(**value)) | |
else: | |
setattr(self, key, value) | |
def load_yaml(filename): | |
with open(filename) as file: | |
try: | |
data = yaml.safe_load(file) | |
return data | |
except yaml.YAMLError as e: | |
print(f"Error while loading YAML file: {e}") | |
def normalize(x, dim=None, eps=1e-15): | |
if dim is None: | |
return (x - x.min()) / (x.max() - x.min()) | |
# Normalize to [0, 1]. | |
numerator = x - x.min(axis=dim, keepdims=True)[0] | |
denominator = ( | |
x.max(axis=dim, keepdims=True)[0] | |
- x.min(axis=dim, keepdims=True)[0] | |
+ eps | |
) | |
return numerator / denominator | |
class SmoothedValue(object): | |
"""Track a series of values and provide access to smoothed values over a window or the global series average.""" | |
def __init__(self, window_size=20, fmt=None): | |
if fmt is None: | |
fmt = "{median:.4f} ({global_avg:.4f})" | |
self.deque = deque(maxlen=window_size) | |
self.total = 0.0 | |
self.count = 0 | |
self.fmt = fmt | |
def update(self, value, n=1): | |
self.deque.append(value) | |
self.count += n | |
self.total += value * n | |
# def synchronize_between_processes(self): | |
# """ | |
# Warning: does not synchronize the deque! | |
# """ | |
# if not is_dist_avail_and_initialized(): | |
# return | |
# t = torch.tensor([self.count, self.total], | |
# dtype=torch.float64, device='cuda') | |
# dist.barrier() | |
# dist.all_reduce(t) | |
# t = t.tolist() | |
# self.count = int(t[0]) | |
# self.total = t[1] | |
def median(self): | |
d = torch.tensor(list(self.deque)) | |
return d.median().item() | |
def avg(self): | |
d = torch.tensor(list(self.deque), dtype=torch.float32) | |
return d.mean().item() | |
def global_avg(self): | |
return self.total / self.count | |
def max(self): | |
return max(self.deque) | |
def value(self): | |
return self.deque[-1] | |
def __str__(self): | |
return self.fmt.format( | |
median=self.median, | |
avg=self.avg, | |
global_avg=self.global_avg, | |
max=self.max, | |
value=self.value, | |
) | |
class MetricLogger(object): | |
"""Log the metrics.""" | |
def __init__(self, delimiter="\t"): | |
self.meters = defaultdict(SmoothedValue) | |
self.delimiter = delimiter | |
def update(self, **kwargs): | |
for k, v in kwargs.items(): | |
if isinstance(v, torch.Tensor): | |
v = v.item() | |
assert isinstance(v, (float, int)) | |
self.meters[k].update(v) | |
def __getattr__(self, attr): | |
if attr in self.meters: | |
return self.meters[attr] | |
if attr in self.__dict__: | |
return self.__dict__[attr] | |
raise AttributeError( | |
"'{}' object has no attribute '{}'".format(type(self).__name__, attr) | |
) | |
def __str__(self): | |
loss_str = [] | |
for name, meter in self.meters.items(): | |
loss_str.append("{}: {}".format(name, str(meter))) | |
return self.delimiter.join(loss_str) | |
def synchronize_between_processes(self): | |
for meter in self.meters.values(): | |
meter.synchronize_between_processes() | |
def add_meter(self, name, meter): | |
self.meters[name] = meter | |
def log_every(self, iterable, print_freq, header=None): | |
"""Log every `print_freq` times.""" | |
i = 0 | |
if not header: | |
header = "" | |
start_time = time.time() | |
end = time.time() | |
iter_time = SmoothedValue(fmt="{avg:.4f}") | |
data_time = SmoothedValue(fmt="{avg:.4f}") | |
space_fmt = ":" + str(len(str(len(iterable)))) + "d" | |
log_msg = self.delimiter.join([ | |
header, | |
"[{0" + space_fmt + "}/{1}]", | |
"eta: {eta}", | |
"{meters}", | |
"time: {time}", | |
"data: {data}", | |
"max mem: {memory:.0f}", | |
]) | |
for obj in iterable: | |
data_time.update(time.time() - end) | |
yield obj | |
iter_time.update(time.time() - end) | |
if i % print_freq == 0: | |
eta_seconds = iter_time.global_avg * (len(iterable) - i) | |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) | |
print( | |
log_msg.format( | |
i, | |
len(iterable), | |
eta=eta_string, | |
meters=str(self), | |
time=str(iter_time), | |
data=str(data_time), | |
memory=torch.cuda.max_memory_allocated() / _MB, | |
) | |
) | |
sys.stdout.flush() | |
i += 1 | |
end = time.time() | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
print("{} Total time: {}".format(header, total_time_str)) | |
def mkdir(path): | |
try: | |
os.makedirs(path) | |
except OSError as e: | |
if e.errno != errno.EEXIST: | |
raise | |
def pad_to_square(im): | |
"""Pad the images to square shape.""" | |
im = deepcopy(im) | |
width, height = im.size | |
top_pad = (max(width, height) - height) // 2 | |
bot_pad = max(width, height) - height - top_pad | |
left_pad = (max(width, height) - width) // 2 | |
right_pad = max(width, height) - width - left_pad | |
if len(im.mode) == 3: | |
color = (0, 0, 0) | |
elif len(im.mode) == 1: | |
color = 0 | |
else: | |
raise ValueError(f"Image mode not supported. Image has {im.mode} channels.") | |
return add_margin(im, top_pad, right_pad, bot_pad, left_pad, color=color) | |
def add_margin(pil_img, top, right, bottom, left, color=(0, 0, 0)): | |
"""Ref: https://note.nkmk.me/en/python-pillow-add-margin-expand-canvas/.""" | |
width, height = pil_img.size | |
new_width = width + right + left | |
new_height = height + top + bottom | |
result = Image.new(pil_img.mode, (new_width, new_height), color) | |
result.paste(pil_img, (left, top)) | |
# 1 represents the image, 0 represents the padding | |
pad = [left, top, width, height] | |
return result, pad | |
def process_sentence(sentence, ds_name): | |
"""Dataset specific sentence processing.""" | |
if "refcoco" in ds_name: | |
sentence = sentence[0].lower() | |
# get rid of special characters | |
sentence = sentence.replace('"', "") | |
sentence = sentence.replace("/", "") | |
if ds_name == "voc": | |
if sentence in list(CLASS2ID.keys()): | |
label_id = CLASS2ID[sentence] - 1 | |
sentence = VOC_CLASSES[label_id] | |
if not isinstance(sentence, str): | |
sentence = sentence[0] | |
return sentence | |