tts / rvc_service.py
MAZALA2024's picture
Update rvc_service.py
96bf186 verified
import torch
import numpy as np
import logging
import queue
import threading
import time
from dataclasses import dataclass
from typing import Optional, Dict, List
import gc
from datetime import datetime, time as dt_time
import os
from collections import deque
import asyncio
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class JobRequest:
"""Represents a single voice conversion request"""
id: str # Unique identifier for the job
audio_data: np.ndarray # Input audio data
model_name: str # Name of the RVC model to use
priority: int = 1 # Priority level (1-5, 5 being highest)
timestamp: float = None # When the job was submitted
def __post_init__(self):
self.timestamp = time.time() if self.timestamp is None else self.timestamp
class ModelCache:
"""Manages cached models with LRU eviction policy"""
def __init__(self, max_models: int = 3):
self.max_models = max_models
self.models: Dict[str, torch.nn.Module] = {}
self.model_usage: deque = deque()
self.lock = threading.Lock()
def get_model(self, model_name: str) -> Optional[torch.nn.Module]:
"""Get model from cache, implementing LRU policy"""
with self.lock:
if model_name in self.models:
# Update usage history
self.model_usage.remove(model_name)
self.model_usage.append(model_name)
return self.models[model_name]
return None
def add_model(self, model_name: str, model: torch.nn.Module):
"""Add model to cache, evicting least recently used if necessary"""
with self.lock:
if len(self.models) >= self.max_models:
# Evict least recently used model
lru_model = self.model_usage.popleft()
del self.models[lru_model]
# Force garbage collection to free GPU memory
gc.collect()
torch.cuda.empty_cache()
self.models[model_name] = model
self.model_usage.append(model_name)
class JobQueue:
"""Manages prioritized job queue with rate limiting"""
def __init__(self, max_size: int = 100):
self.queue = queue.PriorityQueue(maxsize=max_size)
self.processing: Dict[str, JobRequest] = {}
self.lock = threading.Lock()
self.last_processed = time.time()
self.rate_limit = 1.0 # Minimum seconds between jobs
def add_job(self, job: JobRequest) -> bool:
"""Add job to queue with priority"""
try:
# Priority tuple: (priority reversed, timestamp, job)
# Lower number = higher priority
self.queue.put((6 - job.priority, job.timestamp, job), block=False)
logger.info(f"Added job {job.id} to queue. Priority: {job.priority}")
return True
except queue.Full:
logger.warning("Queue is full, job rejected")
return False
def get_next_job(self) -> Optional[JobRequest]:
"""Get next job respecting rate limiting"""
if time.time() - self.last_processed < self.rate_limit:
return None
try:
_, _, job = self.queue.get(block=False)
with self.lock:
self.processing[job.id] = job
self.last_processed = time.time()
return job
except queue.Empty:
return None
class RVCService:
"""Main service class for RVC processing"""
def __init__(self):
self.model_cache = ModelCache(max_models=3)
self.job_queue = JobQueue(max_size=100)
self.is_running = False
self.worker_thread = None
# Operating hours (24-hour format)
self.start_time = dt_time(9, 0) # 9:00 AM
self.end_time = dt_time(0, 0) # 12:00 AM
def within_operating_hours(self) -> bool:
"""Check if current time is within operating hours"""
current_time = datetime.now().time()
# For testing/development, always return True
# TODO: Implement proper operating hours check for production
return True
# When ready for production, uncomment this:
# if self.start_time <= self.end_time:
# return self.start_time <= current_time <= self.end_time
# else: # Handles overnight operation (e.g., 9 AM to 12 AM)
# return current_time >= self.start_time or current_time <= self.end_time
async def process_audio(self, job: JobRequest) -> Optional[np.ndarray]:
"""Process a single audio conversion job"""
try:
# Get or load model
model = self.model_cache.get_model(job.model_name)
if model is None:
logger.info(f"Loading model {job.model_name}")
# Here you would load your RVC model
# model = load_rvc_model(job.model_name)
self.model_cache.add_model(job.model_name, model)
# Process audio
with torch.cuda.amp.autocast():
# Your RVC processing logic here
# output = model.convert_voice(job.audio_data)
output = job.audio_data # Placeholder
return output
except Exception as e:
logger.error(f"Error processing job {job.id}: {str(e)}")
return None
async def worker_loop(self):
"""Main worker loop processing jobs from queue"""
while self.is_running:
try:
# Check operating hours
if not self.within_operating_hours():
logger.info("Outside operating hours, worker sleeping...")
await asyncio.sleep(300) # Check every 5 minutes
continue
# Get next job
job = self.job_queue.get_next_job()
if job is None:
await asyncio.sleep(0.1) # Prevent busy waiting
continue
logger.info(f"Processing job {job.id}")
output = await self.process_audio(job)
if output is not None:
logger.info(f"Successfully processed job {job.id}")
else:
logger.error(f"Failed to process job {job.id}")
# Cleanup
with self.job_queue.lock:
self.job_queue.processing.pop(job.id, None)
except Exception as e:
logger.error(f"Worker error: {str(e)}")
await asyncio.sleep(1) # Prevent rapid error loops
def start(self):
"""Start the service"""
if not self.is_running:
self.is_running = True
# Create a new event loop for the worker
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Start the worker loop in the background
def run_worker():
loop.run_until_complete(self.worker_loop())
self.worker_thread = threading.Thread(target=run_worker)
self.worker_thread.daemon = True
self.worker_thread.start()
logger.info("RVC Service started")
def stop(self):
"""Stop the service"""
self.is_running = False
logger.info("RVC Service stopping...")
async def submit_job(self, audio_data: np.ndarray, model_name: str, priority: int = 1) -> str:
"""Submit a new job to the service"""
job_id = f"job_{int(time.time())}_{id(audio_data)}"
job = JobRequest(
id=job_id,
audio_data=audio_data,
model_name=model_name,
priority=priority
)
if self.job_queue.add_job(job):
return job_id
return None
# Memory management utilities
def cleanup_gpu_memory():
"""Force cleanup of GPU memory"""
gc.collect()
torch.cuda.empty_cache()
def monitor_gpu_memory():
"""Log GPU memory usage"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**2
reserved = torch.cuda.memory_reserved() / 1024**2
logger.info(f"GPU Memory: {allocated:.2f}MB allocated, {reserved:.2f}MB reserved")