File size: 3,968 Bytes
2fe3da0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os
import torch
import psutil
import gc
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.data.objaverse import load_obj
from src.utils import mesh
from src.utils.material import Material
import argparse


def bytes_to_megabytes(bytes):
    return bytes / (1024 * 1024)


def bytes_to_gigabytes(bytes):
    return bytes / (1024 * 1024 * 1024)


def print_memory_usage(stage):
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    allocated = torch.cuda.memory_allocated() / 1024**2
    cached = torch.cuda.memory_reserved() / 1024**2
    print(
        f"[{stage}] Process memory: {memory_info.rss / 1024**2:.2f} MB, "
        f"Allocated CUDA memory: {allocated:.2f} MB, Cached CUDA memory: {cached:.2f} MB"
    )


def process_obj(index, root_dir, final_save_dir, paths):
    obj_path = os.path.join(root_dir, paths[index], paths[index] + '.obj')
    mtl_path = os.path.join(root_dir, paths[index], paths[index] + '.mtl')

    if os.path.exists(os.path.join(final_save_dir, f"{paths[index]}.pth")):
        return None

    try:
        with torch.no_grad():
            ref_mesh, vertices, faces, normals, nfaces, texcoords, tfaces, uber_material = load_obj(
                obj_path, return_attributes=True
            )
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            ref_mesh = mesh.compute_tangents(ref_mesh)

        with open(mtl_path, 'r') as file:
            lines = file.readlines()

        if len(lines) >= 250:
            return None

        final_mesh_attributes = {
            "v_pos": ref_mesh.v_pos.detach().cpu(),
            "v_nrm": ref_mesh.v_nrm.detach().cpu(),
            "v_tex": ref_mesh.v_tex.detach().cpu(),
            "v_tng": ref_mesh.v_tng.detach().cpu(),
            "t_pos_idx": ref_mesh.t_pos_idx.detach().cpu(),
            "t_nrm_idx": ref_mesh.t_nrm_idx.detach().cpu(),
            "t_tex_idx": ref_mesh.t_tex_idx.detach().cpu(),
            "t_tng_idx": ref_mesh.t_tng_idx.detach().cpu(),
            "mat_dict": {key: ref_mesh.material[key] for key in ref_mesh.material.mat_keys},
        }

        torch.save(final_mesh_attributes, f"{final_save_dir}/{paths[index]}.pth")
        print(f"==> Saved to {final_save_dir}/{paths[index]}.pth")

        del ref_mesh
        torch.cuda.empty_cache()
        return paths[index]

    except Exception as e:
        print(f"Failed to process {paths[index]}: {e}")
        return None

    finally:
        gc.collect()
        torch.cuda.empty_cache()


def main(root_dir, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    finish_lists = os.listdir(save_dir)
    paths = os.listdir(root_dir)

    valid_uid = []

    print_memory_usage("Start")

    batch_size = 100
    num_batches = (len(paths) + batch_size - 1) // batch_size

    for batch in tqdm(range(num_batches)):
        start_index = batch * batch_size
        end_index = min(start_index + batch_size, len(paths))

        with ThreadPoolExecutor(max_workers=8) as executor:
            futures = [
                executor.submit(process_obj, index, root_dir, save_dir, paths)
                for index in range(start_index, end_index)
            ]
            for future in as_completed(futures):
                result = future.result()
                if result is not None:
                    valid_uid.append(result)

        print_memory_usage(f"=====> After processing batch {batch + 1}")
        torch.cuda.empty_cache()
        gc.collect()

    print_memory_usage("End")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process OBJ files and save final results.")
    parser.add_argument("root_dir", type=str, help="Directory containing the root OBJ files.")
    parser.add_argument("save_dir", type=str, help="Directory to save the processed results.")
    args = parser.parse_args()

    main(args.root_dir, args.save_dir)