1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
| def main(): parser = argparse.ArgumentParser(description="deflat control flow script") parser.add_argument("-f", "--file", help="binary to analyze") parser.add_argument( "--addr", help="address of target function in hex format") args = parser.parse_args()
if args.file is None or args.addr is None: parser.print_help() sys.exit(0)
filename = args.file start = int(args.addr, 16)
project = angr.Project(filename, load_options={'auto_load_libs': False}) cfg = project.analyses.CFGFast(normalize=True, force_complete_scan=False) target_function = cfg.functions.get(start) supergraph = am_graph.to_supergraph(target_function.transition_graph)
base_addr = project.loader.main_object.mapped_base >> 12 << 12
prologue_node = None for node in supergraph.nodes(): if supergraph.in_degree(node) == 0: prologue_node = node if supergraph.out_degree(node) == 0 and len(node.out_branches) == 0: retn_node = node
if prologue_node is None or prologue_node.addr != start: print("Something must be wrong...") sys.exit(-1) main_dispatcher_node = list(supergraph.successors(prologue_node))[0] for node in supergraph.predecessors(main_dispatcher_node): if node.addr != prologue_node.addr: pre_dispatcher_node = node break relevant_nodes, nop_nodes = get_relevant_nop_nodes( supergraph, pre_dispatcher_node, prologue_node, retn_node) print('*******************relevant blocks************************') print('prologue: %#x' % start) print('main_dispatcher: %#x' % main_dispatcher_node.addr) print('pre_dispatcher: %#x' % pre_dispatcher_node.addr) print('retn: %#x' % retn_node.addr) relevant_block_addrs = [node.addr for node in relevant_nodes] print('relevant_blocks:', [hex(addr) for addr in relevant_block_addrs]) print('*******************symbolic execution*********************') relevants = relevant_nodes relevants.append(prologue_node) relevants_without_retn = list(relevants) relevants.append(retn_node) relevant_block_addrs.extend([prologue_node.addr, retn_node.addr]) flow = defaultdict(list) patch_instrs = {} for relevant in relevants_without_retn: print('-------------------dse %#x---------------------' % relevant.addr) block = project.factory.block(relevant.addr, size=relevant.size) has_branches = False hook_addrs = set([]) for ins in block.capstone.insns: if project.arch.name in ARCH_X86: if ins.insn.mnemonic.startswith('cmov'): if relevant not in patch_instrs: patch_instrs[relevant] = ins has_branches = True elif ins.insn.mnemonic.startswith('call'): hook_addrs.add(ins.insn.address) elif project.arch.name in ARCH_ARM: if ins.insn.mnemonic != 'mov' and ins.insn.mnemonic.startswith('mov'): if relevant not in patch_instrs: patch_instrs[relevant] = ins has_branches = True elif ins.insn.mnemonic in {'bl', 'blx'}: hook_addrs.add(ins.insn.address) elif project.arch.name in ARCH_ARM64: if ins.insn.mnemonic.startswith('cset'): if relevant not in patch_instrs: patch_instrs[relevant] = ins has_branches = True elif ins.insn.mnemonic in {'bl', 'blr'}: hook_addrs.add(ins.insn.address) if has_branches: tmp_addr = symbolic_execution(project, relevant_block_addrs, relevant.addr, hook_addrs, claripy.BVV(1, 1), True) if tmp_addr is not None: flow[relevant].append(tmp_addr) tmp_addr = symbolic_execution(project, relevant_block_addrs, relevant.addr, hook_addrs, claripy.BVV(0, 1), True) if tmp_addr is not None: flow[relevant].append(tmp_addr) else: tmp_addr = symbolic_execution(project, relevant_block_addrs, relevant.addr, hook_addrs) if tmp_addr is not None: flow[relevant].append(tmp_addr)
print('************************flow******************************') for k, v in flow.items(): print('%#x: ' % k.addr, [hex(child) for child in v])
print('%#x: ' % retn_node.addr, [])
print('************************patch*****************************') with open(filename, 'rb') as origin: origin_data = bytearray(origin.read()) origin_data_len = len(origin_data)
recovery_file = filename + '_recovered' recovery = open(recovery_file, 'wb')
for nop_node in nop_nodes: fill_nop(origin_data, nop_node.addr-base_addr, nop_node.size, project.arch)
for parent, childs in flow.items(): if len(childs) == 1: parent_block = project.factory.block(parent.addr, size=parent.size) last_instr = parent_block.capstone.insns[-1] file_offset = last_instr.address - base_addr if project.arch.name in ARCH_X86: fill_nop(origin_data, file_offset, last_instr.size, project.arch) patch_value = ins_j_jmp_hex_x86(last_instr.address, childs[0], 'jmp') elif project.arch.name in ARCH_ARM: patch_value = ins_b_jmp_hex_arm(last_instr.address, childs[0], 'b') if project.arch.memory_endness == "Iend_BE": patch_value = patch_value[::-1] elif project.arch.name in ARCH_ARM64: if parent.addr == start: file_offset += 4 patch_value = ins_b_jmp_hex_arm64(last_instr.address+4, childs[0], 'b') else: patch_value = ins_b_jmp_hex_arm64(last_instr.address, childs[0], 'b') if project.arch.memory_endness == "Iend_BE": patch_value = patch_value[::-1] patch_instruction(origin_data, file_offset, patch_value) else: instr = patch_instrs[parent] file_offset = instr.address - base_addr fill_nop(origin_data, file_offset, parent.addr + parent.size - base_addr - file_offset, project.arch) if project.arch.name in ARCH_X86: patch_value = ins_j_jmp_hex_x86(instr.address, childs[0], instr.mnemonic[len('cmov'):]) patch_instruction(origin_data, file_offset, patch_value)
file_offset += 6 patch_value = ins_j_jmp_hex_x86(instr.address+6, childs[1], 'jmp') patch_instruction(origin_data, file_offset, patch_value) elif project.arch.name in ARCH_ARM: bx_cond = 'b' + instr.mnemonic[len('mov'):] patch_value = ins_b_jmp_hex_arm(instr.address, childs[0], bx_cond) if project.arch.memory_endness == 'Iend_BE': patch_value = patch_value[::-1] patch_instruction(origin_data, file_offset, patch_value)
file_offset += 4 patch_value = ins_b_jmp_hex_arm(instr.address+4, childs[1], 'b') if project.arch.memory_endness == 'Iend_BE': patch_value = patch_value[::-1] patch_instruction(origin_data, file_offset, patch_value) elif project.arch.name in ARCH_ARM64: bx_cond = instr.op_str.split(',')[-1].strip() patch_value = ins_b_jmp_hex_arm64(instr.address, childs[0], bx_cond) if project.arch.memory_endness == 'Iend_BE': patch_value = patch_value[::-1] patch_instruction(origin_data, file_offset, patch_value)
file_offset += 4 patch_value = ins_b_jmp_hex_arm64(instr.address+4, childs[1], 'b') if project.arch.memory_endness == 'Iend_BE': patch_value = patch_value[::-1] patch_instruction(origin_data, file_offset, patch_value)
assert len(origin_data) == origin_data_len, "Error: size of data changed!!!" recovery.write(origin_data) recovery.close() print('Successful! The recovered file: %s' % recovery_file)
|