# %% from capstone import * import json from tqdm import tqdm import random from multiprocessing import Process, Queue from unicorn.x86_const import * from unicorn import * from datasets import concatenate_datasets from keystone import * import re from datasets import load_from_disk def test_single(code): ks = Ks(KS_ARCH_X86, KS_MODE_64) try: count = ks.asm(code) except: count = 0 return count def convert_hex_format(assembly): hex_pattern = re.compile(r'\b([0-9A-Fa-f]+)h') converted_assembly = hex_pattern.sub(r'0x\1', assembly) return converted_assembly def get_name_value(name): # match "VAR_num" label if name.startswith("var_"): match = re.match(r"var_(\d+)", name) if match: # get number return int(match.group(1)) # else case return None class KeypatchAsm: def __init__(self, arch=KS_ARCH_X86, mode=KS_MODE_64): self.arch = arch self.mode = mode self.ks = Ks(self.arch, self.mode) def fix_cmp_instruction_size(self, assembly): lines = assembly.split('\n') updated_lines = [] for line in lines: if 'cmp' in line and '[' in line and ']' in line: # add default size indicator 'dword ptr' if ' ptr ' not in line: line = line.replace('cmp', 'cmp dword ptr', 1) elif 'cmp' in line and ':' in line: line = 'nop' updated_lines.append(line) return '\n'.join(updated_lines) def replace_calls_and_leas(self, assembly): lines = assembly.split('\n') update_lines = [] for line in lines: if ('call' in line) and not any(x in line for x in ['0x', '0X']): update_lines.append('nop') elif ('lea' in line): update_lines.append('nop') else: update_lines.append(line) # print(line) return '\n'.join(update_lines) def remove_comments(self, assembly): # remove ';' lines = assembly.split('\n') cleaned_lines = [line.split(';', 1)[0] for line in lines] return '\n'.join(cleaned_lines).strip() def replace_segment_register_references(self, assembly): lines = assembly.split('\n') updated_lines = [] for line in lines: if 'cs:' in line: updated_lines.append('nop') else: if test_single(line) == 0 and "INSTR" not in line: updated_lines.append('nop') else: updated_lines.append(line) return '\n'.join(updated_lines) def ida_resolve(self, assembly, address): def _resolve(_op, ignore_kw=True): names = re.findall(r"[\$a-z0-9_:\.]+", _op, re.I) for name in names: # ingnore known key words if ignore_kw and name in ('byte', 'near', 'short', 'word', 'dword', 'ptr', 'offset'): continue # use get_name_value fucntion value = get_name_value(name) if value is not None: _op = _op.replace(name, '0x'+str(value)) return _op # split the part and anylaize each oprand _asm = assembly.partition(' ') mnem = _asm[0] opers = _asm[2].split(',') for idx, op in enumerate(opers): _op = list(op.partition('[')) ignore_kw = True if _op[1] == '': _op[2] = _op[0] _op[0] = '' else: _op[0] = _resolve(_op[0], ignore_kw=True) ignore_kw = False _op[2] = _resolve(_op[2], ignore_kw=ignore_kw) opers[idx] = ''.join(_op) asm = "{0} {1}".format(mnem, ','.join(opers)) return asm def assemble(self, assembly, address=0, syntax=KS_OPT_SYNTAX_INTEL): assembly = assembly.replace("endbr64\n", "") assembly = self.remove_comments(assembly) assembly = self.ida_resolve(assembly, address) assembly = self.replace_calls_and_leas(assembly) assembly = self.fix_cmp_instruction_size(assembly) assembly = self.replace_segment_register_references(assembly) def fix_ida_syntax(assembly): assembly = convert_hex_format(assembly) assembly = assembly.upper() assembly = assembly.replace("0X", " 0x") if self.arch == KS_ARCH_X86: if 'RETN' in assembly: return assembly.replace('RETN', 'RET', 1) if 'OFFSET ' in assembly: return assembly.replace('OFFSET ', ' ') return assembly if syntax is None: syntax = KS_OPT_SYNTAX_INTEL # print(fix_ida_syntax(assembly)) try: self.ks.syntax = syntax encoding, count = self.ks.asm(fix_ida_syntax(assembly), address) except KsError as e: print(f"Error:{e}") print(f"Assembly:\n{fix_ida_syntax(assembly)}") print("-"*50) print("") encoding, count = None, 0 return (encoding, count) UC_X86_REG_MAPPING = { UC_X86_REG_RAX: "RAX", UC_X86_REG_RBX: "RBX", UC_X86_REG_RCX: "RCX", UC_X86_REG_RDX: "RDX", UC_X86_REG_RSI: "RSI", UC_X86_REG_RDI: "RDI", UC_X86_REG_RBP: "RBP", UC_X86_REG_RSP: "RSP", UC_X86_REG_R8: "R8", UC_X86_REG_R9: "R9", UC_X86_REG_R10: "R10", UC_X86_REG_R11: "R11", UC_X86_REG_R12: "R12", UC_X86_REG_R13: "R13", UC_X86_REG_R14: "R14", UC_X86_REG_R15: "R15", UC_X86_REG_RIP: "RIP", # FPU register, vector register and flag register UC_X86_REG_XMM0: "XMM0", UC_X86_REG_XMM1: "XMM1", UC_X86_REG_XMM2: "XMM2", UC_X86_REG_XMM3: "XMM3", UC_X86_REG_XMM4: "XMM4", UC_X86_REG_XMM5: "XMM5", UC_X86_REG_XMM6: "XMM6", UC_X86_REG_XMM7: "XMM7", UC_X86_REG_XMM8: "XMM8", UC_X86_REG_XMM9: "XMM9", UC_X86_REG_XMM10: "XMM10", UC_X86_REG_XMM11: "XMM11", UC_X86_REG_XMM12: "XMM12", UC_X86_REG_XMM13: "XMM13", UC_X86_REG_XMM14: "XMM14", UC_X86_REG_XMM15: "XMM15", # YMM register UC_X86_REG_YMM0: "YMM0", UC_X86_REG_YMM1: "YMM1", UC_X86_REG_YMM2: "YMM2", UC_X86_REG_YMM3: "YMM3", UC_X86_REG_YMM4: "YMM4", UC_X86_REG_YMM5: "YMM5", UC_X86_REG_YMM6: "YMM6", UC_X86_REG_YMM7: "YMM7", UC_X86_REG_YMM8: "YMM8", UC_X86_REG_YMM9: "YMM9", UC_X86_REG_YMM10: "YMM10", UC_X86_REG_YMM11: "YMM11", UC_X86_REG_YMM12: "YMM12", UC_X86_REG_YMM13: "YMM13", UC_X86_REG_YMM14: "YMM14", UC_X86_REG_YMM15: "YMM15", # EFLAGS register segment register UC_X86_REG_EFLAGS: "EFLAGS", UC_X86_REG_CS: "CS", UC_X86_REG_DS: "DS", UC_X86_REG_ES: "ES", UC_X86_REG_FS: "FS", UC_X86_REG_GS: "GS", UC_X86_REG_SS: "SS" } class MemoryAccessLogger: def __init__(self): self.read_accesses = [] self.write_accesses = [] def hook_mem_read(self, uc, access, address, size, value, user_data): self.read_accesses.append((address, size, value)) def hook_mem_write(self, uc, access, address, size, value, user_data): self.write_accesses.append((address, size, value)) def hook_mem_invalid(uc, access, address, size, value, user_data): if access == UC_MEM_WRITE_UNMAPPED or access == UC_MEM_READ_UNMAPPED or access == UC_MEM_FETCH_UNMAPPED: print(">>> Missing memory is being WRITE at 0x%x, data size = %u, data value = 0x%x" % (address, size, value)) start_map_addr = address & 0xfffffffffffff000 uc.mem_map(start_map_addr, start_map_addr+0x1000) return True return True def instruction_hook(uc, address, size, user_data): # get the current instruction code = uc.mem_read(address, size) rbp = uc.reg_read(UC_X86_REG_RBP) rsp = uc.reg_read(UC_X86_REG_RSP) # print(f"RBP: 0x{rbp:016x}, RSP: 0x{rsp:016x}") # for i in md.disasm(code, address): # print("0x%x:\t%s\t%s" %(i.address, i.mnemonic, i.op_str)) def assemble_wrapper(asm_code, code_address, result_queue): """ execute the function in new process and catch any exception to avoid crash the main process """ try: keypatch_asm = KeypatchAsm() encoding, count = keypatch_asm.assemble(asm_code, code_address) result_queue.put((encoding, count)) except Exception as e: result_queue.put((None, 0)) print("Error during assembly:", str(e)) def safe_assemble(asm_code, code_address, timeout=3): result_queue = Queue() p = Process(target=assemble_wrapper, args=( asm_code, code_address, result_queue)) p.start() p.join(timeout) if p.is_alive(): p.terminate() print("Terminated the process due to timeout.") return None, 0 try: result = result_queue.get_nowait() return result except Exception: return None, 0 md = Cs(CS_ARCH_X86, CS_MODE_64) def compile_run(asm_code, code_address, seed=0): try: random.seed(seed) encoding, count = safe_assemble(asm_code, code_address) if encoding is None or count == 0: return "ERROR", [], [] CODE_SIZE = (count+0x1000) // 0x1000 * 0x1000 CODE_ADDRESS = code_address STACK_ADDRESS = 0x7fff0000 STACK_SIZE = 0x2000 mu = Uc(UC_ARCH_X86, UC_MODE_64) mu.mem_map(CODE_ADDRESS, CODE_ADDRESS+CODE_SIZE) mu.mem_map(STACK_ADDRESS, STACK_ADDRESS+STACK_SIZE) mu.mem_write(CODE_ADDRESS, bytes(encoding)) mu.reg_write(UC_X86_REG_RAX, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_RBX, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_RCX, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_RDX, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_RSI, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_RDI, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_R8, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_R9, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_R10, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_R11, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_R12, random.randint(0, 0x2000)) mu.reg_write(UC_X86_REG_RSP, STACK_ADDRESS + STACK_SIZE) mu.reg_write(UC_X86_REG_RBP, STACK_ADDRESS + STACK_SIZE) mu.hook_add(UC_HOOK_MEM_INVALID | UC_HOOK_MEM_UNMAPPED, hook_mem_invalid) memory_logger = MemoryAccessLogger() mu.hook_add(UC_HOOK_MEM_READ, memory_logger.hook_mem_read) mu.hook_add(UC_HOOK_MEM_WRITE, memory_logger.hook_mem_write) mu.emu_start(CODE_ADDRESS, CODE_ADDRESS + len(bytes(encoding)), timeout=0, count=1000) registers = {} for reg_id, reg_name in UC_X86_REG_MAPPING.items(): registers[reg_name] = mu.reg_read(reg_id) return registers, memory_logger.read_accesses, memory_logger.write_accesses except Exception as e: return "ERROR", [], [] # %% ds = load_from_disk("./virtual_assembly_and_ground_truth") # %% all_results = { 'ground_truth': [], 'generated': [] } test_index = [] cnt = 0 for idx, code in tqdm(enumerate(ds['asm'])): print(idx, cnt) regs, read_mem, write_mem = compile_run(code, 0x1000, cnt) if regs == "ERROR": pass else: test_index.append(idx) all_results['ground_truth'].append( { 'regs': regs, 'read_mem': read_mem, 'write_mem': write_mem } ) cnt += 1 for seed, index in tqdm(enumerate(test_index)): code = ds[index]['generated_asm'] regs, read_mem, write_mem = compile_run(code, 0x1000, seed) if regs != "ERROR": all_results['generated'].append( { 'regs': regs, 'read_mem': read_mem, 'write_mem': write_mem } ) else: all_results['generated'].append(None) evaluation_results = { 'regs': [], 'read_mem': [], 'write_mem': [], } for overall_index in tqdm(range(len(test_index))): ground_truth = all_results['ground_truth'][overall_index] compare = all_results['generated'][overall_index] if compare is None: continue # compare regs if len(compare['regs']) == 0: continue reg_name_list = [ 'RAX', 'RSP', 'RBP' ] count = 0 for reg_name in reg_name_list: if ground_truth['regs'][reg_name] == compare['regs'][reg_name]: count += 1 evaluation_results['regs'].append( float(count) / len(reg_name_list)) # compare read_mem if len(ground_truth['read_mem']) != 0: if len(compare['read_mem']) == 0: evaluation_results['read_mem'].append(0) # calculate the matching score of read_mem else: matching_score = 0 for address, size, value in compare['read_mem']: if (address, size, value) in ground_truth['read_mem']: matching_score += 1 evaluation_results['read_mem'].append( float(matching_score) / len(ground_truth['read_mem'])) # compare write_mem if len(ground_truth['write_mem']) != 0: if len(compare['write_mem']) == 0: evaluation_results['write_mem'].append(0) # calculate the matching score of write_mem else: matching_score = 0 for address, size, value in compare['write_mem']: if (address, size, value) in ground_truth['write_mem']: matching_score += 1 evaluation_results['write_mem'].append( matching_score / len(ground_truth['write_mem'])) # %% reg_score = sum(evaluation_results['regs']) / len(evaluation_results['regs']) read_score = sum(evaluation_results['read_mem']) / len( evaluation_results['read_mem']) write_score = sum(evaluation_results['write_mem']) / len( evaluation_results['write_mem']) print(f"mean_score: {(reg_score + read_score + write_score) / 3}")