import gradio as gr import os import json import numpy as np from sklearn.feature_extraction.text import (CountVectorizer, TfidfTransformer, HashingVectorizer, TfidfVectorizer) from sklearn.linear_model import LogisticRegression from lr.hyperparameters import SEARCH_SPACE, RandomSearch, HyperparameterSearch def load_model(serialization_dir): with open(os.path.join(serialization_dir, "best_hyperparameters.json"), 'r') as f: hyperparameters = json.load(f) if hyperparameters.pop('stopwords') == 1: stop_words = 'english' else: stop_words = None weight = hyperparameters.pop('weight') if weight == 'binary': binary = True else: binary = False ngram_range = hyperparameters.pop('ngram_range') ngram_range = sorted([int(x) for x in ngram_range.split()]) if weight == 'tf-idf': vect = TfidfVectorizer(stop_words=stop_words, lowercase=True, ngram_range=ngram_range) elif weight == 'hash': vect = HashingVectorizer(stop_words=stop_words,lowercase=True,ngram_range=ngram_range) else: vect = CountVectorizer(binary=binary, stop_words=stop_words, lowercase=True, ngram_range=ngram_range) if weight != "hash": with open(os.path.join(serialization_dir, "vocab.json"), 'r') as f: vocab = json.load(f) vect.vocabulary_ = vocab hyperparameters['C'] = float(hyperparameters['C']) hyperparameters['tol'] = float(hyperparameters['tol']) classifier = LogisticRegression(**hyperparameters) if os.path.exists(os.path.join(serialization_dir, "archive", "idf.npy")): vect.idf_ = np.load(os.path.join(serialization_dir, "archive", "idf.npy")) classifier.coef_ = np.load(os.path.join(serialization_dir, "archive", "coef.npy")) classifier.intercept_ = np.load(os.path.join(serialization_dir, "archive", "intercept.npy")) classifier.classes_ = np.load(os.path.join(serialization_dir, "archive", "classes.npy")) return classifier, vect def score(x, clf, vectorizer): # score a single document return clf.predict_proba(vectorizer.transform([x])) clf, vectorizer = load_model("model/") def start(text): k = round(score(text, clf, vectorizer)[0][1], 2) return {"GPT-3 Filter Quality Score": k }