import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer import torch import json import logging import re # Set up logging logging.basicConfig( filename="app.log", level=logging.INFO, format="%(asctime)s:%(levelname)s:%(message)s" ) # Model and tokenizer loading function with caching @st.cache_resource def load_model(): """ Loads and caches the pre-trained language model and tokenizer. Returns: model: Pre-trained language model. tokenizer: Tokenizer for the model. """ try: device = "cuda" if torch.cuda.is_available() else "cpu" model_path = "Canstralian/pentest_ai" # Replace with the actual path if different model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map={"": device}, # This will specify CPU or GPU explicitly load_in_8bit=False, # Disabled for stability trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) logging.info("Model and tokenizer loaded successfully.") return model, tokenizer except Exception as e: logging.error(f"Error loading model: {e}") st.error("Failed to load model. Please check the logs.") return None, None def sanitize_input(text): """ Sanitizes and validates user input text to prevent injection or formatting issues. Args: text (str): User input text. Returns: str: Sanitized text. """ if not isinstance(text, str): raise ValueError("Input must be a string.") # Basic sanitization to remove unwanted characters sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text) return sanitized_text.strip() def generate_text(model, tokenizer, instruction): """ Generates text based on the provided instruction using the loaded model. Args: model: The language model. tokenizer: Tokenizer for encoding/decoding. instruction (str): Instruction text for the model. Returns: str: Generated text response from the model. """ try: # Validate and sanitize instruction input instruction = sanitize_input(instruction) device = "cuda" if torch.cuda.is_available() else "cpu" tokens = tokenizer.encode(instruction, return_tensors='pt').to(device) generated_tokens = model.generate( tokens, max_length=1024, top_p=1.0, temperature=0.5, top_k=50 ) generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) logging.info("Text generated successfully.") return generated_text except Exception as e: logging.error(f"Error generating text: {e}") return "Error in text generation." @st.cache_data def load_json_data(): """ Loads JSON data, simulating the loading process with a sample list. Returns: list: A list of dictionaries with sample user data. """ try: json_data = [ {"name": "Raja Clarke", "email": "consectetuer@yahoo.edu", "country": "Chile", "company": "Urna Nunc Consulting"}, {"name": "Melissa Hobbs", "email": "massa.non@hotmail.couk", "country": "France", "company": "Gravida Mauris Limited"}, {"name": "John Doe", "email": "john.doe@example.com", "country": "USA", "company": "Example Corp"}, {"name": "Jane Smith", "email": "jane.smith@example.org", "country": "Canada", "company": "Innovative Solutions Inc"} ] logging.info("User JSON data loaded successfully.") return json_data except Exception as e: logging.error(f"Error loading JSON data: {e}") return [] # Streamlit App st.title("Penetration Testing AI Assistant") # Load the model and tokenizer model, tokenizer = load_model() if not model or not tokenizer: st.error("Failed to load model or tokenizer. Please check your configuration.") # User instruction input instruction = st.text_input("Enter an instruction for the model:") # Generate text button if instruction: try: generated_text = generate_text(model, tokenizer, instruction) st.subheader("Generated Text:") st.write(generated_text) except ValueError as ve: st.error(f"Invalid input: {ve}") except Exception as e: logging.error(f"Error during text generation: {e}") st.error("An error occurred. Please try again.") # Display JSON user data st.subheader("User Data (from JSON)") user_data = load_json_data() for user in user_data: st.write(f"**Name:** {user['name']}") st.write(f"**Email:** {user['email']}") st.write(f"**Country:** {user['country']}") st.write(f"**Company:** {user['company']}") st.write("---")