dawood's picture
dawood HF staff
Update app.py
0c50f29
raw
history blame
2.57 kB
import os
import gradio as gr
import numpy as np
import glob
import warnings
import pandas as pd
import matplotlib.pyplot as plt
from utils import OrthogonalRegularizer
from huggingface_hub.keras_mixin import from_pretrained_keras
# load model
model = from_pretrained_keras(
"keras-io/pointnet_segmentation", custom_objects={"OrthogonalRegularizer": OrthogonalRegularizer}
)
# Examples
samples = []
input_images = glob.glob("asset/source/*.csv")
examples = [[im] for im in input_images]
LABELS = ["wing", "body", "tail", "engine"]
COLORS = ["blue", "green", "red", "pink"]
def visualize_data(point_cloud, labels, output_path=None):
df = pd.DataFrame(
data={
"x": point_cloud[:, 0],
"y": point_cloud[:, 1],
"z": point_cloud[:, 2],
"label": labels,
}
)
fig = plt.figure(figsize=(15, 10))
ax = plt.axes(projection="3d")
for index, label in enumerate(LABELS):
c_df = df[df["label"] == label]
try:
ax.scatter(c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index])
except IndexError:
pass
ax.legend()
if output_path:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
plt.savefig(output_path)
def inference(
csv_file,
output_path="asset/output",
cpu=False,
):
csv_path = csv_file.name
im_name = csv_path.split("/")[-1].split(".")[0]
if os.path.exists(csv_path):
df = pd.read_csv(csv_path, index_col=None)
inputs = df[["x", "y", "z"]].values
y_test = df.iloc[:, 3:].values # TODO: show ground truth image if y_test is not None
else:
warnings.warn(f"{csv_path} not found for {im_path}")
return
preds = model.predict(np.expand_dims(inputs, 0))[0]
label_map = LABELS + ["none"]
visualize_data(inputs, [label_map[np.argmax(label)] for label in preds], f"{output_path}/{im_name}.png")
return f"{output_path}/{im_name}.png"
article = "<div style='text-align: center;'><a href='https://nouamanetazi.me/' target='_blank'>Space by Nouamane Tazi</a><br><a href='https://keras.io/examples/vision/pointnet_segmentation' target='_blank'>Keras example by Soumik Rakshit, Sayak Paul</a></div>"
iface = gr.Interface(
inference, # main function
inputs=[
"file",
],
outputs=[
gr.outputs.Image(label="result"), # generated image
],
title="Point cloud segmentation with PointNet",
article=article,
examples=examples, cache_examples=True
).launch(enable_queue=True)