|
import torch.multiprocessing as multiprocessing |
|
import torchvision.transforms as transforms |
|
from torch import autocast |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image |
|
import torch |
|
from torchvision.transforms import InterpolationMode |
|
from tqdm import tqdm |
|
import json |
|
import os |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
torch.autograd.set_detect_anomaly(False) |
|
torch.autograd.profiler.emit_nvtx(enabled=False) |
|
torch.autograd.profiler.profile(enabled=False) |
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
class ImageDataset(Dataset): |
|
def __init__(self, image_folder_path, allowed_extensions): |
|
self.allowed_extensions = allowed_extensions |
|
self.all_image_paths, self.all_image_names, self.image_base_paths = self.get_image_paths(image_folder_path) |
|
self.train_size = len(self.all_image_paths) |
|
print(f"Number of images to be tagged: {self.train_size}") |
|
self.thin_transform = transforms.Compose([ |
|
transforms.Resize(224, interpolation=InterpolationMode.BICUBIC), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[ |
|
0.48145466, |
|
0.4578275, |
|
0.40821073 |
|
], std=[ |
|
0.26862954, |
|
0.26130258, |
|
0.27577711 |
|
]) |
|
]) |
|
self.normal_transform = transforms.Compose([ |
|
transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[ |
|
0.48145466, |
|
0.4578275, |
|
0.40821073 |
|
], std=[ |
|
0.26862954, |
|
0.26130258, |
|
0.27577711 |
|
]) |
|
|
|
]) |
|
|
|
def get_image_paths(self, folder_path): |
|
image_paths = [] |
|
image_file_names = [] |
|
image_base_paths = [] |
|
for root, dirs, files in os.walk(folder_path): |
|
for file in files: |
|
if file.lower().split(".")[-1] in self.allowed_extensions: |
|
image_paths.append((os.path.abspath(os.path.join(root, file)))) |
|
image_file_names.append(file.split(".")[0]) |
|
image_base_paths.append(root) |
|
return image_paths, image_file_names, image_base_paths |
|
|
|
def __len__(self): |
|
return len(self.all_image_paths) |
|
|
|
def __getitem__(self, index): |
|
image = Image.open(self.all_image_paths[index]).convert("RGB") |
|
ratio = image.height / image.width |
|
if ratio > 2.0 or ratio < 0.5: |
|
image = self.thin_transform(image) |
|
else: |
|
image = self.normal_transform(image) |
|
|
|
return { |
|
'image': image, |
|
"image_name": self.all_image_names[index], |
|
"image_root": self.image_base_paths[index] |
|
} |
|
|
|
|
|
def prepare_model(model_path: str): |
|
model = torch.load(model_path) |
|
model.to(memory_format=torch.channels_last) |
|
model = model.eval() |
|
return model |
|
|
|
|
|
def train(tagging_is_running, model, dataloader, train_data, output_queue): |
|
print('Begin tagging') |
|
model.eval() |
|
counter = 0 |
|
|
|
with torch.no_grad(): |
|
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data) / dataloader.batch_size)): |
|
this_data = data['image'].to("cuda") |
|
with autocast(device_type='cuda', dtype=torch.bfloat16): |
|
outputs = model(this_data) |
|
|
|
probabilities = torch.nn.functional.sigmoid(outputs) |
|
output_queue.put((probabilities.to("cpu"), data["image_name"], data["image_root"])) |
|
|
|
counter += 1 |
|
_ = tagging_is_running.get() |
|
print("Tagging finished!") |
|
|
|
|
|
def tag_writer(tagging_is_running, output_queue, threshold): |
|
with open("tags.json", "r") as file: |
|
tags = json.load(file) |
|
allowed_tags = sorted(tags) |
|
del tags |
|
allowed_tags.extend(["placeholder0", "placeholder1", "placeholder2"]) |
|
tag_count = len(allowed_tags) |
|
assert tag_count == 7704, f"The length of loss scaling factor is not correct. Correct: 7704, current: {tag_count}" |
|
|
|
while not (tagging_is_running.qsize() > 0 and output_queue.qsize() > 0): |
|
tag_probabilities, image_names, image_roots = output_queue.get() |
|
tag_probabilities = tag_probabilities.tolist() |
|
|
|
for per_image_tag_probabilities, image_name, image_root in zip(tag_probabilities, image_names, image_roots, |
|
strict=True): |
|
this_image_tags = [] |
|
this_image_tag_probabilities = [] |
|
for index, per_tag_probability in enumerate(per_image_tag_probabilities): |
|
if per_tag_probability > threshold: |
|
tag = allowed_tags[index] |
|
if "placeholder" not in tag: |
|
this_image_tags.append(tag) |
|
this_image_tag_probabilities.append(str(int(round(per_tag_probability, 3) * 1000))) |
|
output_file = os.path.join(image_root, os.path.splitext(image_name)[0] + ".txt") |
|
with open(output_file, "w", encoding="utf-8") as this_output: |
|
this_output.write(" ".join(this_image_tags)) |
|
this_output.write("\n") |
|
this_output.write(" ".join(this_image_tag_probabilities)) |
|
|
|
|
|
def main(): |
|
image_folder_path = "/path/to/your/folder/" |
|
|
|
|
|
model_path = "/path/to/your/model.pth" |
|
allowed_extensions = {"jpg", "jpeg", "png", "webp"} |
|
batch_size = 64 |
|
|
|
threshold = 0.3 |
|
|
|
multiprocessing.set_start_method('spawn') |
|
output_queue = multiprocessing.Queue() |
|
tagging_is_running = multiprocessing.Queue(maxsize=5) |
|
tagging_is_running.put("Running!") |
|
|
|
if not torch.cuda.is_available(): |
|
raise RuntimeError("CUDA is not available!") |
|
|
|
model = prepare_model(model_path).to("cuda") |
|
|
|
dataset = ImageDataset(image_folder_path, allowed_extensions) |
|
|
|
batched_loader = DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=6, |
|
pin_memory=True, |
|
drop_last=False, |
|
) |
|
process_writer = multiprocessing.Process(target=tag_writer, |
|
args=(tagging_is_running, output_queue, threshold)) |
|
process_writer.start() |
|
process_tagger = multiprocessing.Process(target=train, |
|
args=(tagging_is_running, model, batched_loader, dataset, output_queue,)) |
|
process_tagger.start() |
|
process_writer.join() |
|
process_tagger.join() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|