import asyncio import json import logging from typing import TypeVar, Type, Optional, Callable from pydantic import BaseModel from langchain_mistralai.chat_models import ChatMistralAI from langchain.schema import SystemMessage, HumanMessage from langchain.schema.messages import BaseMessage T = TypeVar('T', bound=BaseModel) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Available Mistral models: # - mistral-tiny : Fastest, cheapest, good for testing # - mistral-small : Good balance of speed and quality # - mistral-medium : Better quality, slower than small # - mistral-large : Best quality, slowest and most expensive # # mistral-large-latest: currently points to mistral-large-2411. # pixtral-large-latest: currently points to pixtral-large-2411. # mistral-moderation-latest: currently points to mistral-moderation-2411. # ministral-3b-latest: currently points to ministral-3b-2410. # ministral-8b-latest: currently points to ministral-8b-2410. # open-mistral-nemo: currently points to open-mistral-nemo-2407. # mistral-small-latest: currently points to mistral-small-2409. # codestral-latest: currently points to codestral-2501. # # Pricing: https://docs.mistral.ai/platform/pricing/ class MistralAPIError(Exception): """Base class for Mistral API errors""" pass class MistralRateLimitError(MistralAPIError): """Raised when hitting rate limits""" pass class MistralParsingError(MistralAPIError): """Raised when response parsing fails""" pass class MistralValidationError(MistralAPIError): """Raised when response validation fails""" pass class MistralClient: def __init__(self, api_key: str, model_name: str = "mistral-small-latest", max_tokens: int = 1000): logger.info(f"Initializing MistralClient with model: {model_name}, max_tokens: {max_tokens}") self.model = ChatMistralAI( mistral_api_key=api_key, model=model_name, max_tokens=max_tokens ) self.fixing_model = ChatMistralAI( mistral_api_key=api_key, model=model_name, max_tokens=max_tokens ) # Pour gérer le rate limit self.last_call_time = 0 self.min_delay = 1 # 1 seconde minimum entre les appels self.max_retries = 5 self.backoff_factor = 2 # For exponential backoff self.max_backoff = 30 # Maximum backoff time in seconds async def _wait_for_rate_limit(self): """Attend le temps nécessaire pour respecter le rate limit.""" current_time = asyncio.get_event_loop().time() time_since_last_call = current_time - self.last_call_time if time_since_last_call < self.min_delay: delay = self.min_delay - time_since_last_call logger.debug(f"Rate limit: waiting for {delay:.2f} seconds") await asyncio.sleep(delay) self.last_call_time = asyncio.get_event_loop().time() async def _handle_api_error(self, error: Exception, retry_count: int) -> float: """Handle API errors and return wait time for retry""" wait_time = min(self.backoff_factor ** retry_count, self.max_backoff) if "rate limit" in str(error).lower(): logger.warning(f"Rate limit hit, waiting {wait_time}s before retry") raise MistralRateLimitError(str(error)) elif "403" in str(error): logger.error("Authentication error - invalid API key or quota exceeded") raise MistralAPIError("Authentication failed") return wait_time async def _generate_with_retry( self, messages: list[BaseMessage], response_model: Optional[Type[T]] = None, custom_parser: Optional[Callable[[str], T]] = None, error_feedback: str = None ) -> T | str: retry_count = 0 last_error = None while retry_count < self.max_retries: try: logger.info(f"Attempt {retry_count + 1}/{self.max_retries}") current_messages = messages.copy() if error_feedback and retry_count > 0: if isinstance(last_error, MistralParsingError): # For parsing errors, add structured format reminder current_messages.append(HumanMessage(content="Please ensure your response is in valid JSON format.")) elif isinstance(last_error, MistralValidationError): # For validation errors, add the specific feedback current_messages.append(HumanMessage(content=f"Previous error: {error_feedback}. Please try again.")) await self._wait_for_rate_limit() try: response = await self.model.ainvoke(current_messages) content = response.content logger.debug(f"Raw response: {content[:100]}...") except Exception as api_error: wait_time = await self._handle_api_error(api_error, retry_count) retry_count += 1 if retry_count < self.max_retries: await asyncio.sleep(wait_time) continue raise # Si pas de parsing requis, retourner le contenu brut if not response_model and not custom_parser: return content # Parser la réponse try: if custom_parser: return custom_parser(content) # Essayer de parser avec le modèle Pydantic data = json.loads(content) return response_model(**data) except json.JSONDecodeError as e: last_error = MistralParsingError(f"Invalid JSON format: {str(e)}") logger.error(f"JSON parsing error: {str(e)}") raise last_error except Exception as e: last_error = MistralValidationError(str(e)) logger.error(f"Validation error: {str(e)}") raise last_error except (MistralParsingError, MistralValidationError) as e: logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}") last_error = e retry_count += 1 if retry_count < self.max_retries: wait_time = min(self.backoff_factor ** retry_count, self.max_backoff) logger.info(f"Waiting {wait_time} seconds before retry...") await asyncio.sleep(wait_time) continue logger.error(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}") raise Exception(f"Failed after {self.max_retries} attempts. Last error: {str(last_error)}") async def generate(self, messages: list[BaseMessage], response_model: Optional[Type[T]] = None, custom_parser: Optional[Callable[[str], T]] = None) -> T | str: """Génère une réponse à partir d'une liste de messages avec parsing optionnel.""" return await self._generate_with_retry(messages, response_model, custom_parser) async def transform_prompt(self, story_text: str, art_prompt: str) -> str: """Transforme un texte d'histoire en prompt artistique.""" messages = [{ "role": "system", "content": art_prompt }, { "role": "user", "content": f"Transform this story text into a comic panel description:\n{story_text}" }] try: return await self._generate_with_retry(messages) except Exception as e: print(f"Error transforming prompt: {str(e)}") return story_text async def generate_text(self, messages: list[BaseMessage]) -> str: """ Génère une réponse textuelle simple sans structure JSON. Utile pour la génération de texte narratif ou descriptif. Args: messages: Liste des messages pour le modèle Returns: str: Le texte généré """ retry_count = 0 last_error = None while retry_count < self.max_retries: try: logger.info(f"Attempt {retry_count + 1}/{self.max_retries}") await self._wait_for_rate_limit() response = await self.model.ainvoke(messages) return response.content.strip() except Exception as e: logger.error(f"Error on attempt {retry_count + 1}/{self.max_retries}: {str(e)}") retry_count += 1 if retry_count < self.max_retries: wait_time = 2 * retry_count logger.info(f"Waiting {wait_time} seconds before retry...") await asyncio.sleep(wait_time) continue logger.error(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}") raise Exception(f"Failed after {self.max_retries} attempts. Last error: {last_error or str(e)}") async def check_health(self) -> bool: """ Vérifie la disponibilité du service Mistral avec un appel simple sans retry. Returns: bool: True si le service est disponible, False sinon """ try: response = await self.model.ainvoke([SystemMessage(content="Hi")]) return True except Exception as e: logger.error(f"Health check failed: {str(e)}") raise