|
import math |
|
import numpy as np |
|
import pandas as pd |
|
|
|
import os |
|
import glob |
|
import trimesh |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
from matplotlib import pyplot as plt |
|
|
|
import gradio as gr |
|
from huggingface_hub import from_pretrained_keras |
|
|
|
|
|
|
|
def conv_bn(x, filters): |
|
x = layers.Conv1D(filters, kernel_size=1, padding="valid")(x) |
|
x = layers.BatchNormalization(momentum=0.0)(x) |
|
return layers.Activation("relu")(x) |
|
|
|
|
|
def dense_bn(x, filters): |
|
x = layers.Dense(filters)(x) |
|
x = layers.BatchNormalization(momentum=0.0)(x) |
|
return layers.Activation("relu")(x) |
|
|
|
|
|
class OrthogonalRegularizer(keras.regularizers.Regularizer): |
|
def __init__(self, num_features, l2reg=0.001, **kwarg): |
|
super(OrthogonalRegularizer, self).__init__(**kwargs) |
|
self.num_features = num_features |
|
self.l2reg = l2reg |
|
self.eye = tf.eye(num_features) |
|
|
|
def __call__(self, x): |
|
x = tf.reshape(x, (-1, self.num_features, self.num_features)) |
|
xxt = tf.tensordot(x, x, axes=(2, 2)) |
|
xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features)) |
|
return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye)) |
|
|
|
def get_config(self): |
|
return {'l2reg': float(self.l2reg)} |
|
|
|
def tnet(inputs, num_features): |
|
|
|
|
|
bias = keras.initializers.Constant(np.eye(num_features).flatten()) |
|
reg = OrthogonalRegularizer(num_features) |
|
|
|
x = conv_bn(inputs, 32) |
|
x = conv_bn(x, 64) |
|
x = conv_bn(x, 512) |
|
x = layers.GlobalMaxPooling1D()(x) |
|
x = dense_bn(x, 256) |
|
x = dense_bn(x, 128) |
|
x = layers.Dense( |
|
num_features * num_features, |
|
kernel_initializer="zeros", |
|
bias_initializer=bias, |
|
activity_regularizer=reg, |
|
)(x) |
|
feat_T = layers.Reshape((num_features, num_features))(x) |
|
|
|
return layers.Dot(axes=(2, 1))([inputs, feat_T]) |
|
|
|
EXAMPLES_PATH = 'examples' |
|
model = from_pretrained_keras('keras-io/PointNet') |
|
|
|
CLASS_MAP = {0: 'chair', |
|
1: 'sofa', |
|
2: 'desk', |
|
3: 'bed', |
|
4: 'dresser', |
|
5: 'night_stand', |
|
6: 'toilet', |
|
7: 'bathtub', |
|
8: 'monitor', |
|
9: 'table'} |
|
|
|
def infer(img_path): |
|
mesh = trimesh.load(img_path.name) |
|
points = mesh.sample(2048) |
|
points = np.expand_dims(np.asarray(points), axis=0) |
|
|
|
|
|
preds = model.predict(points) |
|
preds = tf.math.argmax(preds, -1) |
|
|
|
|
|
fig = plt.figure(figsize=(4, 6)) |
|
ax = fig.add_subplot(2, 1, 1, projection="3d") |
|
ax.scatter(points[0, :, 0], points[0, :, 1], points[0, :, 2]) |
|
ax.set_title(f"This is {CLASS_MAP[preds[0].numpy()]}") |
|
ax.set_axis_off() |
|
|
|
return plt.gcf() |
|
|
|
|
|
inputs = gr.File(type = 'file') |
|
|
|
|
|
output = gr.Plot() |
|
|
|
|
|
|
|
title = 'PointNet Classification and Segmentation' |
|
description = 'Classify images using point cloud Segmentation' |
|
article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. " |
|
examples = [f'{EXAMPLES_PATH}/{f}' for f in os.listdir(EXAMPLES_PATH)] |
|
|
|
gr.Interface(infer, inputs, output, examples= examples, allow_flagging='never', cache_examples=False, |
|
title=title, description=description, article=article, live=False).launch(enable_queue=False, debug=True, inbrowser=False) |
|
|