import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM from pathlib import Path import os # Model configuration MODEL_DIR = Path("./AI_Chatbot").resolve() REQUIRED_FILES = [ "config.json", "pytorch_model.bin", "tokenizer_config.json", "special_tokens_map.json", "vocab.txt" # Change to merges.txt/vocab.json if using different tokenizer ] def load_legal_model(): """Load model and tokenizer with validation""" # Verify model directory if not MODEL_DIR.exists(): raise FileNotFoundError(f"Model directory missing at {MODEL_DIR}") # Check for required files missing_files = [] for file in REQUIRED_FILES: if not (MODEL_DIR / file).exists(): missing_files.append(file) if missing_files: raise FileNotFoundError(f"Missing files: {', '.join(missing_files)}") # Load model components try: tokenizer = AutoTokenizer.from_pretrained(str(MODEL_DIR), local_files_only=True) model = AutoModelForCausalLM.from_pretrained(str(MODEL_DIR), local_files_only=True) return tokenizer, model except Exception as e: raise RuntimeError(f"Model loading failed: {str(e)}") def legal_analysis(query, history): """Generate legal analysis for user query""" try: # Tokenize input inputs = tokenizer(query, return_tensors="pt") # Generate response outputs = model.generate( inputs.input_ids, max_length=512, num_beams=4, early_stopping=True, temperature=0.7 ) # Decode and clean response response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response except Exception as e: return f"Error processing request: {str(e)}" # Load model once at startup try: tokenizer, model = load_legal_model() except Exception as e: print(f"Critical error during initialization: {str(e)}") raise # Create Gradio interface with gr.Blocks(title="Legal AI Assistant", theme=gr.themes.Soft()) as app: gr.Markdown("# ⚖️ Legal AI Counsel") gr.Markdown("Expert-level legal analysis powered by AI") with gr.Row(): with gr.Column(scale=3): input_query = gr.Textbox( label="Legal Query", placeholder="Enter your legal question...", lines=3 ) submit_btn = gr.Button("Analyze", variant="primary") with gr.Column(scale=7): output_response = gr.Textbox( label="Legal Analysis", interactive=False, lines=10 ) examples = gr.Examples( examples=[ ["What constitutes breach of contract?"], ["How to file for intellectual property protection?"], ["What are the requirements for a valid will?"] ], inputs=[input_query] ) submit_btn.click( fn=legal_analysis, inputs=[input_query], outputs=[output_response] ) if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )