poetica / app /services /poetry_generation.py
abhisheksan's picture
Refactor model loading to use a consistent model name variable in PoetryGenerationService
51ed73b
raw
history blame
4.34 kB
from typing import Optional, Dict, List
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
import logging
from functools import lru_cache
import concurrent.futures
from torch.cuda import empty_cache
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelManager:
_instance = None
_initialized = False
_model_name = "meta-llama/Llama-3.2-1B-Instruct"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Set model to evaluation mode and move to GPU
self.model = self.model.to(self.model.device)
self.model.eval()
ModelManager._initialized = True
def __del__(self):
try:
del self.model
del self.tokenizer
torch.cuda.empty_cache()
except:
pass
@lru_cache(maxsize=1)
def get_hf_token() -> str:
"""Get Hugging Face token from environment variables."""
token = os.getenv("HF_TOKEN")
if not token:
raise EnvironmentError(
"HF_TOKEN environment variable not found. "
"Please set your Hugging Face access token."
)
return token
model_name = "meta-llama/Llama-3.2-1B-Instruct"
class PoetryGenerationService:
def __init__(self):
# Get model manager instance
model_manager = ModelManager()
self.model = model_manager.model
self.tokenizer = model_manager.tokenizer
self.cache = {}
def preload_models(self):
"""Preload the models during application startup"""
try:
# Initialize tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Set model to evaluation mode and move to GPU
self.model = self.model.to(self.model.device)
self.model.eval()
logger.info("Models preloaded successfully")
except Exception as e:
logger.error(f"Error preloading models: {str(e)}")
raise
def generate_poem(
self,
prompt: str,
temperature: Optional[float] = 0.7,
top_p: Optional[float] = 0.9,
top_k: Optional[int] = 50,
max_length: Optional[int] = 100,
repetition_penalty: Optional[float] = 1.1
) -> str:
try:
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_length=max_length,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
return self.tokenizer.decode(
outputs[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
except Exception as e:
raise Exception(f"Error generating poem: {str(e)}")
def generate_poems(self, prompts: list[str]) -> list[str]:
with concurrent.futures.ThreadPoolExecutor() as executor:
poems = list(executor.map(self.generate_poem, prompts))
return poems