abhisheksan commited on
Commit
f55cd01
1 Parent(s): 86e94f2

Refactor PoetryGenerationService to streamline model initialization and improve error handling

Browse files
Files changed (1) hide show
  1. app/services/poetry_generation.py +13 -27
app/services/poetry_generation.py CHANGED
@@ -1,27 +1,27 @@
1
- from typing import Optional, Dict, List
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
  import logging
6
  from functools import lru_cache
7
  import concurrent.futures
8
- from torch.cuda import empty_cache
9
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
 
 
13
  class ModelManager:
14
  _instance = None
15
- _initialized = False
16
- _model_name = "meta-llama/Llama-3.2-1B-Instruct"
17
 
18
  def __new__(cls):
19
  if cls._instance is None:
20
  cls._instance = super().__new__(cls)
 
21
  return cls._instance
22
 
23
  def __init__(self):
24
-
25
  # Initialize tokenizer and model
26
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -30,9 +30,7 @@ class ModelManager:
30
  torch_dtype=torch.float16,
31
  device_map="auto"
32
  )
33
-
34
  # Set model to evaluation mode and move to GPU
35
- self.model = self.model.to(self.model.device)
36
  self.model.eval()
37
  ModelManager._initialized = True
38
 
@@ -41,8 +39,8 @@ class ModelManager:
41
  del self.model
42
  del self.tokenizer
43
  torch.cuda.empty_cache()
44
- except:
45
- pass
46
 
47
  @lru_cache(maxsize=1)
48
  def get_hf_token() -> str:
@@ -54,35 +52,23 @@ def get_hf_token() -> str:
54
  "Please set your Hugging Face access token."
55
  )
56
  return token
57
- model_name = "meta-llama/Llama-3.2-1B-Instruct"
58
  class PoetryGenerationService:
59
  def __init__(self):
60
  # Get model manager instance
61
  model_manager = ModelManager()
62
  self.model = model_manager.model
63
  self.tokenizer = model_manager.tokenizer
64
- self.cache = {}
65
  def preload_models(self):
66
  """Preload the models during application startup"""
67
  try:
68
- # Initialize tokenizer and model
69
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
70
- self.tokenizer.pad_token = self.tokenizer.eos_token
71
-
72
- self.model = AutoModelForCausalLM.from_pretrained(
73
- model_name,
74
- torch_dtype=torch.float16,
75
- device_map="auto"
76
- )
77
-
78
- # Set model to evaluation mode and move to GPU
79
- self.model = self.model.to(self.model.device)
80
- self.model.eval()
81
-
82
  logger.info("Models preloaded successfully")
83
  except Exception as e:
84
  logger.error(f"Error preloading models: {str(e)}")
85
  raise
 
86
  def generate_poem(
87
  self,
88
  prompt: str,
@@ -119,7 +105,7 @@ class PoetryGenerationService:
119
  except Exception as e:
120
  raise Exception(f"Error generating poem: {str(e)}")
121
 
122
- def generate_poems(self, prompts: list[str]) -> list[str]:
123
  with concurrent.futures.ThreadPoolExecutor() as executor:
124
  poems = list(executor.map(self.generate_poem, prompts))
125
- return poems
 
1
+ from typing import Optional, List
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
  import logging
6
  from functools import lru_cache
7
  import concurrent.futures
 
8
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
13
+
14
  class ModelManager:
15
  _instance = None
 
 
16
 
17
  def __new__(cls):
18
  if cls._instance is None:
19
  cls._instance = super().__new__(cls)
20
+ cls._initialized = False
21
  return cls._instance
22
 
23
  def __init__(self):
24
+ if not ModelManager._initialized:
25
  # Initialize tokenizer and model
26
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
30
  torch_dtype=torch.float16,
31
  device_map="auto"
32
  )
 
33
  # Set model to evaluation mode and move to GPU
 
34
  self.model.eval()
35
  ModelManager._initialized = True
36
 
 
39
  del self.model
40
  del self.tokenizer
41
  torch.cuda.empty_cache()
42
+ except Exception as e:
43
+ logger.error(f"Error during cleanup: {str(e)}")
44
 
45
  @lru_cache(maxsize=1)
46
  def get_hf_token() -> str:
 
52
  "Please set your Hugging Face access token."
53
  )
54
  return token
55
+
56
  class PoetryGenerationService:
57
  def __init__(self):
58
  # Get model manager instance
59
  model_manager = ModelManager()
60
  self.model = model_manager.model
61
  self.tokenizer = model_manager.tokenizer
62
+
63
  def preload_models(self):
64
  """Preload the models during application startup"""
65
  try:
66
+ _ = ModelManager() # Ensure ModelManager singleton is initialized
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  logger.info("Models preloaded successfully")
68
  except Exception as e:
69
  logger.error(f"Error preloading models: {str(e)}")
70
  raise
71
+
72
  def generate_poem(
73
  self,
74
  prompt: str,
 
105
  except Exception as e:
106
  raise Exception(f"Error generating poem: {str(e)}")
107
 
108
+ def generate_poems(self, prompts: List[str]) -> List[str]:
109
  with concurrent.futures.ThreadPoolExecutor() as executor:
110
  poems = list(executor.map(self.generate_poem, prompts))
111
+ return poems