GPTLens / src /run_auditor.py
Aishwarya Solanki
initial commit
ee7776a
import json
import random
import argparse
import os
from tqdm import tqdm
from utils import dotdict
from stqdm import stqdm
import openai
from model import gpt, gpt_usage, OPENAI_API_KEY
from prompts import auditor_prompt, auditor_format_constrain
from prompts import topk_prompt1, topk_prompt2
completion_tokens = 0
prompt_tokens = 0
def remove_spaces(s):
return ' '.join(s.split())
def prompt_wrap(prompt, format_constraint, code, topk):
return prompt + code + format_constraint + topk_prompt1.format(topk=topk) + topk_prompt2
def auditor_response_parse(auditor_outputs):
output_list = []
for auditor_output in auditor_outputs:
try:
start_idx = auditor_output.find("{")
end_idx = auditor_output.rfind("}")
data = json.loads(auditor_output[start_idx: end_idx+1])
except:
print("parsing json fail.")
continue
try:
output_list += data["output_list"]
except:
print("No vulnerability detected")
continue
return output_list
def solve(args, code):
bug_info_list = []
auditor_input = prompt_wrap(auditor_prompt, auditor_format_constrain, code, args.topk)
try:
auditor_outputs = gpt(auditor_input, model=args.backend, temperature=args.temperature, n=args.num_auditor)
bug_info_list = auditor_response_parse(auditor_outputs)
except Exception as e:
print(e)
return bug_info_list
def run(args):
if args.get('openai_api_key') is None:
openai.api_key = OPENAI_API_KEY
else:
openai.api_key = args.openai_api_key
with open("data/CVE_label/CVE2description.json", "r") as f:
CVE2description = json.load(f)
with open("data/CVE_label/CVE2label.json", "r") as f:
CVE2label = json.load(f)
# log output file
log_dir = f"./src/logs/auditor_{args.backend}_{args.temperature}_top{args.topk}_{args.num_auditor}"
for CVE_index, label in stqdm(CVE2label.items()):
all_bug_info_list = []
description = CVE2description[CVE_index]
file_name = "-".join(CVE_index.split("-")[1:]) + ".sol"
with open("data/CVE_clean/" + file_name, "r") as f:
code = f.read()
# remove space
code = remove_spaces(code)
# auditing
bug_info_list = solve(args, code)
if len(bug_info_list) == 0: #Sometimes the query fails because the model does not strictly follow the format
print("{index} failed".format(index=CVE_index))
continue
for info in bug_info_list:
info.update({"file_name": file_name, "label": label, "description": description})
all_bug_info_list.append(info)
file = f"{log_dir}/{CVE_index}.json"
os.makedirs(os.path.dirname(file), exist_ok=True)
with open(file, 'w') as f:
json.dump(all_bug_info_list, f, indent=4)
def parse_args():
args = argparse.ArgumentParser()
args.add_argument('--backend', type=str, choices=['gpt-3.5-turbo','gpt-4', 'gpt-4-turbo-preview'], default='gpt-4-turbo-preview')
args.add_argument('--temperature', type=float, default=0.7)
args.add_argument('--dataset', type=str, default="CVE")
args.add_argument('--topk', type=int, default=5) # the topk per each auditor
args.add_argument('--num_auditor', type=int, default=1)
args = args.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print(args)
run(args)
def mainfnc(args=dotdict):
# args = parse_args()
# print(args)
run(args)