CVE2TTP / app.py
HarshV1315's picture
Upload 5 files
ec9188f verified
from flask import Flask, render_template, request, jsonify
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModel
import nvdlib
# Flask app initialization
app = Flask(__name__)
# Define the model architecture
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.transformer_model = AutoModel.from_pretrained('jackaduma/SecRoBERTa')
self.dropout = nn.Dropout(0.3)
self.output = nn.Linear(768, 14)
def forward(self, input_ids, attention_mask=None):
_, o2 = self.transformer_model(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=False
)
x = self.dropout(o2)
out = self.output(x)
return out
# Function to predict MITRE ATT&CK techniques
def predict_techniques(model, tokenizer, cve_description, device):
tokenized_input = tokenizer.encode_plus(
cve_description,
max_length=320,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = tokenized_input['input_ids'].to(device)
attention_mask = tokenized_input['attention_mask'].to(device)
with torch.no_grad():
logits = model(input_ids, attention_mask)
probs = torch.sigmoid(logits).cpu().numpy()
predicted_techniques = np.round(probs)
return predicted_techniques
# Global variables for model and tokenizer
global_model = None
global_tokenizer = None
# Lazy loading function to get the model and tokenizer
def get_model_and_tokenizer(device='cpu'):
global global_model, global_tokenizer
if global_model is None or global_tokenizer is None:
global_model = Model()
global_model.load_state_dict(torch.load('tactic_predict.pt', map_location=device, weights_only=True))
global_model.to(device)
global_model.eval()
global_tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecRoBERTa')
return global_model, global_tokenizer
# Route for the home page
@app.route('/')
def home():
return render_template('index.html')
# Route to handle form submission and return results
@app.route('/predict', methods=['POST'])
def predict():
cve_id = request.form['cve_id']
r = nvdlib.searchCVE(cveId=cve_id)[0]
desc_list = r.descriptions
cve_data = next(desc.value for desc in desc_list if desc.lang == "en")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model and tokenizer lazily
model, tokenizer = get_model_and_tokenizer(device)
predicted_techniques = predict_techniques(model, tokenizer, cve_data, device)
tactic_names = [
"Reconnaissance", "Resource Development", "Initial Access", "Execution",
"Persistence", "Privilege Escalation", "Defense Evasion",
"Credential Access", "Discovery", "Lateral Movement", "Collection",
"Command and Control", "Exfiltration", "Impact"
]
predicted_tactic_names = [tactic_names[i] for i, val in enumerate(predicted_techniques[0]) if val == 1]
return render_template('result.html', tactics=predicted_tactic_names, cve_id=cve_id, cve_desc=cve_data)
# Run the app
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)