File size: 3,520 Bytes
58126ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b888c8e
58126ce
 
 
 
 
 
 
 
 
 
 
f32d4c1
a1bacc7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math
import numpy as np
import pandas as pd

import os
import glob
import trimesh
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from matplotlib import pyplot as plt

import gradio as gr
from huggingface_hub import from_pretrained_keras



def conv_bn(x, filters):
    x = layers.Conv1D(filters, kernel_size=1, padding="valid")(x)
    x = layers.BatchNormalization(momentum=0.0)(x)
    return layers.Activation("relu")(x)


def dense_bn(x, filters):
    x = layers.Dense(filters)(x)
    x = layers.BatchNormalization(momentum=0.0)(x)
    return layers.Activation("relu")(x)

# @keras.utils.register_keras_serializable
class OrthogonalRegularizer(keras.regularizers.Regularizer):
    def __init__(self, num_features, l2reg=0.001, **kwarg):
        super(OrthogonalRegularizer, self).__init__(**kwargs)
        self.num_features = num_features
        self.l2reg = l2reg
        self.eye = tf.eye(num_features)

    def __call__(self, x):
        x = tf.reshape(x, (-1, self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
        return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))

    def get_config(self):
        return {'l2reg': float(self.l2reg)}

def tnet(inputs, num_features):

    # Initalise bias as the indentity matrix
    bias = keras.initializers.Constant(np.eye(num_features).flatten())
    reg = OrthogonalRegularizer(num_features)

    x = conv_bn(inputs, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = dense_bn(x, 128)
    x = layers.Dense(
        num_features * num_features,
        kernel_initializer="zeros",
        bias_initializer=bias,
        activity_regularizer=reg,
    )(x)
    feat_T = layers.Reshape((num_features, num_features))(x)
    # Apply affine transformation to input features
    return layers.Dot(axes=(2, 1))([inputs, feat_T])

EXAMPLES_PATH = 'examples'
model = from_pretrained_keras('keras-io/PointNet')

CLASS_MAP = {0: 'chair',
 1: 'sofa',
 2: 'desk',
 3: 'bed',
 4: 'dresser',
 5: 'night_stand',
 6: 'toilet',
 7: 'bathtub',
 8: 'monitor',
 9: 'table'}

def infer(img_path):
    mesh = trimesh.load(img_path.name)
    points = mesh.sample(2048)
    points = np.expand_dims(np.asarray(points), axis=0)

    # run test data through model
    preds = model.predict(points)
    preds = tf.math.argmax(preds, -1)

    # plot points with predicted class and label
    fig = plt.figure(figsize=(4, 6))
    ax = fig.add_subplot(2, 1, 1, projection="3d")
    ax.scatter(points[0, :, 0], points[0, :, 1], points[0, :, 2])
    ax.set_title(f"This is {CLASS_MAP[preds[0].numpy()]}")
    ax.set_axis_off()
    # plt.imshow(image)
    return plt.gcf()

# get the inputs
inputs = gr.File(type = 'file')

# the app outputs two segmented images
output = gr.Plot()


# it's good practice to pass examples, description and a title to guide users
title = 'PointNet Classification and Segmentation'
description = 'Classify images using point cloud Segmentation'
article = "Author: <a href=\"https://huggingface.co/geninhu\">Nhu Hoang</a>. "
examples = [f'{EXAMPLES_PATH}/{f}' for f in os.listdir(EXAMPLES_PATH)]

gr.Interface(infer, inputs, output, examples= examples, allow_flagging='never', cache_examples=False,
             title=title, description=description, article=article, live=False).launch(enable_queue=False, debug=True, inbrowser=False)