Update news_category_prediction.py
Browse files
news_category_prediction.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import tensorflow as tf
|
4 |
-
|
5 |
|
6 |
def parse_prediction(tflite_pred, label_encoder):
|
7 |
tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
|
|
|
1 |
import pandas as pd
|
2 |
import numpy as np
|
3 |
import tensorflow as tf
|
4 |
+
from config import DISTILBERT_TOKENIZER_N_TOKENS, NEWS_CATEGORY_CLASSIFIER_N_CLASSES, CLASSIFIER_THRESHOLD
|
5 |
|
6 |
def parse_prediction(tflite_pred, label_encoder):
|
7 |
tflite_pred_argmax = np.argmax(tflite_pred, axis=1)
|