logistic-regression-iris

A logistic regression model trained on the Iris dataset.

It takes two inputs: 'PetalLengthCm' and 'PetalWidthCm'. It predicts whether the species is 'Iris-setosa'.

It is a PyTorch adaptation of the scikit-learn model in Chapter 10 of Aurelien Geron's book 'Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow'.

Code: https://github.com/sambitmukherjee/handson-ml3-pytorch/blob/main/chapter10/logistic_regression_iris.ipynb

Experiment tracking: https://wandb.ai/sadhaklal/logistic-regression-iris

Usage

!pip install -q datasets

from datasets import load_dataset

iris = load_dataset("scikit-learn/iris")
iris.set_format("pandas")
iris_df = iris['train'][:]
X = iris_df[['PetalLengthCm', 'PetalWidthCm']]
y = (iris_df['Species'] == "Iris-setosa").astype(int)

class_names = ["Not Iris-setosa", "Iris-setosa"]

from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(X.values, y.values, test_size=0.3, stratify=y, random_state=42)
X_means, X_stds = X_train.mean(axis=0), X_train.std(axis=0)

import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin

device = torch.device("cpu")

class LinearModel(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        out = self.fc(x)
        return out

model = LinearModel.from_pretrained("sadhaklal/logistic-regression-iris")
model.to(device)

# Inference on new data:
import numpy as np

X_new = np.array([[2.0, 0.5], [3.0, 1.0]]) # Contains data on 2 new flowers.
X_new = ((X_new - X_means) / X_stds) # Normalize.
X_new = torch.from_numpy(X_new).float()

model.eval()
X_new = X_new.to(device)
with torch.no_grad():
    logits = model(X_new)
proba = torch.sigmoid(logits.squeeze())
preds = (proba > 0.5).long()

print(f"Predicted classes: {preds}")
print(f"Predicted probabilities of being Iris-setosa: {proba}")

Metric

As shown above, the validation set contains 30% of the examples (selected at random in a stratified fashion).

Accuracy on the validation set: 1.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The HF Inference API does not support tabular-classification models for pytorch library.

Dataset used to train sadhaklal/logistic-regression-iris