import gradio as gr import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model import pickle from PIL import Image import pandas as pd # Load the trained model model = load_model("./model_final.keras") # Load the fitted pipeline with open("pipeline_1.pkl", "rb") as f: pipeline_1 = pickle.load(f) def preprocess_and_predict(image): """ Preprocess the input image using the pipeline and make a prediction. """ # Resize and convert the input image to grayscale (28x28) image = image.resize((28, 28)).convert("L") # Flatten the image to a 784-length vector image_array = np.array(image).reshape(1, -1).astype(np.float32) # Convert the flattened array to a DataFrame (with appropriate column names) image_df = pd.DataFrame(image_array, columns=[f"pixel{i}" for i in range(784)]) # Transform the input using the fitted pipeline image_array_transformed = pipeline_1.transform(image_df).reshape(1,-1) #reshape to [[]] because tensorflow accepts matrices # Make predictions with the model predictions = model.predict(image_array_transformed) # Get the predicted digit (the class with the highest probability) predicted_digit = np.argmax(predictions, axis=1)[0] return f"Predicted Digit: {predicted_digit}" # Define sample examples with paths to example images examples = [ ["./examples/0.jpg"], ["./examples/1.jpg"], ["./examples/2_high_contrast.jpg"], ["./examples/4.jpg"], ["./examples/6.jpg"], ["./examples/7.jpg"], ["./examples/8_high_contrast.jpg"], ["./examples/8.jpg"] ] # Define Gradio interface demo = gr.Interface( fn=preprocess_and_predict, # Function to be called inputs=gr.Image(type="pil"), # Input type: Image outputs="text", # Output type: Text title="MNIST Digit Classifier", # Title description="Upload an image of a digit (0-9) from the MNIST dataset (https://huggingface.co/datasets/ylecun/mnist) [The model will perform poorly for custom images bcz it has only been trained using \"as is\" images from MNIST i.e\n(i) pretty much centered\n (ii) 28x28 pixels\n (iii) perfectly black background\n (iv) white font color images. A custom image will have to be resized (to be 28x28) and still might not have the above things and thus, the model performs poorly], and the model will predict the digit.", examples=examples # Add sample examples ) # Launch the app demo.launch()