package com.almworks.jira.structure.api.attribute.loader.distinct;

import com.almworks.jira.structure.api.attribute.AttributeSpec;
import com.almworks.jira.structure.api.attribute.loader.*;
import com.almworks.jira.structure.api.attribute.loader.reduce.ReductionStrategy;
import com.almworks.jira.structure.api.item.ItemIdentity;
import com.almworks.jira.structure.api.row.StructureRow;
import com.google.common.collect.ImmutableSet;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Map;
import java.util.Set;

import static com.almworks.jira.structure.api.attribute.AttributeSpecBuilder.create;
import static com.almworks.jira.structure.api.attribute.CoreAttributeSpecs.Param.*;
import static com.almworks.jira.structure.api.attribute.loader.distinct.DistinctAttributes.MAP_ITEM_COUNT_FORMAT;
import static com.almworks.jira.structure.api.attribute.loader.distinct.DistinctAttributes.NON_UNIQUE_ITEMS;

public abstract class AbstractDistinctSumLoader<T, E, X> extends AbstractAttributeLoader<T> implements AttributeLoader.ForestIndependent<T> {
  private final ImmutableSet<AttributeSpec<?>> myDependencies;
  private final AttributeSpec<Map<ItemIdentity, E>> myDependentValuesSpec;
  private final AttributeSpec<T> myNonDistinctSpec;
  private final AttributeSpec<Map<ItemIdentity, Integer>> myNonUniqueItemsSpec;
  private final boolean myStrict;

  public AbstractDistinctSumLoader(AttributeSpec<T> spec, AttributeSpec<T> nonDistinctSpec, AttributeSpec<E> dependentAttribute) {
    super(spec);
    myNonDistinctSpec = nonDistinctSpec;
    String type = ReductionStrategy.getStrategyType(spec);
    myStrict = STRICT.equals(type);
    if (myStrict || SUBTREE.equals(type)) type = null;
    myDependentValuesSpec = create(NON_UNIQUE_ITEMS, DistinctAttributes.<E>nonUniqueValuesValueFormat())
      .params().setAttribute(dependentAttribute).set(TYPE, type).build();
    myNonUniqueItemsSpec = create(NON_UNIQUE_ITEMS, MAP_ITEM_COUNT_FORMAT).params().set(TYPE, type).build();
    myDependencies = ImmutableSet.of(myNonDistinctSpec, dependentAttribute, myNonUniqueItemsSpec, myDependentValuesSpec);
  }

  @Override
  public final boolean isEveryItemTypeSupported() {
    return true;
  }

  @Override
  public final boolean isItemTypeSupported(String itemType) {
    return true;
  }

  @Override
  public AttributeCachingStrategy getCachingStrategy() {
    return AttributeCachingStrategy.SHOULD;
  }

  @Override
  @NotNull
  public Set<? extends AttributeSpec<?>> getAttributeDependencies() {
    return myDependencies;
  }

  @Nullable
  protected Map<ItemIdentity, E> getChildrenValues(Context context) {
    AttributeValue<Map<ItemIdentity, E>> value = context.getAttributeValue(myDependentValuesSpec);
    return value == null ? null : value.getValue();
  }

  @Nullable
  protected Map<ItemIdentity, Integer> getChildrenCounts(Context context) {
    AttributeValue<Map<ItemIdentity, Integer>> value = context.getAttributeValue(myNonUniqueItemsSpec);
    return value == null ? null : value.getValue();
  }

  @Nullable
  protected abstract X initializeWithNonDistinctValue(T value);

  @NotNull
  protected abstract X removeDuplicates(@NotNull X before, @NotNull E argument, int extraInstancesCount);

  protected abstract AttributeValue<T> finalize(X value);

  @Override
  public AttributeValue<T> loadValue(StructureRow row, Context context) {
    AttributeValue<T> totalValue = context.getAttributeValue(myNonDistinctSpec);
    if (totalValue == null) {
      return AttributeValue.undefined();
    } else if (totalValue.isError() || !totalValue.isDefined()) {
      return totalValue;
    }
    X total = initializeWithNonDistinctValue(totalValue.getValue());
    if (total == null) {
      return AttributeValue.undefined();
    }

    Map<ItemIdentity, Integer> maybeDuplicateCounts = getChildrenCounts(context);
    Map<ItemIdentity, E> maybeDuplicatedValues = getChildrenValues(context);
    ItemIdentity itemId = context.getRow().getItemId();

    if (maybeDuplicateCounts != null && maybeDuplicatedValues != null) {
      for (Map.Entry<ItemIdentity, Integer> e : maybeDuplicateCounts.entrySet()) {
        Integer count = e.getValue();
        if (count != null) {
          E duplicatedValue = maybeDuplicatedValues.get(e.getKey());
          if (duplicatedValue != null) {
            if (myStrict && itemId.equals(e.getKey())) {
              // strict -- the non-distinct sum will not have the self value included
              count = count - 1;
            }
            if (count > 1) {
              total = removeDuplicates(total, duplicatedValue, count - 1);
            }
          }
        }
      }
    }
    return finalize(total).withTrail(totalValue.getAdditionalDataTrail());
  }
}
