Spaces:
Build error
Build error
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import json | |
import logging | |
import re | |
# Set up logging | |
logging.basicConfig( | |
filename="app.log", | |
level=logging.INFO, | |
format="%(asctime)s:%(levelname)s:%(message)s" | |
) | |
# Model and tokenizer loading function with caching | |
def load_model(): | |
""" | |
Loads and caches the pre-trained language model and tokenizer. | |
Returns: | |
model: Pre-trained language model. | |
tokenizer: Tokenizer for the model. | |
""" | |
model_path = "Canstralian/pentest_ai" | |
try: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", | |
load_in_4bit=False, | |
load_in_8bit=True, | |
trust_remote_code=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
logging.info("Model and tokenizer loaded successfully.") | |
return model, tokenizer | |
except Exception as e: | |
logging.error(f"Error loading model: {e}") | |
st.error("Failed to load model. Please check the logs.") | |
return None, None | |
def sanitize_input(text): | |
""" | |
Sanitizes and validates user input text to prevent injection or formatting issues. | |
Args: | |
text (str): User input text. | |
Returns: | |
str: Sanitized text. | |
""" | |
if not isinstance(text, str): | |
raise ValueError("Input must be a string.") | |
# Basic sanitization to remove unwanted characters | |
sanitized_text = re.sub(r"[^a-zA-Z0-9\s\.,!?]", "", text) | |
return sanitized_text.strip() | |
def generate_text(model, tokenizer, instruction): | |
""" | |
Generates text based on the provided instruction using the loaded model. | |
Args: | |
model: The language model. | |
tokenizer: Tokenizer for encoding/decoding. | |
instruction (str): Instruction text for the model. | |
Returns: | |
str: Generated text response from the model. | |
""" | |
try: | |
# Validate and sanitize instruction input | |
instruction = sanitize_input(instruction) | |
tokens = tokenizer.encode(instruction, return_tensors='pt').to('cuda') | |
generated_tokens = model.generate( | |
tokens, | |
max_length=1024, | |
top_p=1.0, | |
temperature=0.5, | |
top_k=50 | |
) | |
generated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True) | |
logging.info("Text generated successfully.") | |
return generated_text | |
except Exception as e: | |
logging.error(f"Error generating text: {e}") | |
return "Error in text generation." | |
def load_json_data(): | |
""" | |
Loads JSON data, simulating the loading process with a sample list. | |
Returns: | |
list: A list of dictionaries with sample user data. | |
""" | |
try: | |
json_data = [ | |
{"name": "Raja Clarke", "email": "consectetuer@yahoo.edu", "country": "Chile", "company": "Urna Nunc Consulting"}, | |
{"name": "Melissa Hobbs", "email": "massa.non@hotmail.couk", "country": "France", "company": "Gravida Mauris Limited"}, | |
{"name": "John Doe", "email": "john.doe@example.com", "country": "USA", "company": "Example Corp"}, | |
{"name": "Jane Smith", "email": "jane.smith@example.org", "country": "Canada", "company": "Innovative Solutions Inc"} | |
] | |
logging.info("User JSON data loaded successfully.") | |
return json_data | |
except Exception as e: | |
logging.error(f"Error loading JSON data: {e}") | |
return [] | |
# Streamlit App | |
st.title("Penetration Testing AI Assistant") | |
# Load the model and tokenizer | |
model, tokenizer = load_model() | |
# User instruction input | |
instruction = st.text_input("Enter an instruction for the model:") | |
# Generate text button | |
if instruction: | |
try: | |
generated_text = generate_text(model, tokenizer, instruction) | |
st.subheader("Generated Text:") | |
st.write(generated_text) | |
except ValueError as ve: | |
st.error(f"Invalid input: {ve}") | |
except Exception as e: | |
logging.error(f"Error during text generation: {e}") | |
st.error("An error occurred. Please try again.") | |
# Display JSON user data | |
st.subheader("User Data (from JSON)") | |
user_data = load_json_data() | |
for user in user_data: | |
st.write(f"**Name:** {user['name']}") | |
st.write(f"**Email:** {user['email']}") | |
st.write(f"**Country:** {user['country']}") | |
st.write(f"**Company:** {user['company']}") | |
st.write("---") | |