import torch from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from rich.console import Console from rich.markdown import Markdown from rich.panel import Panel from rich.progress import Progress import time import os import json from typing import List, Tuple, Dict, Optional from dataclasses import dataclass, field from datetime import datetime import numpy as np from threading import Lock import gc import logging from contextlib import contextmanager # Set up logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('chat_system.log'), logging.StreamHandler() ] ) @dataclass class ConversationTurn: """Represents a single turn in the conversation.""" role: str content: str timestamp: float = field(default_factory=time.time) token_count: int = 0 class TokenManager: """Manages token counting and context window optimization.""" def __init__(self, tokenizer, max_context_tokens: int = 4096): self.tokenizer = tokenizer self.max_context_tokens = max_context_tokens self._token_count_cache = {} self.cache_lock = Lock() def count_tokens(self, text: str) -> int: """Count tokens with caching for efficiency.""" with self.cache_lock: if text not in self._token_count_cache: tokens = self.tokenizer.encode(text, add_special_tokens=True) self._token_count_cache[text] = len(tokens) return self._token_count_cache[text] def optimize_context(self, turns: List[ConversationTurn], max_turns: int = 10) -> List[ConversationTurn]: """Optimize context window while maintaining coherence.""" total_tokens = 0 optimized_turns = [] # Always include the last turn if turns: last_turn = turns[-1] total_tokens += last_turn.token_count optimized_turns.append(last_turn) # Add previous turns while respecting token limit for turn in reversed(turns[:-1]): if total_tokens + turn.token_count > self.max_context_tokens: break if len(optimized_turns) >= max_turns: break total_tokens += turn.token_count optimized_turns.insert(0, turn) return optimized_turns class ConversationManager: """Manages conversation state and history.""" def __init__(self, token_manager: TokenManager): self.token_manager = token_manager self.turns: List[ConversationTurn] = [] self.system_prompt = """You are a highly capable AI assistant with expertise in business and technical domains. You provide detailed, well-reasoned responses while maintaining a professional tone. Focus on delivering accurate, contextual information without repeating previous conversation details.""" self.system_tokens = token_manager.count_tokens(self.system_prompt) def add_turn(self, role: str, content: str): """Add a new conversation turn with token counting.""" turn = ConversationTurn( role=role, content=content, token_count=self.token_manager.count_tokens(content) ) self.turns.append(turn) def get_prompt(self, include_system: bool = True) -> str: """Generate optimized prompt for model input.""" optimized_turns = self.token_manager.optimize_context(self.turns) components = [] if include_system: components.append(f"System: {self.system_prompt}") for turn in optimized_turns: role_prefix = "Human" if turn.role == "user" else "Assistant" components.append(f"{role_prefix}: {turn.content}") return "\n\n".join(components) class ResponseGenerator: """Handles model inference and response generation.""" def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device # Enhanced generation parameters self.base_params = { 'do_sample': True, 'top_k': 50, 'top_p': 0.95, 'temperature': 0.8, 'repetition_penalty': 1.1, 'no_repeat_ngram_size': 4, 'num_beams': 2, 'early_stopping': True, 'length_penalty': 1.2, 'bad_words_ids': None, 'min_length': 10, 'use_cache': True, } @contextmanager def inference_mode(self): """Context manager for inference optimization.""" torch.cuda.empty_cache() gc.collect() try: with torch.inference_mode(): yield finally: torch.cuda.empty_cache() gc.collect() def calculate_dynamic_length(self, input_text: str, conversation_length: int) -> int: """Calculate dynamic response length based on input and conversation context.""" input_tokens = len(self.tokenizer.encode(input_text)) base_length = max(100, input_tokens * 2) # Scale based on conversation complexity complexity_factor = min(2.0, 1.0 + (conversation_length / 20)) dynamic_length = int(base_length * complexity_factor) # Ensure length is within reasonable bounds return min(max(dynamic_length, 100), 2048) def generate_response(self, prompt: str, conversation_length: int) -> str: """Generate response with dynamic length and advanced parameters.""" with self.inference_mode(): inputs = self.tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=4096 ).to(self.device) max_new_tokens = self.calculate_dynamic_length(prompt, conversation_length) generation_params = { **self.base_params, 'max_new_tokens': max_new_tokens, 'pad_token_id': self.tokenizer.pad_token_id, 'eos_token_id': self.tokenizer.eos_token_id, } outputs = self.model.generate( **inputs, **generation_params ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response response_parts = response.split("Assistant:") if len(response_parts) > 1: response = response_parts[-1].strip() return response class EnterpriseQwenChat: """Main chat interface with enterprise-grade features.""" def __init__(self, model_directory: str = "./qwen"): self.console = Console() self.model_directory = model_directory self.setup_components() def setup_components(self): """Initialize components with CUDA support.""" try: self.console.print("Initializing Enterprise Qwen Chat...", style="bold yellow") # Initialize tokenizer self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model with CUDA optimizations config = AutoConfig.from_pretrained(os.path.join(self.model_directory, "config.json")) self.model = AutoModelForCausalLM.from_pretrained( self.model_directory, config=config, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) # Move model to GPU if available self.model.to("cuda" if torch.cuda.is_available() else "cpu") # Initialize managers self.token_manager = TokenManager(self.tokenizer) self.conversation_manager = ConversationManager(self.token_manager) self.response_generator = ResponseGenerator(self.model, self.tokenizer) self.console.print("[bold green]System initialized successfully![/bold green]") except Exception as e: logging.error(f"Initialization failed: {str(e)}") raise def save_conversation(self) -> str: """Save conversation with metadata.""" timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f'conversation_{timestamp}.json' conversation_data = { 'timestamp': timestamp, 'turns': [ { 'role': turn.role, 'content': turn.content, 'timestamp': turn.timestamp, 'token_count': turn.token_count } for turn in self.conversation_manager.turns ], 'metadata': { 'total_turns': len(self.conversation_manager.turns), 'total_tokens': sum(turn.token_count for turn in self.conversation_manager.turns) } } with open(filename, 'w', encoding='utf-8') as f: json.dump(conversation_data, f, indent=2) return filename def run(self): """Run the chat interface with enhanced features.""" self.console.print(Panel.fit( "[bold green]Enterprise Qwen Chat System[/bold green]\n" "[italic]Commands:\n" "- 'exit' or 'quit': End conversation\n" "- 'save': Save conversation\n" "- 'clear': Clear conversation history[/italic]" )) while True: try: user_input = self.console.input("[bold cyan]You:[/bold cyan] ").strip() if user_input.lower() in ['exit', 'quit']: log_file = self.save_conversation() self.console.print(f"Conversation saved to: {log_file}", style="bold green") break if user_input.lower() == 'save': log_file = self.save_conversation() self.console.print(f"Conversation saved to: {log_file}", style="bold green") continue if user_input.lower() == 'clear': self.conversation_manager.turns.clear() self.console.print("Conversation history cleared.", style="bold yellow") continue # Process user input self.conversation_manager.add_turn("user", user_input) # Generate and display response with self.console.status("[bold yellow]Generating response...[/bold yellow]"): start_time = time.time() prompt = self.conversation_manager.get_prompt() response = self.response_generator.generate_response( prompt, len(self.conversation_manager.turns) ) self.conversation_manager.add_turn("assistant", response) end_time = time.time() self.console.print(Markdown(f"**AI:** {response}")) self.console.print( f"[italic grey](Generated in {end_time - start_time:.2f} seconds)[/italic grey]\n" ) except KeyboardInterrupt: self.console.print("\nGracefully shutting down...", style="bold yellow") self.save_conversation() break except Exception as e: logging.error(f"Error during chat: {str(e)}") self.console.print( "[bold red]An error occurred. The conversation has been saved.[/bold red]" ) self.save_conversation() break if __name__ == "__main__": chat = EnterpriseQwenChat() chat.run()