package com.almworks.jira.structure.api.util;

import com.almworks.integers.*;
import com.almworks.jira.structure.api.forest.raw.Forest;

import java.util.ArrayList;
import java.util.List;

public class IndexedForest {
  private final Forest myForest;

  private WritableIntList mySubtreeEnds;
  // Temporary buffer for mySubtreeEnds calculation
  private WritableIntList mySubtreeStarts;
  private WritableIntList myParents;
  // Temporary buffer for myParents calculation
  private List<WritableIntList> mySiblingBuffers;

  public IndexedForest(Forest forest) {
    myForest = forest;
  }
  
  public int subtreeEnd(int idx) {
    if (idx < 0) return size();

    if (mySubtreeEnds == null) {
      mySubtreeEnds = new IntArray(IntCollections.repeat(-1, myForest.size()));
    }
    int end = mySubtreeEnds.get(idx);
    if (end < 0) {
      end = calcSubtreeEnds(idx);
    }
    assert end >= 0;
    return end;
  }

  private int calcSubtreeEnds(int i) {
    if (mySubtreeStarts == null) mySubtreeStarts = new IntArray();
    else mySubtreeStarts.clear();

    int n = myForest.size();
    int startDepth = myForest.getDepth(i);
    for (i += 1; i < n + 1; ++i) {
      // All depths here are relative to startDepth
      int depth = i < n ? Math.max(myForest.getDepth(i) - startDepth, 0) : 0;
      int lastDepth = mySubtreeStarts.size();
      if (depth > lastDepth) {
        mySubtreeStarts.add(i - 1);
      } else if (depth == lastDepth) {
        mySubtreeEnds.set(i - 1, i);
      } else {
        for (int j = lastDepth - 1; j >= depth; --j) {
          mySubtreeEnds.set(mySubtreeStarts.get(j), i);
        }
        mySubtreeStarts.removeRange(depth, lastDepth);
      }
      if (depth == 0) {
        assert mySubtreeStarts.isEmpty();
        return i;
      }
    }
    assert false;
    return i + 1;
  }

  public int parent(int idx) {
    if (myParents == null) {
      myParents = new IntArray(IntCollections.repeat(-1, myForest.size()));
    }
    if (myForest.getDepth(idx) == 0) return -1;
    int parent = myParents.get(idx);
    if (parent < 0) {
      parent = calcParent(idx);
    }
    assert parent >= 0;
    return parent;
  }

  private int calcParent(int idx) {
    if (mySiblingBuffers == null) {
      mySiblingBuffers = new ArrayList<>(4);
    }

    int i = idx;
    int targetDepth = myForest.getDepth(idx);
    int lastRelDepth = -1;
    while (true) {
      int relDepth = myForest.getDepth(i) - targetDepth;
      if (relDepth < lastRelDepth) {
        assert relDepth + 1 == lastRelDepth;
        // Found parent for siblings at lastRelDepth
        for (IntIterator siblingIdx : mySiblingBuffers.get(lastRelDepth)) {
          myParents.set(siblingIdx.value(), i);
        }
      }
      if (relDepth < 0) {
        // Found parent for idx
        break;
      }
      for (int j = lastRelDepth + 1; j <= relDepth; ++j) {
        if (j == mySiblingBuffers.size()) {
          mySiblingBuffers.add(new IntArray());
        }
        mySiblingBuffers.get(j).clear();
      }
      mySiblingBuffers.get(relDepth).add(i);
      lastRelDepth = relDepth;
      int parent = myParents.get(i);
      if (parent >= 0) {
        // Jump
        i = parent;
      } else {
        // Scroll
        i -= 1;
      }
    }

    return myParents.get(idx);
  }

  public int depth(int idx) {
    return myForest.getDepth(idx);
  }
  
  public long row(int idx) {
    return myForest.getRow(idx);
  }
  
  public LongIterator rows(final IntIterator indices) {
    return new LongFindingIterator() {
      @Override
      protected boolean findNext() {
        if (indices.hasNext()) {
          myNext = myForest.getRow(indices.nextValue());
          return true;
        }
        return false;
      }
    };
  }

  public int size() {
    return myForest.size();
  }

  public Forest getForest() {
    return myForest;
  }
}
