jitesh commited on
Commit
4fb52dd
1 Parent(s): bf1d936

cache the models

Browse files
Files changed (1) hide show
  1. story_gen.py +16 -7
story_gen.py CHANGED
@@ -18,14 +18,23 @@ class StoryGenerator:
18
  self.stories = []
19
  self.data = []
20
 
21
- @st.cache()
 
 
 
 
 
 
 
 
 
 
22
  def initialise_models(self):
23
- start = time.time()
24
- self.generator = pipeline('text-generation', model='gpt2')
25
- self.classifier = pipeline("text-classification",
26
- model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
27
- initialising_time = time.time()-start
28
- print(f'Initialising Time: {initialising_time}')
29
  # set_seed(42)
30
  # sys.exit()
31
 
 
18
  self.stories = []
19
  self.data = []
20
 
21
+ @staticmethod
22
+ @st.cache(allow_output_mutation=True)
23
+ def get_generator():
24
+ return pipeline('text-generation', model='gpt2')
25
+
26
+ @staticmethod
27
+ @st.cache(allow_output_mutation=True)
28
+ def get_classifier():
29
+ return pipeline("text-classification",
30
+ model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
31
+
32
  def initialise_models(self):
33
+ # start = time.time()
34
+ self.generator = self.get_generator()
35
+ self.classifier = self.get_classifier()
36
+ # initialising_time = time.time()-start
37
+ # print(f'Initialising Time: {initialising_time}')
 
38
  # set_seed(42)
39
  # sys.exit()
40