/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationFields;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationParameters;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.outlierdetection.OutlierDetection;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

abstract class AbstractConfusionMatrixMetric
implements EvaluationMetric {
    public static final ParseField AT = new ParseField("at", new String[0]);
    protected final double[] thresholds;
    private EvaluationMetricResult result;

    protected AbstractConfusionMatrixMetric(List<Double> at) {
        this.thresholds = ExceptionsHelper.requireNonNull(at, AT).stream().mapToDouble(Double::doubleValue).toArray();
        if (this.thresholds.length == 0) {
            throw ExceptionsHelper.badRequestException("[" + this.getName() + "." + AT.getPreferredName() + "] must have at least one value", new Object[0]);
        }
        for (double threshold : this.thresholds) {
            if (!(threshold < 0.0) && !(threshold > 1.0)) continue;
            throw ExceptionsHelper.badRequestException("[" + this.getName() + "." + AT.getPreferredName() + "] values must be in [0.0, 1.0]", new Object[0]);
        }
    }

    protected AbstractConfusionMatrixMetric(StreamInput in) throws IOException {
        this.thresholds = in.readDoubleArray();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeDoubleArray(this.thresholds);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(AT.getPreferredName(), (Object)this.thresholds);
        builder.endObject();
        return builder;
    }

    @Override
    public Set<String> getRequiredFields() {
        return Sets.newHashSet((Object[])new String[]{EvaluationFields.ACTUAL_FIELD.getPreferredName(), EvaluationFields.PREDICTED_PROBABILITY_FIELD.getPreferredName()});
    }

    @Override
    public Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(EvaluationParameters parameters, EvaluationFields fields) {
        if (this.result != null) {
            return Tuple.tuple(Collections.emptyList(), Collections.emptyList());
        }
        String actualField = fields.getActualField();
        String predictedProbabilityField = fields.getPredictedProbabilityField();
        return Tuple.tuple(this.aggsAt(actualField, predictedProbabilityField), Collections.emptyList());
    }

    @Override
    public void process(Aggregations aggs) {
        this.result = this.evaluate(aggs);
    }

    public Optional<EvaluationMetricResult> getResult() {
        return Optional.ofNullable(this.result);
    }

    protected abstract List<AggregationBuilder> aggsAt(String var1, String var2);

    protected abstract EvaluationMetricResult evaluate(Aggregations var1);

    protected String aggName(double threshold, Condition condition) {
        return this.getName() + "_at_" + threshold + "_" + condition.name();
    }

    protected AggregationBuilder buildAgg(String actualField, String predictedProbabilityField, double threshold, Condition condition) {
        BoolQueryBuilder boolQuery = QueryBuilders.boolQuery();
        QueryBuilder actualIsTrueQuery = OutlierDetection.actualIsTrueQuery(actualField);
        RangeQueryBuilder predictedIsTrueQuery = QueryBuilders.rangeQuery((String)predictedProbabilityField).gte((Object)threshold);
        if (condition.actual) {
            boolQuery.must(actualIsTrueQuery);
        } else {
            boolQuery.mustNot(actualIsTrueQuery);
        }
        if (condition.predicted) {
            boolQuery.must((QueryBuilder)predictedIsTrueQuery);
        } else {
            boolQuery.mustNot((QueryBuilder)predictedIsTrueQuery);
        }
        return AggregationBuilders.filter((String)this.aggName(threshold, condition), (QueryBuilder)boolQuery);
    }

    static enum Condition {
        TP(true, true),
        FP(false, true),
        TN(false, false),
        FN(true, false);

        final boolean actual;
        final boolean predicted;

        private Condition(boolean actual, boolean predicted) {
            this.actual = actual;
            this.predicted = predicted;
        }
    }
}

