package com.almworks.jira.structure.api.util;

import com.almworks.integers.*;
import com.almworks.jira.structure.api.forest.raw.Forest;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.*;
import java.util.function.Supplier;

/**
 * Forest-based indexes for common operations like getting parents, children, subtree or siblings.
 *
 * Thread-unsafe.
 */
public class IndexedForest {
  private final Forest myForest;

  /**
   * mySubtreeEnds[i] contains the index of the end of a subtree rooted at [i]
   */
  private final Supplier<WritableIntList> mySubtreeEnds;

  /**
   * myParents[i] contains the index of the parent of row at [i]
   */
  private final Supplier<WritableIntList> myParents;

  /**
   * Temporary buffer for mySubtreeEnds calculation
   */
  @Nullable
  private WritableIntList mySubtreeStarts;

  /**
   * Temporary buffer for myParents calculation
   */
  @Nullable
  private List<WritableIntList> mySiblingBuffers;

  public IndexedForest(Forest forest) {
    myForest = forest;
    int size = forest.size();
    mySubtreeEnds = createForestIndex(size);
    myParents = createForestIndex(size);
  }

  private static Supplier<WritableIntList> createForestIndex(int size) {
    return StructureUtil.memoize(() -> new IntArray(IntCollections.repeat(-1, size)));
  }

  public int subtreeEnd(int idx) {
    if (idx < 0) return size();

    WritableIntList subtreeEnds = mySubtreeEnds.get();
    int end = subtreeEnds.get(idx);
    if (end < 0) {
      end = calcSubtreeEnds(idx, subtreeEnds);
    }
    assert end >= 0;
    return end;
  }

  private int calcSubtreeEnds(int i, WritableIntList subtreeEnds) {
    WritableIntList subtreeStarts = prepareSubtreeStarts();
    int n = myForest.size();
    int startDepth = myForest.getDepth(i);
    for (i += 1; i < n + 1; ++i) {
      // All depths here are relative to startDepth
      int depth = i < n ? Math.max(myForest.getDepth(i) - startDepth, 0) : 0;
      int lastDepth = subtreeStarts.size();
      if (depth > lastDepth) {
        subtreeStarts.add(i - 1);
      } else if (depth == lastDepth) {
        subtreeEnds.set(i - 1, i);
      } else {
        for (int j = lastDepth - 1; j >= depth; --j) {
          subtreeEnds.set(subtreeStarts.get(j), i);
        }
        subtreeStarts.removeRange(depth, lastDepth);
      }
      if (depth == 0) {
        assert subtreeStarts.isEmpty();
        return i;
      }
    }
    assert false;
    return i + 1;
  }

  @NotNull
  private WritableIntList prepareSubtreeStarts() {
    WritableIntList subtreeStarts = mySubtreeStarts;
    if (subtreeStarts == null) {
      mySubtreeStarts = subtreeStarts = new IntArray();
    } else {
      subtreeStarts.clear();
    }
    return subtreeStarts;
  }

  public int parent(int idx) {
    if (idx < 0) return -1;
    WritableIntList parents = myParents.get();
    if (myForest.getDepth(idx) == 0) return -1;
    int parent = parents.get(idx);
    if (parent < 0) {
      parent = calcParent(idx, parents);
    }
    assert parent >= 0;
    return parent;
  }

  private int calcParent(int idx, WritableIntList parents) {
    List<WritableIntList> siblingBuffers = prepareSiblingBuffers();

    int i = idx;
    int targetDepth = myForest.getDepth(idx);
    int lastRelDepth = -1;
    while (true) {
      int relDepth = myForest.getDepth(i) - targetDepth;
      if (relDepth < lastRelDepth) {
        assert relDepth + 1 == lastRelDepth;
        // Found parent for siblings at lastRelDepth
        for (IntIterator siblingIdx : siblingBuffers.get(lastRelDepth)) {
          parents.set(siblingIdx.value(), i);
        }
      }
      if (relDepth < 0) {
        // Found parent for idx
        break;
      }
      for (int j = lastRelDepth + 1; j <= relDepth; ++j) {
        if (j == siblingBuffers.size()) {
          siblingBuffers.add(new IntArray());
        }
        siblingBuffers.get(j).clear();
      }
      siblingBuffers.get(relDepth).add(i);
      lastRelDepth = relDepth;
      int parent = parents.get(i);
      if (parent >= 0) {
        // Jump
        i = parent;
      } else {
        // Scroll
        i -= 1;
      }
    }

    return parents.get(idx);
  }

  @NotNull
  private List<WritableIntList> prepareSiblingBuffers() {
    List<WritableIntList> siblingBuffers = mySiblingBuffers;
    if (siblingBuffers == null) {
      mySiblingBuffers = siblingBuffers = new ArrayList<>(4);
    }
    return siblingBuffers;
  }


  /**
   * Returns an iterator that provides indexes of all direct children at index idx.
   * If idx is -1, returns roots.
   */
  @NotNull
  public IntIterator children(int idx) {
    int first = firstChild(idx);
    if (first < 0) return IntIterator.EMPTY;
    return new IntFindingIterator() {
      private int myPending = first;

      @Override
      protected boolean findNext() throws ConcurrentModificationException {
        if (myPending < 0) {
          return false;
        }
        myNext = myPending;
        myPending = nextSibling(myPending);
        return true;
      }
    };
  }

  @NotNull
  public IntIterator roots() {
    return children(-1);
  }

  /**
   * Returns a list of children rows at index idx. If idx = -1, returns a list of roots.
   */
  public LongList childrenRows(int idx) {
    IntIterator children = children(idx);
    if (!children.hasNext()) return LongList.EMPTY;
    return new LongArray(rows(children));
  }

  /**
   * Returns the index of the first child under [idx]. If idx is equal to -1, returns the index of the first root (0).
   */
  public int firstChild(int idx) {
    int size = size();
    if (idx == -1) {
      // super-root
      return size > 0 ? 0 : -1;
    }
    if (idx < 0 || idx >= size) return -1;
    int k = idx + 1;
    return k < size && depth(k) == depth(idx) + 1 ? k : -1;
  }

  public int nextSibling(int idx) {
    int size = size();
    if (idx < 0 || idx >= size) return -1;
    int k = subtreeEnd(idx);
    return k < size && depth(k) == depth(idx) ? k : -1;
  }

  public int depth(int idx) {
    return myForest.getDepth(idx);
  }

  public long row(int idx) {
    return myForest.getRow(idx);
  }

  public LongIterator rows(final IntIterator indices) {
    return new LongFindingIterator() {
      @Override
      protected boolean findNext() {
        if (indices.hasNext()) {
          myNext = row(indices.nextValue());
          return true;
        }
        return false;
      }
    };
  }

  public int size() {
    return myForest.size();
  }

  public Forest getForest() {
    return myForest;
  }

  public int root(int idx) {
    while (idx >= 0 && depth(idx) > 0) {
      idx = parent(idx);
    }
    return idx;
  }
}
