/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class Tree
implements LenientlyParsedTrainedModel,
StrictlyParsedTrainedModel,
Accountable {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Tree.class);
    public static final ParseField NAME = new ParseField("tree", new String[0]);
    public static final ParseField FEATURE_NAMES = new ParseField("feature_names", new String[0]);
    public static final ParseField TREE_STRUCTURE = new ParseField("tree_structure", new String[0]);
    public static final ParseField TARGET_TYPE = new ParseField("target_type", new String[0]);
    public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels", new String[0]);
    private static final ObjectParser<Builder, Void> LENIENT_PARSER = Tree.createParser(true);
    private static final ObjectParser<Builder, Void> STRICT_PARSER = Tree.createParser(false);
    private final List<String> featureNames;
    private final List<TreeNode> nodes;
    private final TargetType targetType;
    private final List<String> classificationLabels;

    private static ObjectParser<Builder, Void> createParser(boolean lenient) {
        ObjectParser<Builder, Void> parser = new ObjectParser<Builder, Void>(NAME.getPreferredName(), lenient, Builder::new);
        parser.declareStringArray(Builder::setFeatureNames, FEATURE_NAMES);
        parser.declareObjectArray(Builder::setNodes, (p, c) -> TreeNode.fromXContent(p, lenient), TREE_STRUCTURE);
        parser.declareString((rec$, x$0) -> ((Builder)rec$).setTargetType(x$0), TARGET_TYPE);
        parser.declareStringArray(Builder::setClassificationLabels, CLASSIFICATION_LABELS);
        return parser;
    }

    public static Tree fromXContentStrict(XContentParser parser) {
        return STRICT_PARSER.apply(parser, null).build();
    }

    public static Tree fromXContentLenient(XContentParser parser) {
        return LENIENT_PARSER.apply(parser, null).build();
    }

    Tree(List<String> featureNames, List<TreeNode> nodes, TargetType targetType, List<String> classificationLabels) {
        this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES));
        if (ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) {
            throw new IllegalArgumentException("[tree_structure] must not be empty");
        }
        this.nodes = Collections.unmodifiableList(nodes);
        this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE);
        this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels);
    }

    public Tree(StreamInput in) throws IOException {
        this.featureNames = Collections.unmodifiableList(in.readStringList());
        this.nodes = Collections.unmodifiableList(in.readList(TreeNode::new));
        this.targetType = TargetType.fromStream(in);
        this.classificationLabels = in.readBoolean() ? Collections.unmodifiableList(in.readStringList()) : null;
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override
    public TargetType targetType() {
        return this.targetType;
    }

    @Override
    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    @Override
    public void writeTo(StreamOutput out) throws IOException {
        out.writeStringCollection(this.featureNames);
        out.writeCollection(this.nodes);
        this.targetType.writeTo(out);
        out.writeBoolean(this.classificationLabels != null);
        if (this.classificationLabels != null) {
            out.writeStringCollection(this.classificationLabels);
        }
    }

    @Override
    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(FEATURE_NAMES.getPreferredName(), this.featureNames);
        builder.field(TREE_STRUCTURE.getPreferredName(), this.nodes);
        builder.field(TARGET_TYPE.getPreferredName(), this.targetType.toString());
        if (this.classificationLabels != null) {
            builder.field(CLASSIFICATION_LABELS.getPreferredName(), this.classificationLabels);
        }
        builder.endObject();
        return builder;
    }

    public String toString() {
        return Strings.toString(this);
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Tree that = (Tree)o;
        return Objects.equals(this.featureNames, that.featureNames) && Objects.equals(this.nodes, that.nodes) && Objects.equals(this.targetType, that.targetType) && Objects.equals(this.classificationLabels, that.classificationLabels);
    }

    public int hashCode() {
        return Objects.hash(this.featureNames, this.nodes, this.targetType, this.classificationLabels);
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public void validate() {
        int maxFeatureIndex = this.maxFeatureIndex();
        if (maxFeatureIndex >= this.featureNames.size()) {
            throw ExceptionsHelper.badRequestException("feature index [{}] is out of bounds for the [{}] array", maxFeatureIndex, FEATURE_NAMES.getPreferredName());
        }
        if (this.nodes.size() > 1 && this.featureNames.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[{}] is empty and the tree has > 1 nodes; num nodes [{}]. The model Must have features if tree is not a stump", FEATURE_NAMES.getPreferredName(), this.nodes.size());
        }
        this.checkTargetType();
        this.detectMissingNodes();
        this.detectCycle();
        this.verifyLeafNodeUniformity();
    }

    @Override
    public long estimatedNumOperations() {
        return (long)Math.ceil(Math.log(this.nodes.size())) + (long)this.featureNames.size();
    }

    int maxFeatureIndex() {
        int maxFeatureIndex = -1;
        for (TreeNode node : this.nodes) {
            maxFeatureIndex = Math.max(maxFeatureIndex, node.getSplitFeature());
        }
        return maxFeatureIndex;
    }

    private void checkTargetType() {
        if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) {
            throw ExceptionsHelper.badRequestException("[target_type] should be [classification] if [classification_labels] are provided", new Object[0]);
        }
        if (this.targetType != TargetType.CLASSIFICATION && this.nodes.stream().anyMatch(n -> n.getLeafValue().length > 1)) {
            throw ExceptionsHelper.badRequestException("[target_type] should be [classification] if leaf nodes have multiple values", new Object[0]);
        }
    }

    private void detectCycle() {
        HashSet<Integer> visited = new HashSet<Integer>(this.nodes.size());
        ArrayDeque<Integer> toVisit = new ArrayDeque<Integer>(this.nodes.size());
        toVisit.add(0);
        while (!toVisit.isEmpty()) {
            Integer nodeIdx = (Integer)toVisit.remove();
            if (visited.contains(nodeIdx)) {
                throw ExceptionsHelper.badRequestException("[tree] contains cycle at node {}", nodeIdx);
            }
            visited.add(nodeIdx);
            TreeNode treeNode = this.nodes.get(nodeIdx);
            if (treeNode.getLeftChild() >= 0) {
                toVisit.add(treeNode.getLeftChild());
            }
            if (treeNode.getRightChild() < 0) continue;
            toVisit.add(treeNode.getRightChild());
        }
    }

    private void detectMissingNodes() {
        ArrayList<Integer> missingNodes = new ArrayList<Integer>();
        for (int i = 0; i < this.nodes.size(); ++i) {
            TreeNode currentNode = this.nodes.get(i);
            if (currentNode == null) continue;
            if (Tree.nodeMissing(currentNode.getLeftChild(), this.nodes)) {
                missingNodes.add(currentNode.getLeftChild());
            }
            if (!Tree.nodeMissing(currentNode.getRightChild(), this.nodes)) continue;
            missingNodes.add(currentNode.getRightChild());
        }
        if (!missingNodes.isEmpty()) {
            throw ExceptionsHelper.badRequestException("[tree] contains missing nodes {}", missingNodes);
        }
    }

    private void verifyLeafNodeUniformity() {
        Integer leafValueLengths = null;
        for (TreeNode node : this.nodes) {
            if (!node.isLeaf()) continue;
            if (leafValueLengths == null) {
                leafValueLengths = node.getLeafValue().length;
                continue;
            }
            if (leafValueLengths == node.getLeafValue().length) continue;
            throw ExceptionsHelper.badRequestException("[tree.tree_structure] all leaf nodes must have the same number of values", new Object[0]);
        }
    }

    private static boolean nodeMissing(int nodeIdx, List<TreeNode> nodes) {
        return nodeIdx >= nodes.size();
    }

    @Override
    public long ramBytesUsed() {
        long size = SHALLOW_SIZE;
        size += RamUsageEstimator.sizeOfCollection(this.classificationLabels);
        size += RamUsageEstimator.sizeOfCollection(this.featureNames);
        return size += RamUsageEstimator.sizeOfCollection(this.nodes);
    }

    @Override
    public Collection<Accountable> getChildResources() {
        ArrayList<Accountable> accountables = new ArrayList<Accountable>(this.nodes.size());
        for (TreeNode node : this.nodes) {
            accountables.add(Accountables.namedAccountable("tree_node_" + node.getNodeIndex(), node));
        }
        return Collections.unmodifiableCollection(accountables);
    }

    @Override
    public Version getMinimalCompatibilityVersion() {
        if (this.nodes.stream().filter(TreeNode::isLeaf).anyMatch(t -> t.getLeafValue().length > 1)) {
            return Version.V_7_7_0;
        }
        return Version.V_7_6_0;
    }

    public static class Builder {
        private List<String> featureNames;
        private ArrayList<TreeNode.Builder> nodes;
        private int numNodes;
        private TargetType targetType = TargetType.REGRESSION;
        private List<String> classificationLabels;

        public Builder() {
            this.nodes = new ArrayList();
            this.nodes.add(null);
            this.addLeaf(0, 0.0);
            this.numNodes = 1;
        }

        public Builder setFeatureNames(List<String> featureNames) {
            this.featureNames = featureNames;
            return this;
        }

        public Builder setRoot(TreeNode.Builder root) {
            this.nodes.set(0, root);
            return this;
        }

        public Builder addNode(TreeNode.Builder node) {
            this.nodes.add(node);
            return this;
        }

        public Builder setNodes(List<TreeNode.Builder> nodes) {
            this.nodes = new ArrayList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE.getPreferredName()));
            return this;
        }

        public Builder setNodes(TreeNode.Builder ... nodes) {
            return this.setNodes(Arrays.asList(nodes));
        }

        public Builder setTargetType(TargetType targetType) {
            this.targetType = targetType;
            return this;
        }

        public Builder setClassificationLabels(List<String> classificationLabels) {
            this.classificationLabels = classificationLabels;
            return this;
        }

        private void setTargetType(String targetType) {
            this.targetType = TargetType.fromString(targetType);
        }

        public TreeNode.Builder addJunction(int nodeIndex, int featureIndex, boolean isDefaultLeft, double decisionThreshold) {
            int leftChild = this.numNodes++;
            int rightChild = this.numNodes++;
            this.nodes.ensureCapacity(nodeIndex + 1);
            for (int i = this.nodes.size(); i < nodeIndex + 1; ++i) {
                this.nodes.add(null);
            }
            TreeNode.Builder node = TreeNode.builder(nodeIndex).setDefaultLeft(isDefaultLeft).setLeftChild(leftChild).setRightChild(rightChild).setSplitFeature(featureIndex).setThreshold(decisionThreshold);
            this.nodes.set(nodeIndex, node);
            while (this.nodes.size() <= rightChild) {
                this.nodes.add(null);
            }
            return node;
        }

        public Builder addLeaf(int nodeIndex, double value) {
            return this.addLeaf(nodeIndex, Arrays.asList(value));
        }

        public Builder addLeaf(int nodeIndex, List<Double> value) {
            for (int i = this.nodes.size(); i < nodeIndex + 1; ++i) {
                this.nodes.add(null);
            }
            this.nodes.set(nodeIndex, TreeNode.builder(nodeIndex).setLeafValue(value));
            return this;
        }

        public Tree build() {
            if (this.nodes.stream().anyMatch(Objects::isNull)) {
                throw ExceptionsHelper.badRequestException("[tree] cannot contain null nodes", new Object[0]);
            }
            return new Tree(this.featureNames, this.nodes.stream().map(TreeNode.Builder::build).collect(Collectors.toList()), this.targetType, this.classificationLabels);
        }
    }
}

