|
|
|
from capstone import * |
|
import json |
|
from tqdm import tqdm |
|
import random |
|
from multiprocessing import Process, Queue |
|
from unicorn.x86_const import * |
|
from unicorn import * |
|
from datasets import concatenate_datasets |
|
from keystone import * |
|
import re |
|
from datasets import load_from_disk |
|
|
|
|
|
def test_single(code): |
|
ks = Ks(KS_ARCH_X86, KS_MODE_64) |
|
try: |
|
count = ks.asm(code) |
|
except: |
|
count = 0 |
|
return count |
|
|
|
|
|
def convert_hex_format(assembly): |
|
hex_pattern = re.compile(r'\b([0-9A-Fa-f]+)h') |
|
converted_assembly = hex_pattern.sub(r'0x\1', assembly) |
|
return converted_assembly |
|
|
|
|
|
def get_name_value(name): |
|
|
|
|
|
if name.startswith("var_"): |
|
match = re.match(r"var_(\d+)", name) |
|
if match: |
|
|
|
return int(match.group(1)) |
|
|
|
return None |
|
|
|
|
|
class KeypatchAsm: |
|
def __init__(self, arch=KS_ARCH_X86, mode=KS_MODE_64): |
|
self.arch = arch |
|
self.mode = mode |
|
self.ks = Ks(self.arch, self.mode) |
|
|
|
def fix_cmp_instruction_size(self, assembly): |
|
lines = assembly.split('\n') |
|
updated_lines = [] |
|
for line in lines: |
|
if 'cmp' in line and '[' in line and ']' in line: |
|
|
|
if ' ptr ' not in line: |
|
line = line.replace('cmp', 'cmp dword ptr', 1) |
|
elif 'cmp' in line and ':' in line: |
|
line = 'nop' |
|
updated_lines.append(line) |
|
return '\n'.join(updated_lines) |
|
|
|
def replace_calls_and_leas(self, assembly): |
|
lines = assembly.split('\n') |
|
update_lines = [] |
|
for line in lines: |
|
if ('call' in line) and not any(x in line for x in ['0x', '0X']): |
|
update_lines.append('nop') |
|
elif ('lea' in line): |
|
update_lines.append('nop') |
|
else: |
|
update_lines.append(line) |
|
|
|
return '\n'.join(update_lines) |
|
|
|
def remove_comments(self, assembly): |
|
|
|
lines = assembly.split('\n') |
|
cleaned_lines = [line.split(';', 1)[0] for line in lines] |
|
return '\n'.join(cleaned_lines).strip() |
|
|
|
def replace_segment_register_references(self, assembly): |
|
lines = assembly.split('\n') |
|
updated_lines = [] |
|
for line in lines: |
|
if 'cs:' in line: |
|
updated_lines.append('nop') |
|
else: |
|
if test_single(line) == 0 and "INSTR" not in line: |
|
updated_lines.append('nop') |
|
else: |
|
updated_lines.append(line) |
|
return '\n'.join(updated_lines) |
|
|
|
def ida_resolve(self, assembly, address): |
|
def _resolve(_op, ignore_kw=True): |
|
names = re.findall(r"[\$a-z0-9_:\.]+", _op, re.I) |
|
|
|
for name in names: |
|
|
|
if ignore_kw and name in ('byte', 'near', 'short', 'word', 'dword', 'ptr', 'offset'): |
|
continue |
|
|
|
|
|
value = get_name_value(name) |
|
if value is not None: |
|
_op = _op.replace(name, '0x'+str(value)) |
|
|
|
return _op |
|
|
|
|
|
_asm = assembly.partition(' ') |
|
mnem = _asm[0] |
|
opers = _asm[2].split(',') |
|
|
|
for idx, op in enumerate(opers): |
|
_op = list(op.partition('[')) |
|
ignore_kw = True |
|
if _op[1] == '': |
|
_op[2] = _op[0] |
|
_op[0] = '' |
|
else: |
|
_op[0] = _resolve(_op[0], ignore_kw=True) |
|
ignore_kw = False |
|
|
|
_op[2] = _resolve(_op[2], ignore_kw=ignore_kw) |
|
opers[idx] = ''.join(_op) |
|
|
|
asm = "{0} {1}".format(mnem, ','.join(opers)) |
|
return asm |
|
|
|
def assemble(self, assembly, address=0, syntax=KS_OPT_SYNTAX_INTEL): |
|
assembly = assembly.replace("endbr64\n", "") |
|
assembly = self.remove_comments(assembly) |
|
assembly = self.ida_resolve(assembly, address) |
|
assembly = self.replace_calls_and_leas(assembly) |
|
assembly = self.fix_cmp_instruction_size(assembly) |
|
assembly = self.replace_segment_register_references(assembly) |
|
|
|
def fix_ida_syntax(assembly): |
|
assembly = convert_hex_format(assembly) |
|
assembly = assembly.upper() |
|
|
|
assembly = assembly.replace("0X", " 0x") |
|
|
|
if self.arch == KS_ARCH_X86: |
|
if 'RETN' in assembly: |
|
return assembly.replace('RETN', 'RET', 1) |
|
if 'OFFSET ' in assembly: |
|
return assembly.replace('OFFSET ', ' ') |
|
return assembly |
|
|
|
if syntax is None: |
|
syntax = KS_OPT_SYNTAX_INTEL |
|
|
|
|
|
try: |
|
self.ks.syntax = syntax |
|
encoding, count = self.ks.asm(fix_ida_syntax(assembly), address) |
|
except KsError as e: |
|
print(f"Error:{e}") |
|
print(f"Assembly:\n{fix_ida_syntax(assembly)}") |
|
print("-"*50) |
|
print("") |
|
encoding, count = None, 0 |
|
|
|
return (encoding, count) |
|
|
|
|
|
UC_X86_REG_MAPPING = { |
|
UC_X86_REG_RAX: "RAX", UC_X86_REG_RBX: "RBX", UC_X86_REG_RCX: "RCX", |
|
UC_X86_REG_RDX: "RDX", UC_X86_REG_RSI: "RSI", UC_X86_REG_RDI: "RDI", |
|
UC_X86_REG_RBP: "RBP", UC_X86_REG_RSP: "RSP", UC_X86_REG_R8: "R8", |
|
UC_X86_REG_R9: "R9", UC_X86_REG_R10: "R10", UC_X86_REG_R11: "R11", |
|
UC_X86_REG_R12: "R12", UC_X86_REG_R13: "R13", UC_X86_REG_R14: "R14", |
|
UC_X86_REG_R15: "R15", UC_X86_REG_RIP: "RIP", |
|
|
|
UC_X86_REG_XMM0: "XMM0", UC_X86_REG_XMM1: "XMM1", UC_X86_REG_XMM2: "XMM2", |
|
UC_X86_REG_XMM3: "XMM3", UC_X86_REG_XMM4: "XMM4", UC_X86_REG_XMM5: "XMM5", |
|
UC_X86_REG_XMM6: "XMM6", UC_X86_REG_XMM7: "XMM7", UC_X86_REG_XMM8: "XMM8", |
|
UC_X86_REG_XMM9: "XMM9", UC_X86_REG_XMM10: "XMM10", UC_X86_REG_XMM11: "XMM11", |
|
UC_X86_REG_XMM12: "XMM12", UC_X86_REG_XMM13: "XMM13", UC_X86_REG_XMM14: "XMM14", |
|
UC_X86_REG_XMM15: "XMM15", |
|
|
|
UC_X86_REG_YMM0: "YMM0", UC_X86_REG_YMM1: "YMM1", UC_X86_REG_YMM2: "YMM2", |
|
UC_X86_REG_YMM3: "YMM3", UC_X86_REG_YMM4: "YMM4", UC_X86_REG_YMM5: "YMM5", |
|
UC_X86_REG_YMM6: "YMM6", UC_X86_REG_YMM7: "YMM7", UC_X86_REG_YMM8: "YMM8", |
|
UC_X86_REG_YMM9: "YMM9", UC_X86_REG_YMM10: "YMM10", UC_X86_REG_YMM11: "YMM11", |
|
UC_X86_REG_YMM12: "YMM12", UC_X86_REG_YMM13: "YMM13", UC_X86_REG_YMM14: "YMM14", |
|
UC_X86_REG_YMM15: "YMM15", |
|
|
|
UC_X86_REG_EFLAGS: "EFLAGS", |
|
UC_X86_REG_CS: "CS", |
|
UC_X86_REG_DS: "DS", |
|
UC_X86_REG_ES: "ES", |
|
UC_X86_REG_FS: "FS", |
|
UC_X86_REG_GS: "GS", |
|
UC_X86_REG_SS: "SS" |
|
} |
|
|
|
|
|
class MemoryAccessLogger: |
|
def __init__(self): |
|
self.read_accesses = [] |
|
self.write_accesses = [] |
|
|
|
def hook_mem_read(self, uc, access, address, size, value, user_data): |
|
self.read_accesses.append((address, size, value)) |
|
|
|
def hook_mem_write(self, uc, access, address, size, value, user_data): |
|
self.write_accesses.append((address, size, value)) |
|
|
|
|
|
def hook_mem_invalid(uc, access, address, size, value, user_data): |
|
if access == UC_MEM_WRITE_UNMAPPED or access == UC_MEM_READ_UNMAPPED or access == UC_MEM_FETCH_UNMAPPED: |
|
print(">>> Missing memory is being WRITE at 0x%x, data size = %u, data value = 0x%x" |
|
% (address, size, value)) |
|
start_map_addr = address & 0xfffffffffffff000 |
|
|
|
uc.mem_map(start_map_addr, start_map_addr+0x1000) |
|
return True |
|
return True |
|
|
|
|
|
def instruction_hook(uc, address, size, user_data): |
|
|
|
code = uc.mem_read(address, size) |
|
|
|
rbp = uc.reg_read(UC_X86_REG_RBP) |
|
rsp = uc.reg_read(UC_X86_REG_RSP) |
|
|
|
|
|
|
|
|
|
|
|
def assemble_wrapper(asm_code, code_address, result_queue): |
|
""" |
|
execute the function in new process and catch any exception to avoid crash the main process |
|
""" |
|
try: |
|
keypatch_asm = KeypatchAsm() |
|
encoding, count = keypatch_asm.assemble(asm_code, code_address) |
|
result_queue.put((encoding, count)) |
|
except Exception as e: |
|
result_queue.put((None, 0)) |
|
print("Error during assembly:", str(e)) |
|
|
|
|
|
def safe_assemble(asm_code, code_address, timeout=3): |
|
result_queue = Queue() |
|
p = Process(target=assemble_wrapper, args=( |
|
asm_code, code_address, result_queue)) |
|
p.start() |
|
p.join(timeout) |
|
|
|
if p.is_alive(): |
|
p.terminate() |
|
print("Terminated the process due to timeout.") |
|
return None, 0 |
|
|
|
try: |
|
result = result_queue.get_nowait() |
|
return result |
|
except Exception: |
|
return None, 0 |
|
|
|
|
|
md = Cs(CS_ARCH_X86, CS_MODE_64) |
|
|
|
|
|
def compile_run(asm_code, code_address, seed=0): |
|
try: |
|
random.seed(seed) |
|
encoding, count = safe_assemble(asm_code, code_address) |
|
if encoding is None or count == 0: |
|
return "ERROR", [], [] |
|
CODE_SIZE = (count+0x1000) // 0x1000 * 0x1000 |
|
CODE_ADDRESS = code_address |
|
STACK_ADDRESS = 0x7fff0000 |
|
STACK_SIZE = 0x2000 |
|
mu = Uc(UC_ARCH_X86, UC_MODE_64) |
|
mu.mem_map(CODE_ADDRESS, CODE_ADDRESS+CODE_SIZE) |
|
mu.mem_map(STACK_ADDRESS, STACK_ADDRESS+STACK_SIZE) |
|
mu.mem_write(CODE_ADDRESS, bytes(encoding)) |
|
|
|
mu.reg_write(UC_X86_REG_RAX, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_RBX, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_RCX, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_RDX, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_RSI, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_RDI, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_R8, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_R9, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_R10, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_R11, random.randint(0, 0x2000)) |
|
mu.reg_write(UC_X86_REG_R12, random.randint(0, 0x2000)) |
|
|
|
mu.reg_write(UC_X86_REG_RSP, STACK_ADDRESS + STACK_SIZE) |
|
mu.reg_write(UC_X86_REG_RBP, STACK_ADDRESS + STACK_SIZE) |
|
|
|
mu.hook_add(UC_HOOK_MEM_INVALID | |
|
UC_HOOK_MEM_UNMAPPED, hook_mem_invalid) |
|
memory_logger = MemoryAccessLogger() |
|
mu.hook_add(UC_HOOK_MEM_READ, memory_logger.hook_mem_read) |
|
mu.hook_add(UC_HOOK_MEM_WRITE, memory_logger.hook_mem_write) |
|
|
|
mu.emu_start(CODE_ADDRESS, CODE_ADDRESS + |
|
len(bytes(encoding)), timeout=0, count=1000) |
|
registers = {} |
|
for reg_id, reg_name in UC_X86_REG_MAPPING.items(): |
|
registers[reg_name] = mu.reg_read(reg_id) |
|
return registers, memory_logger.read_accesses, memory_logger.write_accesses |
|
except Exception as e: |
|
return "ERROR", [], [] |
|
|
|
|
|
|
|
|
|
ds = load_from_disk("./virtual_assembly_and_ground_truth") |
|
|
|
|
|
|
|
all_results = { |
|
'ground_truth': [], |
|
'generated': [] |
|
} |
|
test_index = [] |
|
cnt = 0 |
|
for idx, code in tqdm(enumerate(ds['asm'])): |
|
print(idx, cnt) |
|
regs, read_mem, write_mem = compile_run(code, 0x1000, cnt) |
|
if regs == "ERROR": |
|
pass |
|
else: |
|
test_index.append(idx) |
|
all_results['ground_truth'].append( |
|
{ |
|
'regs': regs, |
|
'read_mem': read_mem, |
|
'write_mem': write_mem |
|
} |
|
) |
|
cnt += 1 |
|
|
|
for seed, index in tqdm(enumerate(test_index)): |
|
code = ds[index]['generated_asm'] |
|
regs, read_mem, write_mem = compile_run(code, 0x1000, seed) |
|
if regs != "ERROR": |
|
all_results['generated'].append( |
|
{ |
|
'regs': regs, |
|
'read_mem': read_mem, |
|
'write_mem': write_mem |
|
} |
|
) |
|
else: |
|
all_results['generated'].append(None) |
|
|
|
|
|
evaluation_results = { |
|
'regs': [], |
|
'read_mem': [], |
|
'write_mem': [], |
|
} |
|
|
|
for overall_index in tqdm(range(len(test_index))): |
|
ground_truth = all_results['ground_truth'][overall_index] |
|
compare = all_results['generated'][overall_index] |
|
|
|
if compare is None: |
|
continue |
|
|
|
|
|
if len(compare['regs']) == 0: |
|
continue |
|
|
|
reg_name_list = [ |
|
'RAX', 'RSP', 'RBP' |
|
] |
|
count = 0 |
|
for reg_name in reg_name_list: |
|
if ground_truth['regs'][reg_name] == compare['regs'][reg_name]: |
|
count += 1 |
|
evaluation_results['regs'].append( |
|
float(count) / len(reg_name_list)) |
|
|
|
|
|
if len(ground_truth['read_mem']) != 0: |
|
if len(compare['read_mem']) == 0: |
|
evaluation_results['read_mem'].append(0) |
|
|
|
else: |
|
matching_score = 0 |
|
for address, size, value in compare['read_mem']: |
|
if (address, size, value) in ground_truth['read_mem']: |
|
matching_score += 1 |
|
evaluation_results['read_mem'].append( |
|
float(matching_score) / len(ground_truth['read_mem'])) |
|
|
|
|
|
if len(ground_truth['write_mem']) != 0: |
|
if len(compare['write_mem']) == 0: |
|
evaluation_results['write_mem'].append(0) |
|
|
|
else: |
|
matching_score = 0 |
|
for address, size, value in compare['write_mem']: |
|
if (address, size, value) in ground_truth['write_mem']: |
|
matching_score += 1 |
|
evaluation_results['write_mem'].append( |
|
matching_score / len(ground_truth['write_mem'])) |
|
|
|
|
|
|
|
reg_score = sum(evaluation_results['regs']) / len(evaluation_results['regs']) |
|
read_score = sum(evaluation_results['read_mem']) / len( |
|
evaluation_results['read_mem']) |
|
write_score = sum(evaluation_results['write_mem']) / len( |
|
evaluation_results['write_mem']) |
|
|
|
print(f"mean_score: {(reg_score + read_score + write_score) / 3}") |
|
|