Canstralian's picture
Update app.py
16bf80f verified
raw
history blame
4.62 kB
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import logging
import re
# Set up logging
logging.basicConfig(
filename="app.log",
level=logging.INFO,
format="%(asctime)s:%(levelname)s:%(message)s"
)
# Model and tokenizer loading function with caching
@st.cache_resource
def load_model():
"""
Loads and caches the pre-trained language model and tokenizer.
Returns:
model: Pre-trained language model.
tokenizer: Tokenizer for the model.
"""
model_path = "Canstralian/pentest_ai"
try:
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
load_in_4bit=False,
load_in_8bit=True,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
logging.info("Model and tokenizer loaded successfully.")
return model, tokenizer
except Exception as e:
logging.error(f"Error loading model: {e}")
st.error("Failed to load model. Please check the logs.")
return None, None
def sanitize_input(text):
"""
Sanitizes and validates user input text to prevent injection or formatting issues.
Args:
text (str): User input text.
Returns:
str: Sanitized text.
"""
if not isinstance(text, str):
raise ValueError("Input must be a string.")
# Basic sanitization to remove unwanted characters
sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text)
return sanitized_text.strip()
def generate_text(model, tokenizer, instruction):
"""
Generates text based on the provided instruction using the loaded model.
Args:
model: The language model.
tokenizer: Tokenizer for encoding/decoding.
instruction (str): Instruction text for the model.
Returns:
str: Generated text response from the model.
"""
try:
# Validate and sanitize instruction input
instruction = sanitize_input(instruction)
tokens = tokenizer.encode(instruction, return_tensors='pt').to('cuda')
generated_tokens = model.generate(
tokens,
max_length=1024,
top_p=1.0,
temperature=0.5,
top_k=50
)
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
logging.info("Text generated successfully.")
return generated_text
except Exception as e:
logging.error(f"Error generating text: {e}")
return "Error in text generation."
@st.cache_data
def load_json_data():
"""
Loads JSON data, simulating the loading process with a sample list.
Returns:
list: A list of dictionaries with sample user data.
"""
try:
json_data = [
{"name": "Raja Clarke", "email": "consectetuer@yahoo.edu", "country": "Chile", "company": "Urna Nunc Consulting"},
{"name": "Melissa Hobbs", "email": "massa.non@hotmail.couk", "country": "France", "company": "Gravida Mauris Limited"},
{"name": "John Doe", "email": "john.doe@example.com", "country": "USA", "company": "Example Corp"},
{"name": "Jane Smith", "email": "jane.smith@example.org", "country": "Canada", "company": "Innovative Solutions Inc"}
]
logging.info("User JSON data loaded successfully.")
return json_data
except Exception as e:
logging.error(f"Error loading JSON data: {e}")
return []
# Streamlit App
st.title("Penetration Testing AI Assistant")
# Load the model and tokenizer
model, tokenizer = load_model()
# User instruction input
instruction = st.text_input("Enter an instruction for the model:")
# Generate text button
if instruction:
try:
generated_text = generate_text(model, tokenizer, instruction)
st.subheader("Generated Text:")
st.write(generated_text)
except ValueError as ve:
st.error(f"Invalid input: {ve}")
except Exception as e:
logging.error(f"Error during text generation: {e}")
st.error("An error occurred. Please try again.")
# Display JSON user data
st.subheader("User Data (from JSON)")
user_data = load_json_data()
for user in user_data:
st.write(f"**Name:** {user['name']}")
st.write(f"**Email:** {user['email']}")
st.write(f"**Country:** {user['country']}")
st.write(f"**Company:** {user['company']}")
st.write("---")