from __future__ import annotations

import importlib
import importlib.abc
import sys
import io
import threading
from typing import Optional, List

_CURRENT_MODULE_SPACE_THREADLOCAL = threading.local()

class ModuleSpaceFinder(importlib.abc.MetaPathFinder):
    '''
    Finds modules in the current `ModuleSpace`'s `path` and loads them into
    its `modules`.
    '''

    def find_spec(self, fullname, path, target=None):
        module_space = ModuleSpace.current
        if module_space is None:
            return None

        found = None
        for finder in module_space.finders():
            found = finder.find_spec(fullname, path)
            if found is not None:
                return found

class ModuleDictShim(dict):
    '''
    Redirects queries for a module to the current `ModuleSpace` before falling back on the original global
    `sys.modules` dictionary.
    '''

    def __init__(self, inner):
        self.inner = inner

    def __getitem__(self, key):
        current_space = ModuleSpace.current
        if current_space is not None:
            try:
                return current_space.modules[key]
            except KeyError:
                pass

        return self.inner[key]

    def __setitem__(self, key, value):
        # If the key already exists in the current `ModuleSpace`, prefer putting it there
        # See `ModuleDictShim.pop` for an explanation of why we need to do this.
        current_space = ModuleSpace.current
        if current_space is not None:
            if key in current_space.modules or key in current_space._previously_loaded_modules:
                current_space.modules[key] = value
                return

        self.inner[key] = value

    def __contains__(self, key):
        current_space = ModuleSpace.current
        if current_space is not None:
            if key in current_space.modules:
                return True

        return key in self.inner

    def __iter__(self):
        current_space = ModuleSpace.current
        if current_space is not None:
            for k in current_space.modules:
                yield k

        for k in self.inner:
            yield k

    def get(self, key, default=None):
        current_space = ModuleSpace.current
        if current_space is not None:
            if key in current_space.modules:
                return current_space.modules[key]

        return self.inner.get(key, default)

    # `importlib._bootstrap._find_and_load` uses this to rearrange module order after import. This creates a problem
    # for us, because we don't want it to reinsert modules into `sys.modules` right after we've loaded them into the
    # current `ModuleSpace` (effectively making `ModuleSpace` useless). To combat this, `pop`ing a module that exists
    # in the current `ModuleSpace` adds the key to the `ModuleSpace._previously_loaded_modules` set, which `__setitem__`
    # uses to determine if we should set the entry in `ModuleSpace.modules` instead of the inner `sys.modules`.
    def pop(self, key):
        current_space = ModuleSpace.current
        if current_space is not None:
            if key in current_space.modules:
                value = current_space.modules.pop(key)
                current_space._previously_loaded_modules.add(key)
                return value

        return self.inner.pop(key)

    def __delitem__(self, key):
        current_space = ModuleSpace.current
        if current_space is not None:
            if key in current_space.modules:
                del current_space.modules[key]
                current_space._previously_loaded_modules.add(key)
                return

        del self.inner[key]

    def clear(self):
        current_space = ModuleSpace.current
        if current_space is not None:
            current_space.modules.clear()

        self.inner.clear()

    # Haven't implemented these methods because we haven't encountered any code that tries to use them on
    # `sys.modules`... yet
    def items(self, *args, **kwargs):
        raise NotImplementedError
    def keys(self, *args, **kwargs):
        raise NotImplementedError()
    def popitem(self, *args, **kwargs):
        raise NotImplementedError()
    def setdefault(self, *args, **kwargs):
        raise NotImplementedError()
    def update(self, *args, **kwargs):
        raise NotImplementedError()
    def values(self, *args, **kwargs):
        raise NotImplementedError()
    def __len__(self, *args, **kwargs):
        raise NotImplementedError()
    def __eq__(self, *args, **kwargs):
        raise NotImplementedError()
    def __reduce__(self, *args, **kwargs):
        raise NotImplementedError()

# Install module space finder at head of import resolution
sys.meta_path.insert(0, ModuleSpaceFinder())
# Replace sys.modules with our wrapper. Note that there is a stored pointer
# to the original `sys.modules` in the CPython interpreter state and this does not replace that value. That
# is handled in `PythonManager::PythonManager()`.
sys.modules = ModuleDictShim(sys.modules)

class ModuleSpaceMeta(type):
    @property
    def current(self) -> Optional[ModuleSpace]:
        '''
        The `ModuleSpace` current available for resolution, if any.
        '''

        stack = self.stack
        return stack[-1] if stack else None

    @property
    def stack(self) -> List[ModuleSpace]:
        '''
        The stack of `ModuleSpace`s that have been entered. The last one is the current `ModuleSpace`.
        '''

        try:
            return _CURRENT_MODULE_SPACE_THREADLOCAL.module_space_stack
        except AttributeError:
            new_stack = []
            _CURRENT_MODULE_SPACE_THREADLOCAL.module_space_stack = new_stack
            return new_stack

class ModuleSpace(metaclass=ModuleSpaceMeta):
    '''
    A set of paths and loaded modules from those paths that is made temporarily available for
    import resolution.  When a `ModuleSpace` is not registered as the current `ModuleSpace`, its
    imports will not be available. This allows different parts of the program to have different copies
    of modules at the same name (which may or may not come from the same source file).
    '''

    def __init__(self, msg_callback=None):
        # The search path (namespaced equivalent of sys.path)
        self.path = []
        # The module cache (namespaced equivalent of sys.modules)
        self.modules = {}
        # See `ModuleDictShim.pop` for an explanation of this value
        self._previously_loaded_modules = set()

        self._finders = {}

        self._msg_callback = msg_callback

    def __enter__(self):
        ModuleSpace.stack.append(self)
    
    def __exit__(self, ty, value, traceback):
        ModuleSpace.stack.pop()

    def write_stdout(self, message):
        if self._msg_callback is not None:
            self._msg_callback(message)

    def finders(self):
        '''
        Provides a method for iterating over a list of caching finders for each path in the
        `ModuleSpace`. These are enumerated in the `ModuleSpace.path` order.
        '''

        for p in self.path:
            try:
                yield self._finders[p]
            except KeyError:
                finder = importlib.machinery.FileFinder(p, (self._make_loader, ('.py',)))
                self._finders[p] = finder
                yield finder

    def _make_loader(self, fullname, found_path):
        return ModuleSpaceLoader(self, fullname, found_path)


class ModuleSpaceLoader:
    '''
    Loads a module into a particular `ModuleSpace` instead of the global module dictionary.
    '''

    def __init__(self, module_space, fullname, found_path):
        self._inner_loader = importlib.machinery.SourceFileLoader(fullname, found_path)
        self.module_space = module_space
        self._found_path = found_path

    # We don't use the newer `create_module`/`exec_module` combination because we want to
    # override the handling of `sys.modules`; specifically we want to NOT use `sys.modules`
    # and instead keep the modules scoped to our `ModuleSpace.modules`.
    def load_module(self, name):
        # This is largely taken from `importlib.util._module_to_load`
        is_reload = name in self.module_space.modules

        module = self.module_space.modules.get(name)
        if not is_reload:
            # This must be done before open() is called as the 'io' module
            # implicitly imports 'locale' and would otherwise trigger an
            # infinite loop.
            module = type(sys)(name)
            # This must be done before putting the module in sys.modules
            # (otherwise an optimization shortcut in import.c becomes wrong)
            module.__initializing__ = True
            module.__file__ = self._found_path
            self.module_space.modules[name] = module
        try:
            self._inner_loader.exec_module(module)
        except Exception:
            if not is_reload:
                try:
                    del self.module_space.modules[name]
                except KeyError:
                    pass
            raise
        finally:
            module.__initializing__ = False


# good place to hijack stdout.
class RoutedIO(io.TextIOBase):
    def __init__(self, original_stdout):
        self._original_stdout = original_stdout
    def write(self, message):
        current_space = ModuleSpace.current
        if current_space is not None:
            return current_space.write_stdout(message)
        else:
            return self._original_stdout.write(message)

old_stdout = sys.stdout
new_stdout = RoutedIO(old_stdout)
sys.stdout = new_stdout
