yashriva commited on
Commit
98fa8fa
1 Parent(s): 82bcf04

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +15 -1
README.md CHANGED
@@ -29,9 +29,23 @@ pipeline_tag: tabular-classification
29
  ## Usage
30
  ```python
31
  import joblib
 
 
 
32
 
33
  model = joblib.load("iris_svm.joblib")
 
 
 
 
 
 
 
34
  import json
 
 
 
 
35
 
36
  with open("config.json", "r") as f:
37
  config = json.load(f)
@@ -53,7 +67,7 @@ input_data = np.array([
53
  if input_data.shape[1] != len(features):
54
  raise ValueError(f"Input data must have {len(features)} features.")
55
 
56
- predicted_classes = model.predict(input_data)
57
  predicted_class_names = [list(target_mapping.keys())[list(target_mapping.values()).index(predicted_class)] for predicted_class in predicted_classes]
58
 
59
  print("Predicted classes:", predicted_class_names)
 
29
  ## Usage
30
  ```python
31
  import joblib
32
+ from sklearn.impute import SimpleImputer
33
+ from sklearn.compose import ColumnTransformer
34
+ from sklearn.pipeline import Pipeline
35
 
36
  model = joblib.load("iris_svm.joblib")
37
+
38
+ column_transformer_pipeline = ColumnTransformer([
39
+ ("loading_missing_value_imputer", SimpleImputer(strategy="mean"), ["loading"]),
40
+ ("numerical_missing_value_imputer", SimpleImputer(strategy="mean"), list(df.columns[df.dtypes == 'float64'])),
41
+ ("attribute_0_encoder", OneHotEncoder(categories = "auto"), ["attribute_0"]),
42
+ ("attribute_1_encoder", OneHotEncoder(categories = "auto"), ["attribute_1"]),
43
+ ("product_code_encoder", OneHotEncoder(categories = "auto"), ["product_code"])])
44
  import json
45
+ pipeline = Pipeline([
46
+ ('transformation', column_transformer_pipeline),
47
+ ('model', model)
48
+ ])
49
 
50
  with open("config.json", "r") as f:
51
  config = json.load(f)
 
67
  if input_data.shape[1] != len(features):
68
  raise ValueError(f"Input data must have {len(features)} features.")
69
 
70
+ predicted_classes = pipeline.predict(input_data)
71
  predicted_class_names = [list(target_mapping.keys())[list(target_mapping.values()).index(predicted_class)] for predicted_class in predicted_classes]
72
 
73
  print("Predicted classes:", predicted_class_names)