File size: 3,428 Bytes
ec9188f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from flask import Flask, render_template, request, jsonify
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer, AutoModel
import nvdlib

# Flask app initialization
app = Flask(__name__)

# Define the model architecture
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.transformer_model = AutoModel.from_pretrained('jackaduma/SecRoBERTa')
        self.dropout = nn.Dropout(0.3)
        self.output = nn.Linear(768, 14)

    def forward(self, input_ids, attention_mask=None):
        _, o2 = self.transformer_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=False
        )
        x = self.dropout(o2)
        out = self.output(x)
        return out

# Function to predict MITRE ATT&CK techniques
def predict_techniques(model, tokenizer, cve_description, device):
    tokenized_input = tokenizer.encode_plus(
        cve_description,
        max_length=320,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    input_ids = tokenized_input['input_ids'].to(device)
    attention_mask = tokenized_input['attention_mask'].to(device)
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probs = torch.sigmoid(logits).cpu().numpy()

    predicted_techniques = np.round(probs)
    return predicted_techniques

# Global variables for model and tokenizer
global_model = None
global_tokenizer = None

# Lazy loading function to get the model and tokenizer
def get_model_and_tokenizer(device='cpu'):
    global global_model, global_tokenizer
    if global_model is None or global_tokenizer is None:
        global_model = Model()
        global_model.load_state_dict(torch.load('tactic_predict.pt', map_location=device, weights_only=True))
        global_model.to(device)
        global_model.eval()
        global_tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecRoBERTa')
    return global_model, global_tokenizer

# Route for the home page
@app.route('/')
def home():
    return render_template('index.html')

# Route to handle form submission and return results
@app.route('/predict', methods=['POST'])
def predict():
    cve_id = request.form['cve_id']
    r = nvdlib.searchCVE(cveId=cve_id)[0]
    desc_list = r.descriptions
    cve_data = next(desc.value for desc in desc_list if desc.lang == "en")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model and tokenizer lazily
    model, tokenizer = get_model_and_tokenizer(device)
    
    predicted_techniques = predict_techniques(model, tokenizer, cve_data, device)

    tactic_names = [
        "Reconnaissance", "Resource Development", "Initial Access", "Execution", 
        "Persistence", "Privilege Escalation", "Defense Evasion", 
        "Credential Access", "Discovery", "Lateral Movement", "Collection", 
        "Command and Control", "Exfiltration", "Impact"
    ]

    predicted_tactic_names = [tactic_names[i] for i, val in enumerate(predicted_techniques[0]) if val == 1]
    
    return render_template('result.html', tactics=predicted_tactic_names, cve_id=cve_id, cve_desc=cve_data)

# Run the app
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)