HarshV1315 commited on
Commit
ec9188f
1 Parent(s): 43f1117

Upload 5 files

Browse files
Files changed (5) hide show
  1. .dockerignore +5 -0
  2. Dockerfile +20 -0
  3. app.py +97 -0
  4. requirements.txt +5 -0
  5. tactic_predict.pt +3 -0
.dockerignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .env
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the current directory contents into the container at /app
8
+ COPY . /app
9
+
10
+ # Install any needed packages specified in requirements.txt
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ # Make port 5000 available to the world outside this container
14
+ EXPOSE 5000
15
+
16
+ # Define environment variable
17
+ ENV FLASK_APP=app.py
18
+
19
+ # Run app.py when the container launches
20
+ CMD ["flask", "run", "--host", "0.0.0.0"]
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from transformers import AutoTokenizer, AutoModel
6
+ import nvdlib
7
+
8
+ # Flask app initialization
9
+ app = Flask(__name__)
10
+
11
+ # Define the model architecture
12
+ class Model(nn.Module):
13
+ def __init__(self):
14
+ super(Model, self).__init__()
15
+ self.transformer_model = AutoModel.from_pretrained('jackaduma/SecRoBERTa')
16
+ self.dropout = nn.Dropout(0.3)
17
+ self.output = nn.Linear(768, 14)
18
+
19
+ def forward(self, input_ids, attention_mask=None):
20
+ _, o2 = self.transformer_model(
21
+ input_ids=input_ids,
22
+ attention_mask=attention_mask,
23
+ return_dict=False
24
+ )
25
+ x = self.dropout(o2)
26
+ out = self.output(x)
27
+ return out
28
+
29
+ # Function to predict MITRE ATT&CK techniques
30
+ def predict_techniques(model, tokenizer, cve_description, device):
31
+ tokenized_input = tokenizer.encode_plus(
32
+ cve_description,
33
+ max_length=320,
34
+ padding='max_length',
35
+ truncation=True,
36
+ return_attention_mask=True,
37
+ return_tensors='pt'
38
+ )
39
+ input_ids = tokenized_input['input_ids'].to(device)
40
+ attention_mask = tokenized_input['attention_mask'].to(device)
41
+ with torch.no_grad():
42
+ logits = model(input_ids, attention_mask)
43
+ probs = torch.sigmoid(logits).cpu().numpy()
44
+
45
+ predicted_techniques = np.round(probs)
46
+ return predicted_techniques
47
+
48
+ # Global variables for model and tokenizer
49
+ global_model = None
50
+ global_tokenizer = None
51
+
52
+ # Lazy loading function to get the model and tokenizer
53
+ def get_model_and_tokenizer(device='cpu'):
54
+ global global_model, global_tokenizer
55
+ if global_model is None or global_tokenizer is None:
56
+ global_model = Model()
57
+ global_model.load_state_dict(torch.load('tactic_predict.pt', map_location=device, weights_only=True))
58
+ global_model.to(device)
59
+ global_model.eval()
60
+ global_tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecRoBERTa')
61
+ return global_model, global_tokenizer
62
+
63
+ # Route for the home page
64
+ @app.route('/')
65
+ def home():
66
+ return render_template('index.html')
67
+
68
+ # Route to handle form submission and return results
69
+ @app.route('/predict', methods=['POST'])
70
+ def predict():
71
+ cve_id = request.form['cve_id']
72
+ r = nvdlib.searchCVE(cveId=cve_id)[0]
73
+ desc_list = r.descriptions
74
+ cve_data = next(desc.value for desc in desc_list if desc.lang == "en")
75
+
76
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
77
+
78
+ # Load model and tokenizer lazily
79
+ model, tokenizer = get_model_and_tokenizer(device)
80
+
81
+ predicted_techniques = predict_techniques(model, tokenizer, cve_data, device)
82
+
83
+ tactic_names = [
84
+ "Reconnaissance", "Resource Development", "Initial Access", "Execution",
85
+ "Persistence", "Privilege Escalation", "Defense Evasion",
86
+ "Credential Access", "Discovery", "Lateral Movement", "Collection",
87
+ "Command and Control", "Exfiltration", "Impact"
88
+ ]
89
+
90
+ predicted_tactic_names = [tactic_names[i] for i, val in enumerate(predicted_techniques[0]) if val == 1]
91
+
92
+ return render_template('result.html', tactics=predicted_tactic_names, cve_id=cve_id, cve_desc=cve_data)
93
+
94
+ # Run the app
95
+ if __name__ == "__main__":
96
+ app.run(host="0.0.0.0", port=7860)
97
+
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Flask
2
+ torch
3
+ numpy
4
+ transformers
5
+ nvdlib
tactic_predict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bdbe380766df249f0c50081b3a482e6eab588552603c23329087342ed4190870
3
+ size 333893590