# title: create earthwwork train dataset # author: Taewook Kang # date: 2024.3.27 # description: create earthwork train dataset # license: MIT # reference: https://pyautocad.readthedocs.io/en/latest/_modules/pyautocad/api.html # version # 0.1. 2024.3.27. create file # import os, math, argparse, json, re, traceback, numpy as np, pandas as pd, trimesh, laspy, shutil import pyautocad, open3d as o3d, seaborn as sns, win32com.client, pythoncom import matplotlib.pyplot as plt from scipy.spatial import distance from tqdm import trange, tqdm from math import pi def get_layer_to_label(cfg, layer): layers = cfg['layers'] for lay in layers: if lay['layer'] == layer: return lay['label'] return '' def get_entity_from_acad(entity_names = ['AcDbLine', 'AcDbPolyline', 'AcDbText']): acad = pyautocad.Autocad(create_if_not_exists=True) selections = acad.get_selection('Select entities to extract geometry') geoms = [] for entity in tqdm(selections): # tqdm(acad.iter_objects()): # selections: try: if entity.EntityName in entity_names: geoms.append(entity) except Exception as e: print(f'error: {e}') continue if not geoms: print("No entities found in the drawing.") return return geoms def get_bbox(polyline): xmin, ymin, xmax, ymax = polyline[0][0], polyline[0][1], polyline[0][0], polyline[0][1] for x, y in polyline: xmin = min(xmin, x) ymin = min(ymin, y) xmax = max(xmax, x) ymax = max(ymax, y) return (xmin, ymin, xmax, ymax) def get_xsections_from_acad(cfg): entities = get_entity_from_acad() # extract cross sections xsec_list = [] xsec_entities = [] for entity in entities: if entity.Layer == 'Nru_Frame_Crs_Design' and entity.EntityName == 'AcDbPolyline': polyline = [] vertex_list = entity.Coordinates for i in range(0, len(vertex_list), 2): polyline.append((vertex_list[i], vertex_list[i+1])) if len(polyline) < 2: continue bbox = get_bbox(polyline) xsec = {'bbox': bbox, 'station': '', 'geom': []} xsec_list.append(xsec) else: xsec_entities.append(entity) if len(xsec_entities) == 0: print("No cross section found in the drawing.") return [] for xsec in xsec_list: for entity in xsec_entities: if entity.EntityName != 'AcDbText': continue pt = (entity.InsertionPoint[0], entity.InsertionPoint[1]) bbox = xsec['bbox'] if pt[0] < bbox[0] or pt[1] < bbox[1] or pt[0] > bbox[2] or pt[1] > bbox[3]: continue xsec_station = entity.TextString pattern = r'\d+\+\d+\.\d+' match = re.search(pattern, xsec_station) if match: xsec_station = match.group() else: xsec_station = '-1+000.00' xsec['station'] = xsec_station if len(xsec_list) == 0: xsec = {'bbox': (-9999999999.0, -9999999999.0, 9999999999.0, 9999999999.0), 'station': '0+000.00'} xsec_list.append(xsec) xsec_list = sorted(xsec_list, key=lambda x: x['station']) # sorting xsec_list by station string, format is 'xxx+xxx.xx' # extract geometry in each cross section for xsec in tqdm(xsec_list): for entity in xsec_entities: label = get_layer_to_label(cfg, entity.Layer) if label == '': continue closed = False polyline = [] if entity.EntityName == 'AcDbLine': polyline = [entity.StartPoint, entity.EndPoint], closed = False elif entity.EntityName == 'AcDbPolyline': vertex_list = entity.Coordinates for i in range(0, len(vertex_list), 2): polyline.append((vertex_list[i], vertex_list[i+1])) closed = entity.Closed else: continue xsec_bbox = xsec['bbox'] entity_bbox = get_bbox(polyline) if entity_bbox[0] < xsec_bbox[0] or entity_bbox[1] < xsec_bbox[1] or entity_bbox[2] > xsec_bbox[2] or entity_bbox[3] > xsec_bbox[3]: continue geo = { 'label': label, 'polyline': polyline, 'closed': closed, 'earthwork_feature': [] } xsec['geom'].append(geo) return xsec_list # defining function to add line plot _draw_xsection_index = 0 _xsections = None _plot_ax = None def draw_xsections(ax, index): xsec = _xsections[index] for geo in xsec['geom']: station = xsec['station'] ax.set_title(f'station: {station}') polyline = np.array(geo['polyline']) ax.plot(polyline[:,0], polyline[:,1], label=geo['label']) ax.set_aspect('equal', 'box') def next_button(event): global _draw_xsection_index, _xsections, _plot_ax _draw_xsection_index += 1 if _draw_xsection_index >= len(_xsections): _draw_xsection_index = 0 _plot_ax.clear() draw_xsections(_plot_ax, _draw_xsection_index) def prev_button(event): global _draw_xsection_index, _xsections, _plot_ax _draw_xsection_index -= 1 if _draw_xsection_index < 0: _draw_xsection_index = len(_xsections) - 1 _plot_ax.clear() draw_xsections(_plot_ax, _draw_xsection_index) def on_key_press(event): if event.key == 'right': next_button(None) elif event.key == 'left': prev_button(None) def show_xsections(xsections): from matplotlib.widgets import Button global _draw_xsection_index, _xsections, _plot_ax _xsections = xsections fig = plt.figure() _plot_ax = fig.subplots() plt.subplots_adjust(left = 0.3, bottom = 0.25) draw_xsections(_plot_ax, _draw_xsection_index) # defining button and add its functionality axprev = fig.add_axes([0.7, 0.05, 0.1, 0.075]) bprev = Button(axprev, 'prev', color="white") bprev.on_clicked(prev_button) axnext = fig.add_axes([0.81, 0.05, 0.1, 0.075]) bnext = Button(axnext, 'next', color="white") bnext.on_clicked(next_button) fig.canvas.mpl_connect('key_press_event', on_key_press) plt.show() def main(): parser = argparse.ArgumentParser(description='create earthwork train dataset') parser.add_argument('--config', type=str, default='config.json', help='config file') parser.add_argument('--output', type=str, default='output/', help='output directory') parser.add_argument('--view', type=str, default='output/chain_chunk_6.json', help='view file') args = parser.parse_args() try: if len(args.view) > 0: with open(args.view, 'r') as f: xsections = json.load(f) show_xsections(xsections) return cfg = None with open(args.config, 'r', encoding='utf-8') as f: cfg = json.load(f) chunk_index = 0 file_names = os.listdir(args.output) if len(file_names): pattern = r'chain_chunk_(\d+)\.json' indices = [int(re.match(pattern, file_name).group(1)) for file_name in file_names if re.match(pattern, file_name)] chunk_index = max(indices) + 1 if indices else 0 print(file_names) while True: xsections = get_xsections_from_acad(cfg) if len(xsections) == 0: break geo_file = os.path.join(args.output, f'chain_chunk_{chunk_index}.json') with open(geo_file, 'w') as f: json.dump(xsections, f, indent=4) print(f'{geo_file} was saved in {args.output}') chunk_index += 1 except Exception as e: print(f'error: {e}') traceback.print_exc() if __name__ == '__main__': main()