//============================================================================
//  Copyright (c) Kitware, Inc.
//  All rights reserved.
//  See LICENSE.txt for details.
//
//  This software is distributed WITHOUT ANY WARRANTY; without even
//  the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
//  PURPOSE.  See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_worklet_RemoveUnusedPoints_h
#define vtk_m_worklet_RemoveUnusedPoints_h

#include <vtkm/cont/ArrayCopyDevice.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleConstant.h>
#include <vtkm/cont/ArrayHandlePermutation.h>
#include <vtkm/cont/CellSetExplicit.h>
#include <vtkm/cont/UnknownArrayHandle.h>

#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/ScatterCounting.h>
#include <vtkm/worklet/WorkletMapField.h>

namespace vtkm
{
namespace worklet
{

/// A collection of worklets used to identify which points are used by at least
/// one cell and then remove the points that are not used by any cells. The
/// class containing these worklets can be used to manage running these
/// worklets, building new cell sets, and redefine field arrays.
///
class RemoveUnusedPoints
{
public:
  /// A worklet that creates a mask of used points (the first step in removing
  /// unused points). Given an array of point indices (taken from the
  /// connectivity of a CellSetExplicit) and an array mask initialized to 0,
  /// writes a 1 at the index of every point referenced by a cell.
  ///
  struct GeneratePointMask : public vtkm::worklet::WorkletMapField
  {
    using ControlSignature = void(FieldIn pointIndices, WholeArrayInOut pointMask);
    using ExecutionSignature = void(_1, _2);

    template <typename PointMaskPortalType>
    VTKM_EXEC void operator()(vtkm::Id pointIndex, const PointMaskPortalType& pointMask) const
    {
      pointMask.Set(pointIndex, 1);
    }
  };

  /// A worklet that takes an array of point indices (taken from the
  /// connectivity of a CellSetExplicit) and an array that functions as a map
  /// from the original indices to new indices, creates a new array with the
  /// new mapped indices.
  ///
  struct TransformPointIndices : public vtkm::worklet::WorkletMapField
  {
    using ControlSignature = void(FieldIn pointIndex, WholeArrayIn indexMap, FieldOut mappedPoints);
    using ExecutionSignature = _3(_1, _2);

    template <typename IndexMapPortalType>
    VTKM_EXEC vtkm::Id operator()(vtkm::Id pointIndex, const IndexMapPortalType& indexPortal) const
    {
      return indexPortal.Get(pointIndex);
    }
  };

public:
  VTKM_CONT
  RemoveUnusedPoints() = default;

  template <typename ShapeStorage, typename ConnectivityStorage, typename OffsetsStorage>
  VTKM_CONT explicit RemoveUnusedPoints(
    const vtkm::cont::CellSetExplicit<ShapeStorage, ConnectivityStorage, OffsetsStorage>& inCellSet)
  {
    this->FindPointsStart();
    this->FindPoints(inCellSet);
    this->FindPointsEnd();
  }

  /// Get this class ready for identifying the points used by cell sets.
  ///
  VTKM_CONT void FindPointsStart() { this->MaskArray.ReleaseResources(); }

  /// Analyze the given cell set to find all points that are used. Unused
  /// points are those that are not found in any cell sets passed to this
  /// method.
  ///
  template <typename ShapeStorage, typename ConnectivityStorage, typename OffsetsStorage>
  VTKM_CONT void FindPoints(
    const vtkm::cont::CellSetExplicit<ShapeStorage, ConnectivityStorage, OffsetsStorage>& inCellSet)
  {
    if (this->MaskArray.GetNumberOfValues() < 1)
    {
      // Initialize mask array to 0.
      this->MaskArray.AllocateAndFill(inCellSet.GetNumberOfPoints(), 0);
    }
    VTKM_ASSERT(this->MaskArray.GetNumberOfValues() == inCellSet.GetNumberOfPoints());

    vtkm::worklet::DispatcherMapField<GeneratePointMask> dispatcher;
    dispatcher.Invoke(inCellSet.GetConnectivityArray(vtkm::TopologyElementTagCell(),
                                                     vtkm::TopologyElementTagPoint()),
                      this->MaskArray);
  }

  /// Compile the information collected from calls to \c FindPointsInCellSet to
  /// ready this class for mapping cell sets and fields.
  ///
  VTKM_CONT void FindPointsEnd()
  {
    this->PointScatter.reset(new vtkm::worklet::ScatterCounting(this->MaskArray, true));

    this->MaskArray.ReleaseResources();
  }

  /// \brief Map cell indices
  ///
  /// Given a cell set (typically the same one passed to the constructor)
  /// returns a new cell set with cell points transformed to use the indices of
  /// the new reduced point arrays.
  ///
  template <typename ShapeStorage, typename ConnectivityStorage, typename OffsetsStorage>
  VTKM_CONT
    vtkm::cont::CellSetExplicit<ShapeStorage, VTKM_DEFAULT_CONNECTIVITY_STORAGE_TAG, OffsetsStorage>
    MapCellSet(const vtkm::cont::CellSetExplicit<ShapeStorage, ConnectivityStorage, OffsetsStorage>&
                 inCellSet) const
  {
    VTKM_ASSERT(this->PointScatter);

    return MapCellSet(inCellSet,
                      this->PointScatter->GetInputToOutputMap(),
                      this->PointScatter->GetOutputToInputMap().GetNumberOfValues());
  }

  /// \brief Map cell indices
  ///
  /// Given a cell set (typically the same one passed to the constructor) and
  /// an array that maps point indices from an old set of indices to a new set,
  /// returns a new cell set with cell points transformed to use the indices of
  /// the new reduced point arrays.
  ///
  /// This helper method can be used by external items that do similar operations
  /// that remove points or otherwise rearange points in a cell set. If points
  /// were removed by calling \c FindPoints, then you should use the other form
  /// of \c MapCellSet.
  ///
  template <typename ShapeStorage,
            typename ConnectivityStorage,
            typename OffsetsStorage,
            typename MapStorage>
  VTKM_CONT static vtkm::cont::CellSetExplicit<ShapeStorage,
                                               VTKM_DEFAULT_CONNECTIVITY_STORAGE_TAG,
                                               OffsetsStorage>
  MapCellSet(
    const vtkm::cont::CellSetExplicit<ShapeStorage, ConnectivityStorage, OffsetsStorage>& inCellSet,
    const vtkm::cont::ArrayHandle<vtkm::Id, MapStorage>& inputToOutputPointMap,
    vtkm::Id numberOfPoints)
  {
    using VisitTopology = vtkm::TopologyElementTagCell;
    using IncidentTopology = vtkm::TopologyElementTagPoint;

    using NewConnectivityStorage = VTKM_DEFAULT_CONNECTIVITY_STORAGE_TAG;

    vtkm::cont::ArrayHandle<vtkm::Id, NewConnectivityStorage> newConnectivityArray;

    vtkm::worklet::DispatcherMapField<TransformPointIndices> dispatcher;
    dispatcher.Invoke(inCellSet.GetConnectivityArray(VisitTopology(), IncidentTopology()),
                      inputToOutputPointMap,
                      newConnectivityArray);

    vtkm::cont::CellSetExplicit<ShapeStorage, NewConnectivityStorage, OffsetsStorage> outCellSet;
    outCellSet.Fill(numberOfPoints,
                    inCellSet.GetShapesArray(VisitTopology(), IncidentTopology()),
                    newConnectivityArray,
                    inCellSet.GetOffsetsArray(VisitTopology(), IncidentTopology()));

    return outCellSet;
  }

  /// \brief Returns a permutation map that maps new points to old points.
  ///
  VTKM_CONT vtkm::cont::ArrayHandle<vtkm::Id> GetPermutationArray() const
  {
    return this->PointScatter->GetOutputToInputMap();
  }

private:
  vtkm::cont::ArrayHandle<vtkm::IdComponent> MaskArray;

  /// Manages how the original point indices map to the new point indices.
  ///
  std::shared_ptr<vtkm::worklet::ScatterCounting> PointScatter;
};
}
} // namespace vtkm::worklet

#endif //vtk_m_worklet_RemoveUnusedPoints_h
