# -*- coding: utf-8 -*-
"""Curses output management class."""

import dataclasses
from enum import Enum
import io
import textwrap
import threading
import traceback
from typing import Any, Callable, Dict, List, Optional, Union

import curses
import shutil
import signal


class Status(Enum):
  """Enum class for module states.

  The order here is important, as it's the order modules will be displayed."""
  COMPLETED = 'Completed'
  SETTINGUP = 'Setting Up'
  ERROR = 'Error'
  RUNNING = 'Running'
  PREPROCESSING = 'Preprocessing'
  PROCESSING = 'Processing'
  POSTPROCESSING = 'Postprocessing'
  PENDING = 'Pending'
  CANCELLED = 'Cancelled'


@dataclasses.dataclass
class _ModuleThread:
  status: Status
  container: str
  progress: Optional[str] = None  # Of the form 'XX.X%'


class Module:
  """An object used by the CursesDisplayManager used to represent a DFTW module.
  """
  def __init__(self,
               name: str,
               dependencies: List[str],
               runtime_name: Optional[str] = None):
    """Initialize the Module object.

    Args:
      name: The module name of this module.
      dependencies: A list of Runtime names that this module is blocked on.
      runtime_name: The runtime name of this module.
    """
    self.name = name
    self.runtime_name = runtime_name if runtime_name else name
    self.status: Status = Status.PENDING
    self._dependencies: List[str] = dependencies
    self._error_message: str = ''
    self._threads: Dict[str, _ModuleThread] = {}
    self._threads_containers_max: int = 0
    self._threads_containers_completed: int = 0
    self._progress: Optional[str] = None

  def Stringify(self) -> List[str]:
    """Returns an CursesDisplayManager friendly string describing the module."""
    progress = f' {self._progress}' if self._progress else ''
    module_line = f'     {self.runtime_name}: {self.status.value}{progress}'
    thread_lines = []

    if self.status == Status.PENDING and len(self._dependencies) != 0:
      module_line += f' ({", ".join(self._dependencies)})'
    elif self.status == Status.ERROR:
      module_line += f': {self._error_message}'
    elif self.status in [Status.RUNNING, Status.PROCESSING] and self._threads:
      module_line += (f' - {self._threads_containers_completed} of '
          f'{self._threads_containers_max} containers completed')
      for n, t in self._threads.items():
        progress = f'{t.progress} ' if t.progress else ''
        thread_lines.append(
            f'       {n}: {t.status.value} {progress}({t.container})')

    return [module_line] + thread_lines

  def SetStatus(self, status: Status) -> None:
    """Set the status of the module.

    Args:
      status: The status to set this module to."""
    if self.status not in [Status.ERROR, Status.COMPLETED, Status.CANCELLED]:
      self.status = status

  def SetThreadState(self, thread: str, status: Status, container: str) -> None:
    """Set the state of a thread within a threaded module.

    Args:
      thread: The name of this thread (eg ThreadPoolExecutor-0_5).
      status: The current status of the thread.
      container: The name of the container the thread is currently processing.
    """
    self._threads[thread] = _ModuleThread(status, container)
    if status == Status.COMPLETED:
      self._threads_containers_completed += 1

  def SetProgress(self,
                  steps_taken: int,
                  steps_expected: int) -> None:
    """Sets the modules progress values.

    Args:
      steps_taken: The number of steps taken so far.
      steps_expected: The number of total steps expected for completion.
    """
    self._progress = f'{steps_taken / steps_expected * 100:.1f}%'

  def SetThreadProgress(self,
                        thread_id: str,
                        steps_taken: int,
                        steps_expected: int) -> None:
    """Sets a threads progress values.

    Args:
      thread_id: The thread id in question.
      steps_taken: The number of steps taken so far.
      steps_expected: The number of total steps expected for completion.

    Raises:
      ValueError: if thread_id is not being tracked.
    """
    if thread_id not in self._threads:
      raise ValueError(f'{thread_id} not found')

    self._threads[thread_id].progress = (
        f'{steps_taken / steps_expected * 100:.1f}%')

  def SetError(self, message: str) -> None:
    """Sets the error for the module.

    Args:
      message: The error message string."""
    self._error_message = message
    self.status = Status.ERROR

  def SetContainerCount(self, count: int) -> None:
    """Sets the maximum container count for the module.

    Args:
      count: The total number of containers to be processed."""
    self._threads_containers_max = count


class Message:
  """Helper class for managing messages."""

  def __init__(self, source: str, content: str, is_error: bool = False) -> None:
    """Initialize a Message object.

    Args:
      source: The source of the message, eg 'dftimewolf' or a runtime name.
      content: The content of the message.
      is_error: True if the message is an error message, False otherwise."""
    self.source: str = source
    self.content: str = content
    self.is_error: bool = is_error

  def Stringify(self, source_len: int = 0, colorize: bool = False) -> str:
    """Returns an CursesDisplayManager friendly string of the Message.

    Args:
      source_len: The longest source length; used to unify the formatting of
          messages.
      colorize: True if colors should be used.

    Returns:
      A string representation of the Message.
    """
    pad = (len(self.source) if len(self.source) > source_len
        else source_len)

    color_code = '\u001b[31;1m' if self.is_error and colorize else ''
    reset_code = '\u001b[0m' if self.is_error and colorize else ''

    return f'[ {self.source:{pad}} ] {color_code}{self.content}{reset_code}'


class CursesDisplayManager:
  """Handles the curses based console output, based on information passed in.
  """

  def __init__(self) -> None:
    """Intializes the CursesDisplayManager."""
    self._recipe_name: str = ''
    self._exception: Union[Exception, None] = None
    self._preflights: Dict[str, Module] = {}
    self._modules: Dict[str, Module] = {}
    self._messages: List[Message] = []
    self._messages_longest_source_len: int = 0
    self._lock = threading.Lock()
    self._stdscr: curses.window = None  # type: ignore

  def StartCurses(self) -> None:
    """Start the curses display."""
    self._stdscr = curses.initscr()
    curses.noecho()
    curses.cbreak()
    self._stdscr.keypad(True)
    signal.signal(signal.SIGWINCH, self.SIGWINCH_Handler)

  def EndCurses(self) -> None:
    """Curses finalisation actions."""
    if True in [m.is_error for m in self._messages] or self._exception:
      self.Pause()

    curses.nocbreak()
    self._stdscr.keypad(False)
    curses.echo()
    curses.endwin()

  def SetRecipe(self, recipe: str) -> None:
    """Set the recipe name.

    Args:
      recipe: The recipe name"""
    self._recipe_name = recipe

  def SetException(self, e: Exception) -> None:
    """Set an Exception to be included in the display.

    Args:
      e: The exception object."""
    self._exception = e

  def SetError(self, module: str, message: str) -> None:
    """Sets the error state ane message for a module.

    Args:
      module: The module name generating the error.
      message: The error message content."""
    if module in self._preflights:
      self._preflights[module].SetError(message)
    if module in self._modules:
      self._modules[module].SetError(message)

    self.EnqueueMessage(module, message, True)

    self.Draw()

  def EnqueueMessage(self,
                     source: str,
                     content: str,
                     is_error: bool = False) -> None:
    """Enqueue a message to be displayed.

    Args:
      source: The source of the message, eg 'dftimewolf' or a runtime name.
      content: The message content.
      is_error: True if the message is an error message, False otherwise."""
    if self._messages_longest_source_len < len(source):
      self._messages_longest_source_len = len(source)

    for line in content.split('\n'):
      if line:
        self._messages.append(Message(source, line, is_error))

    self.Draw()

  def PrepareMessagesForDisplay(self, available_lines: int) -> List[str]:
    """Prepares the list of messages to be displayed.

    Args:
      available_lines: The number of lines available to print messages.

    Returns:
      A list of strings, formatted for display."""
    _, x = self._stdscr.getmaxyx()

    lines = []
    width = max(x - self._messages_longest_source_len - 8, 1)

    for m in self._messages:
      lines.extend(
        textwrap.wrap(m.Stringify(self._messages_longest_source_len),
                      width=width,
                      initial_indent='  ', subsequent_indent='    ',
                      replace_whitespace=False, break_long_words=False))

    return lines[-available_lines:]

  def EnqueuePreflight(self,
                       name: str,
                       dependencies: List[str],
                       runtime_name: Optional[str]) -> None:
    """Enqueue a preflight module object for display.

    Args:
      name: The name of the preflight module.
      dependencies: runtime names of blocking modules.
      runtime_name: the runtime name of the preflight module."""
    m = Module(name, dependencies, runtime_name)
    self._preflights[m.runtime_name] = m

  def EnqueueModule(self,
                    name: str,
                    dependencies: List[str],
                    runtime_name: Optional[str]) -> None:
    """Enqueue a module object for display.

    Args:
      name: The name of the module.
      dependencies: runtime names of blocking modules.
      runtime_name: the runtime name of the module."""
    m = Module(name, dependencies, runtime_name)
    self._modules[m.runtime_name] = m

  def UpdateModuleStatus(self, module: str, status: Status) -> None:
    """Update the status of a module for display.

    Args:
      module: The runtime name of the module.
      status: the status of the module."""
    if module in self._preflights:
      self._preflights[module].SetStatus(status)
    if module in self._modules:
      self._modules[module].SetStatus(status)

    self.Draw()

  def SetThreadedModuleContainerCount(self, module: str, count: int) -> None:
    """Set the container count that a threaded module will operate on.

    Args:
      module: The runtime name of the threaded module.
      count: The total number of containers the module will process."""
    if module in self._preflights:
      self._preflights[module].SetContainerCount(count)
    if module in self._modules:
      self._modules[module].SetContainerCount(count)

  def UpdateModuleThreadState(self,
                              module: str,
                              status: Status,
                              thread: str,
                              container: str) -> None:
    """Update the state of a thread within a threaded module for display.

    Args:
      module: The runtime name of the module.
      status: The status of the thread.
      thread: The name of the thread, eg 'ThreadPoolExecutor-0_0'.
      container: The name of the container being processed."""
    if module in self._preflights:
      self._preflights[module].SetThreadState(thread, status, container)
    if module in self._modules:
      self._modules[module].SetThreadState(thread, status, container)

    self.Draw()

  def SetModuleProgress(self,
                        module_name: str,
                        steps_taken: int,
                        steps_expected: int) -> None:
    """Sets the progress values for a module.

    Args:
      module_name: The module in question.
      steps_taken: The number of steps taken so far.
      steps_expected: The number of total steps expected for completion.

    Raises:
      ValueError: If module_name or thread_id is not being tracked.
    """
    if module_name not in self._modules:
      raise ValueError(f'{module_name} not found')

    self._modules[module_name].SetProgress(steps_taken, steps_expected)

    self.Draw()

  def SetModuleThreadProgress(self,
                              module_name: str,
                              thread_id: str,
                              steps_taken: int,
                              steps_expected: int) -> None:
    """Sets the thread progress values for a processing thread in a module.

    Args:
      module_name: The module in question.
      thread_id: The thread id in question.
      steps_taken: The number of steps taken so far.
      steps_expected: The number of total steps expected for completion.

    Raises:
      ValueError: If module_name or thread_id is not being tracked.
    """
    if module_name not in self._modules:
      raise ValueError(f'{module_name} not found')

    self._modules[module_name].SetThreadProgress(
        thread_id, steps_taken, steps_expected)
    self.Draw()

  def Draw(self) -> None:
    """Draws the window."""
    if not self._stdscr:
      return

    with self._lock:
      self._stdscr.clear()
      y, x = self._stdscr.getmaxyx()

      try:
        curr_line = 0
        self._stdscr.addstr(curr_line, 0, f' {self._recipe_name}'[:x])
        curr_line += 1

        # Preflights
        if self._preflights:
          self._stdscr.addstr(curr_line, 0, '   Preflights:'[:x])
          curr_line += 1
          for _, module in self._preflights.items():
            for line in module.Stringify():
              self._stdscr.addstr(curr_line, 0, line[:x])
              curr_line += 1

        # Modules
        self._stdscr.addstr(curr_line, 0, '   Modules:'[:x])
        curr_line += 1
        for status in Status:  # Print the modules in Status order
          for _, module in self._modules.items():
            if module.status != status:
              continue
            for line in module.Stringify():
              self._stdscr.addstr(curr_line, 0, line[:x])
              curr_line += 1

        # Messages
        curr_line += 1
        self._stdscr.addstr(curr_line, 0, ' Messages:'[:x])
        curr_line += 1

        message_space = y - 4 - curr_line
        for m in self.PrepareMessagesForDisplay(message_space):
          self._stdscr.addstr(curr_line, 0, m[:x])
          curr_line += 1

        # Exceptions
        if self._exception:
          self._stdscr.addstr(y - 2, 0,
              f' Exception encountered: {str(self._exception)}'[:x])
      except curses.error:
        # pylint: disable=line-too-long
        self._stdscr.addstr(y - 3, 0, '*********************************************************************** '[:x])
        self._stdscr.addstr(y - 2, 0, '*** Terminal not large enough, consider increasing your window size *** '[:x])
        self._stdscr.addstr(y - 1, 0, '*********************************************************************** '[:x])
        # pylint: enable=line-too-long

      self._stdscr.move(y - 1, 0)
      self._stdscr.refresh()

  def PrintMessages(self) -> None:
    """Dump all messages to stdout. Intended to be used when exiting, after
    calling EndCurses()."""

    if self._messages:
      print('Messages')
      for m in self._messages:
        print(f'  {m.Stringify(self._messages_longest_source_len, True)}')

    if self._exception:
      print('\nException encountered during execution:')
      print(''.join(traceback.format_exception(None,
                                self._exception,
                                self._exception.__traceback__)))

  def Pause(self) -> None:
    """Ask the user to press any key to continue."""
    with self._lock:
      y, _ = self._stdscr.getmaxyx()

      self._stdscr.addstr(y - 1, 0, "Press any key to continue")
      self._stdscr.getkey()
      self._stdscr.addstr(y - 1, 0, "                         ")

  def SIGWINCH_Handler(self, *unused_argvs: Any) -> None:
    """Redraw the window when SIGWINCH is raised."""
    curses.resizeterm(*shutil.get_terminal_size())
    self.Draw()


class CDMStringIOWrapper(io.StringIO):
  """Subclass of io.StringIO, adds a callback to write().

  This wrapped IO class will send lines it receives via write() to the callback,
  intended to be EnqueueMessage of a CDM instance."""
  def __init__(self,
               source: str,
               is_error: bool,
               callback: Callable[[str, str, bool], None]) -> None:
    """Initialise the object."""
    super(CDMStringIOWrapper, self).__init__()
    self._source = source
    self._is_error = is_error
    self._callback = callback

  def write(self, s: str) -> int:
    """Writes the bytes to the internal buffer, and uses the callback on any
    lines received."""
    for line in s.split('\n'):
      if line != '':
        self._callback(self._source, line, self._is_error)

    return super(CDMStringIOWrapper, self).write(s)
