mechanical / app.py
engrharis's picture
Create app.py
4f5520c verified
raw
history blame
3.27 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import open3d as o3d
import numpy as np
import cadquery as cq
# Load the tokenizer from Qwen2-1.5B and model weights from filapro/cad-recode
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("filapro/cad-recode", trust_remote_code=True)
# Set device (GPU if available, CPU otherwise)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model loaded on {device}")
@st.cache(allow_output_mutation=True)
def load_point_cloud(file):
"""Loads a point cloud from a uploaded file."""
if not file:
return None
if file.type not in ("application/octet-stream", "text/plain"):
st.error("Please upload a point cloud file (.pcd, .xyz, etc.)")
return None
try:
point_cloud = o3d.io.read_point_cloud(file)
except Exception as e:
st.error(f"Error loading point cloud: {e}")
return None
return point_cloud
def prepare_input_data(point_cloud):
"""Prepares point cloud data for model input."""
if not point_cloud:
return None
point_cloud_array = np.asarray(point_cloud.points).flatten()
input_text = " ".join(map(str, point_cloud_array))
return input_text
def generate_cad_code(input_text):
"""Runs inference and decodes generated output."""
if not input_text:
return None
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id)
cad_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
return cad_code
def generate_cad_model(cad_code):
"""Generates a CAD model from the provided code."""
if not cad_code:
return None
try:
# Execute CAD code using CadQuery library
exec(cad_code)
cad_model = cq.Workplane("XY").val()
except Exception as e:
st.error(f"Error generating CAD model: {e}")
return None
return cad_model
def main():
"""Streamlit app for point cloud to CAD code conversion."""
st.title("Point Cloud to CAD Code Converter")
st.write("This app uses the filapro/cad-recode model to generate Python code for a 3D CAD model from your point cloud data.")
uploaded_file = st.file_uploader("Upload Point Cloud File")
point_cloud = load_point_cloud(uploaded_file)
if point_cloud:
input_text = prepare_input_data(point_cloud)
cad_code = generate_cad_code(input_text)
if cad_code:
st.success("Generated Python CAD Code:")
st.code(cad_code)
cad_model = generate_cad_model(cad_code)
if cad_model:
# Optionally, use a 3D visualization library like trimesh
# to display the generated CAD model (not included)
st.success("Generated CAD Model (Visualization not yet implemented)")
# st.write(cad_model) # Replace with visualization code
if __name__ == "__main__":
main()