/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.dataframe.inference;

import java.util.Deque;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.inference.TestDocsIterator;
import org.elasticsearch.xpack.ml.dataframe.stats.DataCountsTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.utils.persistence.LimitAwareBulkIndexer;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

public class InferenceRunner {
    private static final Logger LOGGER = LogManager.getLogger(InferenceRunner.class);
    private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;
    private final Settings settings;
    private final Client client;
    private final ModelLoadingService modelLoadingService;
    private final ResultsPersisterService resultsPersisterService;
    private final TaskId parentTaskId;
    private final DataFrameAnalyticsConfig config;
    private final ExtractedFields extractedFields;
    private final ProgressTracker progressTracker;
    private final DataCountsTracker dataCountsTracker;
    private volatile boolean isCancelled;

    public InferenceRunner(Settings settings, Client client, ModelLoadingService modelLoadingService, ResultsPersisterService resultsPersisterService, TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) {
        this.settings = Objects.requireNonNull(settings);
        this.client = Objects.requireNonNull(client);
        this.modelLoadingService = Objects.requireNonNull(modelLoadingService);
        this.resultsPersisterService = Objects.requireNonNull(resultsPersisterService);
        this.parentTaskId = Objects.requireNonNull(parentTaskId);
        this.config = Objects.requireNonNull(config);
        this.extractedFields = Objects.requireNonNull(extractedFields);
        this.progressTracker = Objects.requireNonNull(progressTracker);
        this.dataCountsTracker = Objects.requireNonNull(dataCountsTracker);
    }

    public void cancel() {
        this.isCancelled = true;
    }

    public void run(String modelId) {
        if (this.isCancelled) {
            return;
        }
        LOGGER.info("[{}] Started inference on test data against model [{}]", (Object)this.config.getId(), (Object)modelId);
        try {
            PlainActionFuture localModelPlainActionFuture = new PlainActionFuture();
            this.modelLoadingService.getModelForPipeline(modelId, (ActionListener<LocalModel>)localModelPlainActionFuture);
            TestDocsIterator testDocsIterator = new TestDocsIterator(new OriginSettingClient(this.client, "ml"), this.config, this.extractedFields);
            try (LocalModel localModel = (LocalModel)localModelPlainActionFuture.actionGet();){
                LOGGER.debug("Loaded inference model [{}]", (Object)localModel);
                this.inferTestDocs(localModel, testDocsIterator);
            }
        }
        catch (Exception e) {
            LOGGER.error((Message)new ParameterizedMessage("[{}] Error during inference against model [{}]", (Object)this.config.getId(), (Object)modelId), (Throwable)e);
            throw ExceptionsHelper.serverError((String)"[{}] failed running inference on model [{}]; cause was [{}]", (Throwable)e, (Object[])new Object[]{this.config.getId(), modelId, e.getMessage()});
        }
    }

    void inferTestDocs(LocalModel model, TestDocsIterator testDocsIterator) {
        long totalDocCount = 0L;
        long processedDocCount = 0L;
        try (LimitAwareBulkIndexer bulkIndexer = new LimitAwareBulkIndexer(this.settings, this::executeBulkRequest);){
            while (testDocsIterator.hasNext()) {
                if (this.isCancelled) {
                    break;
                }
                Deque batch = testDocsIterator.next();
                if (totalDocCount == 0L) {
                    totalDocCount = testDocsIterator.getTotalHits();
                }
                for (SearchHit doc : batch) {
                    this.dataCountsTracker.incrementTestDocsCount();
                    InferenceResults inferenceResults = model.inferNoStats(this.featuresFromDoc(doc));
                    bulkIndexer.addAndExecuteIfNeeded(this.createIndexRequest(doc, inferenceResults, this.config.getDest().getResultsField()));
                    int progressPercent = Math.min((int)((double)(++processedDocCount) * 100.0 / (double)totalDocCount), 98);
                    this.progressTracker.updateInferenceProgress(progressPercent);
                }
            }
        }
        if (!this.isCancelled) {
            this.progressTracker.updateInferenceProgress(100);
        }
    }

    private Map<String, Object> featuresFromDoc(SearchHit doc) {
        HashMap<String, Object> features = new HashMap<String, Object>();
        for (ExtractedField extractedField : this.extractedFields.getAllFields()) {
            Object[] values = extractedField.value(doc);
            if (values.length != 1) continue;
            features.put(extractedField.getName(), values[0]);
        }
        return features;
    }

    private IndexRequest createIndexRequest(SearchHit hit, InferenceResults results, String resultField) {
        LinkedHashMap<String, Boolean> resultsMap = new LinkedHashMap<String, Boolean>(results.asMap());
        resultsMap.put("is_training", false);
        LinkedHashMap<String, LinkedHashMap<String, Boolean>> source = new LinkedHashMap<String, LinkedHashMap<String, Boolean>>(hit.getSourceAsMap());
        source.put(resultField, resultsMap);
        IndexRequest indexRequest = new IndexRequest(hit.getIndex());
        indexRequest.id(hit.getId());
        indexRequest.source(source);
        indexRequest.opType(DocWriteRequest.OpType.INDEX);
        indexRequest.setParentTask(this.parentTaskId);
        return indexRequest;
    }

    private void executeBulkRequest(BulkRequest bulkRequest) {
        this.resultsPersisterService.bulkIndexWithHeadersWithRetry(this.config.getHeaders(), bulkRequest, this.config.getId(), () -> !this.isCancelled, errorMsg -> {});
    }
}

