Spaces:
Build error
Build error
from transformers import AutoTokenizer, AutoModelForTokenClassification, AutoModelForSequenceClassification | |
from transformers import pipeline | |
from flair.data import Sentence | |
from flair.models import SequenceTagger | |
import pickle | |
class Models: | |
def pickle_it(self, obj, file_name): | |
with open(f'{file_name}.pickle', 'wb') as f: | |
pickle.dump(obj, f) | |
def unpickle_it(self, file_name): | |
with open(f'{file_name}.pickle', 'rb') as f: | |
return pickle.load(f) | |
def load_trained_models(self, pickle=False): | |
#NER (dates) | |
tokenizer = AutoTokenizer.from_pretrained("Jean-Baptiste/camembert-ner-with-dates") | |
model = AutoModelForTokenClassification.from_pretrained("Jean-Baptiste/camembert-ner-with-dates") | |
self.ner_dates = pipeline('ner', model=model, tokenizer=tokenizer, aggregation_strategy="simple") | |
#Zero Shot Classification | |
# self.zero_shot_classifier = pipeline("zero-shot-classification", model='facebook/bart-large-mnli') | |
self.zero_shot_classifier = pipeline("zero-shot-classification", model='valhalla/distilbart-mnli-12-6') | |
# Ner | |
tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER") | |
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER") | |
self.ner = pipeline('ner', model=model, tokenizer=tokenizer, grouped_entities=True) | |
# Pos Tagging | |
self.tagger = SequenceTagger.load("flair/pos-english-fast") | |
if pickle: | |
self.pickle_models() | |
return self.ner, self.ner_dates, self.zero_shot_classifier, self.tagger | |
def pickle_models(self): | |
self.pickle_it(self.ner, "ner") | |
self.pickle_it(self.zero_shot_classifier, "zero_shot_classifier_6") | |
self.pickle_it(self.ner_dates, "ner_dates") | |
self.pickle_it(self.tagger, "pos_tagger_fast") | |
def load_pickled_models(self): | |
ner_dates = self.unpickle_it('ner_dates') | |
ner = self.unpickle_it('ner') | |
zero_shot_classifier = self.unpickle_it('zero_shot_classifier_6') | |
tagger = self.unpickle_it("pos_tagger_fast") | |
return ner_dates, ner, zero_shot_classifier, tagger | |
def get_flair_sentence(self, sent): | |
return Sentence(sent) |