Spaces:
Runtime error
Runtime error
from flask import Flask, render_template, request, jsonify | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModel | |
import nvdlib | |
# Flask app initialization | |
app = Flask(__name__) | |
# Define the model architecture | |
class Model(nn.Module): | |
def __init__(self): | |
super(Model, self).__init__() | |
self.transformer_model = AutoModel.from_pretrained('jackaduma/SecRoBERTa') | |
self.dropout = nn.Dropout(0.3) | |
self.output = nn.Linear(768, 14) | |
def forward(self, input_ids, attention_mask=None): | |
_, o2 = self.transformer_model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
return_dict=False | |
) | |
x = self.dropout(o2) | |
out = self.output(x) | |
return out | |
# Function to predict MITRE ATT&CK techniques | |
def predict_techniques(model, tokenizer, cve_description, device): | |
tokenized_input = tokenizer.encode_plus( | |
cve_description, | |
max_length=320, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt' | |
) | |
input_ids = tokenized_input['input_ids'].to(device) | |
attention_mask = tokenized_input['attention_mask'].to(device) | |
with torch.no_grad(): | |
logits = model(input_ids, attention_mask) | |
probs = torch.sigmoid(logits).cpu().numpy() | |
predicted_techniques = np.round(probs) | |
return predicted_techniques | |
# Global variables for model and tokenizer | |
global_model = None | |
global_tokenizer = None | |
# Lazy loading function to get the model and tokenizer | |
def get_model_and_tokenizer(device='cpu'): | |
global global_model, global_tokenizer | |
if global_model is None or global_tokenizer is None: | |
global_model = Model() | |
global_model.load_state_dict(torch.load('tactic_predict.pt', map_location=device, weights_only=True)) | |
global_model.to(device) | |
global_model.eval() | |
global_tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecRoBERTa') | |
return global_model, global_tokenizer | |
# Route for the home page | |
def home(): | |
return render_template('index.html') | |
# Route to handle form submission and return results | |
def predict(): | |
cve_id = request.form['cve_id'] | |
r = nvdlib.searchCVE(cveId=cve_id)[0] | |
desc_list = r.descriptions | |
cve_data = next(desc.value for desc in desc_list if desc.lang == "en") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load model and tokenizer lazily | |
model, tokenizer = get_model_and_tokenizer(device) | |
predicted_techniques = predict_techniques(model, tokenizer, cve_data, device) | |
tactic_names = [ | |
"Reconnaissance", "Resource Development", "Initial Access", "Execution", | |
"Persistence", "Privilege Escalation", "Defense Evasion", | |
"Credential Access", "Discovery", "Lateral Movement", "Collection", | |
"Command and Control", "Exfiltration", "Impact" | |
] | |
predicted_tactic_names = [tactic_names[i] for i, val in enumerate(predicted_techniques[0]) if val == 1] | |
return render_template('result.html', tactics=predicted_tactic_names, cve_id=cve_id, cve_desc=cve_data) | |
# Run the app | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |