# Copyright (c) 2014, Fundacion Dr. Manuel Sadosky
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:

# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.

# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.

# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""
This module contains two basic classes for gadgets processing: RawGadget
and TypedGadgets. The first is used to describe the gadgets found by
GadgetFinder. These are candidate gadgets as they are not validated yet.
However, they contains (the RawGadget object) the assembly code as well
as its REIL representation. One given gadgets can be classified in one or
more gadgets type. At this point, a TypedGadget object is created for
each classified type and the RawGadget object is associated with them.

"""
from __future__ import absolute_import

from barf.core.reil import ReilEmptyOperand
from barf.core.reil import ReilImmediateOperand


class RawGadget(object):

    """Represent a gadgets as a list of instructions.
    """

    __slots__ = [
        '_instrs',
        '_id',
    ]

    def __init__(self, instrs):

        # List of instructions.
        self._instrs = instrs

        # Id of gadgets.
        self._id = None

    @property
    def address(self):
        """Get gadgets start address.
        """
        return self._instrs[0].address

    @property
    def instrs(self):
        """Get gadgets instructions.
        """
        return self._instrs

    @property
    def ir_instrs(self):
        """Get gadgets IR instructions.
        """
        ir_instrs = []

        for asm_instr in self._instrs:
            ir_instrs += asm_instr.ir_instrs

        return ir_instrs

    @property
    def id(self):
        """Get gadgets validity status.
        """
        return self._id

    @id.setter
    def id(self, value):
        """Set gadgets validity status.
        """
        self._id = value

    def __str__(self):
        lines = []

        for asm_instr in self._instrs:
            lines += ["0x%08x : %s" % (asm_instr.address, asm_instr)]

            for ir_instr in asm_instr.ir_instrs:
                lines += ["  %s" % ir_instr]

        return "\n".join(lines)


class TypedGadget(RawGadget):

    """Represents a gadgets with its semantic classification.
    """

    __slots__ = [
        '_gadget',
        '_sources',
        '_destination',
        '_modified_regs',
        '_gadget_type',
        '_verified',
        '_is_valid',
        '_operation',
    ]

    def __init__(self, gadget, gadget_type, instrs):
        super(TypedGadget, self).__init__(instrs)

        # A raw gadgets.
        self._gadget = gadget

        # A list of sources.
        self._sources = []

        # A list of destinations.
        self._destination = []

        # A list of registers that are modified after gadgets execution.
        self._modified_regs = []

        # Type of the gadgets.
        self._gadget_type = gadget_type

        # Verification flag.
        self._verified = False

        # If the gadgets was verified and it turned out to be correctly
        # classifies, this flags is True. Otherwise, is False.
        self._is_valid = None

        # Operation computed by the gadgets.
        self._operation = None

    # Properties
    # ======================================================================== #
    @property
    def sources(self):
        """Get gadgets sources.
        """
        return self._sources

    @sources.setter
    def sources(self, value):
        """Set gadgets sources.
        """
        self._sources = value

    @property
    def destination(self):
        """Get gadgets destination.
        """
        return self._destination

    @destination.setter
    def destination(self, value):
        """Set gadgets destination.
        """
        self._destination = value

    @property
    def modified_registers(self):
        """Get gadgets modified registers.
        """
        return self._modified_regs

    @modified_registers.setter
    def modified_registers(self, value):
        """Set gadgets modified registers.
        """
        self._modified_regs = value

    @property
    def verified(self):
        """Get gadgets verification status.
        """
        return self._verified

    @verified.setter
    def verified(self, value):
        """Set gadgets verification status.
        """
        self._verified = value

    @property
    def is_valid(self):
        """Get gadgets validity status.
        """
        if not self._verified:
            raise Exception("Typed Gadget not Verified!")

        return self._is_valid

    @is_valid.setter
    def is_valid(self, value):
        """Set gadgets validity status.
        """
        self._verified = True
        self._is_valid = value

    @property
    def type(self):
        """Get gadgets type.
        """
        return self._gadget_type

    @property
    def operation(self):
        """Get gadgets operation.
        """
        return self._operation

    @operation.setter
    def operation(self, value):
        """Set gadgets operation.
        """
        self._operation = value

    def __key(self):
        return (self._gadget,
                self._sources,
                self._destination,
                self._modified_regs,
                self._gadget_type,
                self._verified,
                self._is_valid,
                self._operation)

    def __str__(self):
        strings = {
            GadgetType.NoOperation:     dump_no_operation,
            GadgetType.Jump:            dump_jump,
            GadgetType.MoveRegister:    dump_move_register,
            GadgetType.LoadConstant:    dump_load_constant,
            GadgetType.Arithmetic:      dump_arithmetic,
            GadgetType.LoadMemory:      dump_load_memory,
            GadgetType.StoreMemory:     dump_store_memory,
            GadgetType.ArithmeticLoad:  dump_arithmetic_load,
            GadgetType.ArithmeticStore: dump_arithmetic_store,
            GadgetType.Undefined:       dump_undefined,
        }

        return strings[self._gadget_type](self)

    def __eq__(self, other):
        """Return self == other."""
        if type(other) is type(self):
            same_sources = self._sources == other._sources
            same_destination = self._destination == other._destination
            same_modified = self._modified_regs == other._modified_regs
            same_operation = self._operation == other._operation

            return same_sources and same_destination and same_modified and \
                same_operation
        else:
            return False

    def __ne__(self, other):
        """Return self != other."""
        return not self.__eq__(other)

    def __hash__(self):
        return hash(self.__key())

    # Misc
    # ======================================================================== #
    def __getattr__(self, name):
        return getattr(self._gadget, name)


class GadgetType(object):

    """Enumeration of Gadget Types.
    """

    NoOperation     = 0
    Jump            = 1
    MoveRegister    = 2
    LoadConstant    = 3
    Arithmetic      = 4
    LoadMemory      = 5
    StoreMemory     = 6
    ArithmeticLoad  = 7
    ArithmeticStore = 8
    Undefined       = 9

    @staticmethod
    def to_string(gadget_type):
        strings = {
            GadgetType.NoOperation:     "No Operation",
            GadgetType.Jump:            "Jump",
            GadgetType.MoveRegister:    "Move Register",
            GadgetType.LoadConstant:    "Load Constant",
            GadgetType.Arithmetic:      "Arithmetic",
            GadgetType.LoadMemory:      "Load Memory",
            GadgetType.StoreMemory:     "Store Memory",
            GadgetType.ArithmeticLoad:  "Arithmetic Load",
            GadgetType.ArithmeticStore: "Arithmetic Store",
            GadgetType.Undefined:       "Undefined",
        }

        return strings[gadget_type]


# Gadget dump functions
# ============================================================================ #
def dump_no_operation(gadget):
    return "nop <- nop > {}"


def dump_jump(gadget):
    return "NOT SUPPORTED YET!"


def dump_move_register(gadget):
    fmt = "%s <- %s > {%s}"

    mod_regs = [r for r in gadget.modified_registers]

    fmt_params = (str(gadget.destination[0]), str(gadget.sources[0]), "; ".join(map(str, mod_regs)))

    return fmt % fmt_params


def dump_load_constant(gadget):
    fmt = "%s <- %s > {%s}"

    mod_regs = [r for r in gadget.modified_registers]

    fmt_params = (str(gadget.destination[0]), str(gadget.sources[0]), "; ".join(map(str, mod_regs)))

    return fmt % fmt_params


def dump_arithmetic(gadget):
    fmt = "%s <- %s %s %s > {%s}"

    mod_regs = [r for r in gadget.modified_registers]

    fmt_params = (str(gadget.destination[0]), str(gadget.sources[0]), gadget.operation, str(gadget.sources[1]), "; ".join(map(str, mod_regs)))

    return fmt % fmt_params


def dump_load_memory(gadget):
    fmt = "%s <- mem[%s] > {%s}"

    mod_regs = [r for r in gadget.modified_registers]

    src_regs = []
    for src in gadget.sources:
        if isinstance(src, ReilEmptyOperand):
            continue

        if isinstance(src, ReilImmediateOperand) and src.immediate == 0:
            continue

        src_regs += [src]

    fmt_params = (
        str(gadget.destination[0]),
        " + ".join(map(str, src_regs)),
        "; ".join(map(str, mod_regs))
    )

    return fmt % fmt_params


def dump_store_memory(gadget):
    fmt = "mem[%s] <- %s > {%s}"

    mod_regs = [r for r in gadget.modified_registers]

    dst_regs = []
    for dst in gadget.destination:
        if isinstance(dst, ReilEmptyOperand):
            continue

        if isinstance(dst, ReilImmediateOperand) and dst.immediate == 0:
            continue

        dst_regs += [dst]

    fmt_params = (
        " + ".join(map(str, dst_regs)),
        str(gadget.sources[0]),
        "; ".join(map(str, mod_regs))
    )

    return fmt % fmt_params


def dump_arithmetic_load(gadget):
    fmt = "%s <- %s %s mem[%s] > {%s}"

    mod_regs = [r for r in gadget.modified_registers]

    src_regs = []
    for src in gadget.sources[1:]:
        if isinstance(src, ReilEmptyOperand):
            continue

        if isinstance(src, ReilImmediateOperand) and src.immediate == 0:
            continue

        src_regs += [src]

    fmt_params = (
        str(gadget.destination[0]),
        str(gadget.sources[0]),
        gadget.operation,
        " + ".join(map(str, src_regs)),
        "; ".join(map(str, mod_regs)))

    return fmt % fmt_params


def dump_arithmetic_store(gadget):
    fmt = "mem[%s] <- mem[%s] %s %s > {%s}"

    mod_regs = [r for r in gadget.modified_registers]

    src_regs = []
    for src in gadget.sources[0:2]:
        if isinstance(src, ReilEmptyOperand):
            continue

        if isinstance(src, ReilImmediateOperand) and src.immediate == 0:
            continue

        src_regs += [src]

    dst_regs = []
    for dst in gadget.destination:
        if isinstance(dst, ReilEmptyOperand):
            continue

        if isinstance(dst, ReilImmediateOperand) and dst.immediate == 0:
            continue

        dst_regs += [dst]

    fmt_params = (
        " + ".join(map(str, dst_regs)),
        " + ".join(map(str, src_regs)),
        gadget.operation,
        str(gadget.sources[2]),
        "; ".join(map(str, mod_regs))
    )

    return fmt % fmt_params


def dump_undefined(gadget):
    return "undefined"
