Update README.md
Browse files
README.md
CHANGED
@@ -29,9 +29,23 @@ pipeline_tag: tabular-classification
|
|
29 |
## Usage
|
30 |
```python
|
31 |
import joblib
|
|
|
|
|
|
|
32 |
|
33 |
model = joblib.load("iris_svm.joblib")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
import json
|
|
|
|
|
|
|
|
|
35 |
|
36 |
with open("config.json", "r") as f:
|
37 |
config = json.load(f)
|
@@ -53,7 +67,7 @@ input_data = np.array([
|
|
53 |
if input_data.shape[1] != len(features):
|
54 |
raise ValueError(f"Input data must have {len(features)} features.")
|
55 |
|
56 |
-
predicted_classes =
|
57 |
predicted_class_names = [list(target_mapping.keys())[list(target_mapping.values()).index(predicted_class)] for predicted_class in predicted_classes]
|
58 |
|
59 |
print("Predicted classes:", predicted_class_names)
|
|
|
29 |
## Usage
|
30 |
```python
|
31 |
import joblib
|
32 |
+
from sklearn.impute import SimpleImputer
|
33 |
+
from sklearn.compose import ColumnTransformer
|
34 |
+
from sklearn.pipeline import Pipeline
|
35 |
|
36 |
model = joblib.load("iris_svm.joblib")
|
37 |
+
|
38 |
+
column_transformer_pipeline = ColumnTransformer([
|
39 |
+
("loading_missing_value_imputer", SimpleImputer(strategy="mean"), ["loading"]),
|
40 |
+
("numerical_missing_value_imputer", SimpleImputer(strategy="mean"), list(df.columns[df.dtypes == 'float64'])),
|
41 |
+
("attribute_0_encoder", OneHotEncoder(categories = "auto"), ["attribute_0"]),
|
42 |
+
("attribute_1_encoder", OneHotEncoder(categories = "auto"), ["attribute_1"]),
|
43 |
+
("product_code_encoder", OneHotEncoder(categories = "auto"), ["product_code"])])
|
44 |
import json
|
45 |
+
pipeline = Pipeline([
|
46 |
+
('transformation', column_transformer_pipeline),
|
47 |
+
('model', model)
|
48 |
+
])
|
49 |
|
50 |
with open("config.json", "r") as f:
|
51 |
config = json.load(f)
|
|
|
67 |
if input_data.shape[1] != len(features):
|
68 |
raise ValueError(f"Input data must have {len(features)} features.")
|
69 |
|
70 |
+
predicted_classes = pipeline.predict(input_data)
|
71 |
predicted_class_names = [list(target_mapping.keys())[list(target_mapping.values()).index(predicted_class)] for predicted_class in predicted_classes]
|
72 |
|
73 |
print("Predicted classes:", predicted_class_names)
|