Spaces:
Running
Running
import numpy as np | |
import cv2 | |
import gradio as gr | |
PCA_MODEL_PATH = "pca_texture_model.npy" | |
COMPONENT_NAMES_PATH = "component_names.txt" | |
# Load PCA model | |
pca = np.load(PCA_MODEL_PATH, allow_pickle=True).item() | |
mean_texture = pca.mean_ | |
components = pca.components_ | |
explained_variance = pca.explained_variance_ | |
n_components = components.shape[0] | |
TEXTURE_SIZE = int(np.sqrt(mean_texture.shape[0] // 3)) | |
# Calculate slider ranges | |
slider_ranges = [3 * np.sqrt(var) for var in explained_variance] | |
# Load component names if available | |
try: | |
with open(COMPONENT_NAMES_PATH, "r") as f: | |
component_names = [f"Component {i+1} ({line.strip()})" if line.strip() else f"Component {i+1}" for i, line in enumerate(f.readlines())] | |
if len(component_names) < n_components: | |
component_names += [f"Component {i+1}" for i in range(len(component_names), n_components)] | |
except FileNotFoundError: | |
component_names = [f"Component {i+1}" for i in range(n_components)] | |
def generate_texture(*component_values): | |
component_values = np.array(component_values) | |
new_texture = mean_texture + np.dot(component_values, components) | |
new_texture = np.clip(new_texture, 0, 255).astype(np.uint8) | |
new_texture = new_texture.reshape((TEXTURE_SIZE, TEXTURE_SIZE, 3)) | |
new_texture = cv2.cvtColor(new_texture, cv2.COLOR_BGR2RGB) | |
return new_texture | |
def randomize_texture(): | |
sampled_coefficients = np.random.normal(0, np.sqrt(explained_variance), size=n_components) | |
return sampled_coefficients.tolist() | |
def update_texture(*component_values): | |
texture = generate_texture(*component_values) | |
return texture | |
def on_random_click(): | |
random_values = randomize_texture() | |
texture = generate_texture(*random_values) | |
updates = [gr.update(value=value) for value in random_values] | |
updates.append(texture) | |
return updates | |
def process_uploaded_image(uploaded_image): | |
resized_image = cv2.resize(uploaded_image, (TEXTURE_SIZE, TEXTURE_SIZE)) | |
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2BGR) | |
flattened_image = resized_image.flatten() | |
centered_image = flattened_image - mean_texture | |
coefficients = np.dot(centered_image, components.T) | |
clipped_coefficients = [np.clip(coeff, -slider_ranges[i], slider_ranges[i]) for i, coeff in enumerate(coefficients)] | |
return clipped_coefficients | |
def on_image_upload(image): | |
coefficients = process_uploaded_image(image) | |
updates = [gr.update(value=value) for value in coefficients] | |
return updates | |
def on_update_click(*component_values): | |
texture = generate_texture(*component_values) | |
return texture | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
sliders = [] | |
for i in range(n_components): | |
range_limit = slider_ranges[i] | |
slider = gr.Slider( | |
minimum=-range_limit, | |
maximum=range_limit, | |
step=10, | |
value=0, | |
label=component_names[i] | |
) | |
sliders.append(slider) | |
with gr.Column(): | |
output_image = gr.Image( | |
label="Generated Texture" | |
) | |
upload_image = gr.Image( | |
label="Upload Image", | |
sources=['upload', 'clipboard'], | |
type="numpy" | |
) | |
update_texture_button = gr.Button("Update Texture") | |
random_button = gr.Button("Randomize Texture") | |
get_components_button = gr.Button("Get Components from Image") | |
# Update texture when clicking the "Update Texture" button | |
update_texture_button.click( | |
fn=on_update_click, | |
inputs=sliders, | |
outputs=output_image | |
) | |
# Randomize texture and update sliders and image | |
random_button.click( | |
fn=on_random_click, | |
inputs=None, | |
outputs=[*sliders, output_image] | |
) | |
# Update sliders based on the uploaded image when clicking "Get Components from Image" | |
get_components_button.click( | |
fn=on_image_upload, | |
inputs=upload_image, | |
outputs=sliders | |
) | |
# Keep the uploaded image for reference (no update on texture yet) | |
upload_image.change( | |
fn=None, | |
inputs=None, | |
outputs=[] | |
) | |
demo.launch() | |