yashriva commited on
Commit
7bd36fd
1 Parent(s): 4466594

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +1 -47
README.md CHANGED
@@ -24,50 +24,4 @@ widget:
24
  - 0
25
  - 2
26
  pipeline_tag: tabular-classification
27
- ---
28
-
29
- ## Usage
30
- ```python
31
- import joblib
32
- from sklearn.impute import SimpleImputer
33
- from sklearn.compose import ColumnTransformer
34
- from sklearn.pipeline import Pipeline
35
-
36
- model = joblib.load("iris_svm.joblib")
37
-
38
- column_transformer_pipeline = ColumnTransformer([
39
- ("SepalLengthCm", SimpleImputer(strategy="mean"), ["SepalLengthCm"]),
40
- ("SepalWidthCm", SimpleImputer(strategy="mean"), ["SepalWidthCm"]),
41
- ("PetalLengthCm", SimpleImputer(strategy="mean"), ["PetalLengthCm"]),
42
- ("PetalWidthCm", SimpleImputer(strategy="mean"), ["PetalWidthCm"])])
43
- import json
44
- pipeline = Pipeline([
45
- ('transformation', column_transformer_pipeline),
46
- ('model', model)
47
- ])
48
-
49
- with open("config.json", "r") as f:
50
- config = json.load(f)
51
-
52
- features = config["features"]
53
- target = config["targets"][0]
54
- target_mapping = config["target_mapping"]
55
-
56
- import numpy as np
57
-
58
- # example input data
59
- input_data = np.array([
60
- [5.1, 3.5, 1.4, 0.2],
61
- [4.9, 3.0, 1.4, 0.2],
62
- [6.2, 3.4, 5.4, 2.3]
63
- ])
64
-
65
- # make sure the input data has the correct shape
66
- if input_data.shape[1] != len(features):
67
- raise ValueError(f"Input data must have {len(features)} features.")
68
-
69
- predicted_classes = pipeline.predict(input_data)
70
- predicted_class_names = [list(target_mapping.keys())[list(target_mapping.values()).index(predicted_class)] for predicted_class in predicted_classes]
71
-
72
- print("Predicted classes:", predicted_class_names)
73
- ```
 
24
  - 0
25
  - 2
26
  pipeline_tag: tabular-classification
27
+ ---