/*
 * Decompiled with CFR 0.152.
 */
package se.prediktera.breeze.frontend.wizard.model.panel.ml;

import java.awt.Component;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.Insets;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Vector;
import java.util.stream.IntStream;
import lombok.Generated;
import org.json.JSONObject;
import se.prediktera.breeze.common.entry.BreezeType;
import se.prediktera.breeze.common.util.FileHelper;
import se.prediktera.breeze.common.util.StringHelper;
import se.prediktera.breeze.entry.analyse.Analyse;
import se.prediktera.breeze.entry.analyse.model.AbstractRuntimeModel;
import se.prediktera.breeze.frontend.BreezeFrame;
import se.prediktera.breeze.frontend.common.entry.EntryMainHelper;
import se.prediktera.breeze.frontend.common.entry.TabPanel;
import se.prediktera.breeze.frontend.common.swing.BreezeButton;
import se.prediktera.breeze.frontend.common.swing.BreezeCheckBox;
import se.prediktera.breeze.frontend.common.swing.BreezeComboBoxInputField;
import se.prediktera.breeze.frontend.common.swing.BreezeInputFieldHelper;
import se.prediktera.breeze.frontend.common.swing.BreezeSpinnerInputPanel;
import se.prediktera.breeze.frontend.common.swing.dialog.AbstractBreezeDialog;
import se.prediktera.breeze.frontend.common.swing.dialog.BreezeOptionDialog;
import se.prediktera.breeze.frontend.common.swing.dialog.BreezeProgressManager;
import se.prediktera.breeze.frontend.common.swing.table.BreezeTablePanel;
import se.prediktera.breeze.frontend.common.swing.table.DefaultBreezeTableImpl;
import se.prediktera.breeze.frontend.common.view.AbstractViewPanel;
import se.prediktera.breeze.hardware.tcp.TcpManager;
import se.prediktera.map.common.BrowserLauncher;
import se.prediktera.map.model.AbstractModel;
import se.prediktera.mda.model.PLSDA.PLSDAModel;

public class RuntimeModelPanel
extends AbstractViewPanel
implements TabPanel.Tab {
    private final BreezeTablePanel table;
    private static final ListItem[] ClassificationOptimizers = new ListItem[]{new ListItem("Macro-accuracy", "MacroAccuracy"), new ListItem("Log-loss", "LogLoss"), new ListItem("Log-loss reduction", "LogLossReduction"), new ListItem("Micro-accuracy", "MicroAccuracy"), new ListItem("Top-K accuracy", "TopKAccuracy")};
    private static final ListItem[] QuantificationOptimizers = new ListItem[]{new ListItem("R-squared", "RSquared"), new ListItem("Mean absolute error (MAE)", "MeanAbsoluteError"), new ListItem("Mean Squared Error (MSE)", "MeanSquaredError"), new ListItem("Root Mean Squared Error (RMSE)", "RootMeanSquaredError")};
    private final BreezeSpinnerInputPanel topKInputField;
    private final AbstractRuntimeModel model;
    private final boolean isClassification;

    public RuntimeModelPanel(AbstractRuntimeModel abstractRuntimeModel, String string) {
        Object object;
        this.model = abstractRuntimeModel;
        this.setLayout(new GridBagLayout());
        AbstractRuntimeModel.RuntimeSettings runtimeSettings = abstractRuntimeModel.getRuntimeSettings();
        BreezeComboBoxInputField<Object> breezeComboBoxInputField = new BreezeComboBoxInputField<Object>("Algorithm");
        breezeComboBoxInputField.addItem(new ListItem("Auto", "Auto"));
        breezeComboBoxInputField.addSeparator();
        breezeComboBoxInputField.getComboBox().setName("algoBox");
        HashMap<String, String> hashMap = TcpManager.getInstance().getMLMethods(string);
        for (String object32 : hashMap.keySet()) {
            object = new ListItem(hashMap.get(object32), object32);
            breezeComboBoxInputField.addItem(object);
            if (!object32.equals(runtimeSettings.method)) continue;
            breezeComboBoxInputField.setSelectedItem(object);
        }
        breezeComboBoxInputField.setLabelWidth(200);
        breezeComboBoxInputField.setUnitLabelWidth(100);
        breezeComboBoxInputField.setUnit(" ");
        breezeComboBoxInputField.setChangedListener(listItem -> {
            runtimeSettings.method = ((ListItem)breezeComboBoxInputField.getSelectedItem()).getTrainerName();
            runtimeSettings.displayName = ((ListItem)breezeComboBoxInputField.getSelectedItem()).getDisplayName();
        });
        Object object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridy = this.getComponentCount();
        ((GridBagConstraints)object4).fill = 2;
        ((GridBagConstraints)object4).insets = new Insets(10, 0, 0, 0);
        this.add(breezeComboBoxInputField, object4);
        this.table = new BreezeTablePanel(false, false);
        this.isClassification = string.equals("Classification");
        if (runtimeSettings.metric == null) {
            runtimeSettings.metric = this.isClassification ? RuntimeModelPanel.ClassificationOptimizers[0].TrainerName : RuntimeModelPanel.QuantificationOptimizers[0].TrainerName;
        }
        BreezeComboBoxInputField<ListItem> breezeComboBoxInputField2 = new BreezeComboBoxInputField<ListItem>("Optimizing Metric");
        for (ListItem listItem2 : this.isClassification ? ClassificationOptimizers : QuantificationOptimizers) {
            breezeComboBoxInputField2.addItem(listItem2);
            if (!Objects.equals(listItem2.TrainerName, runtimeSettings.metric)) continue;
            breezeComboBoxInputField2.setSelectedItem(listItem2);
        }
        breezeComboBoxInputField2.setChangedListener(listItem -> {
            runtimeSettings.metric = ((ListItem)breezeComboBoxInputField2.getSelectedItem()).getTrainerName();
        });
        breezeComboBoxInputField2.setFormat(BreezeInputFieldHelper.Format.OneRow);
        breezeComboBoxInputField2.setLabelWidth(200);
        breezeComboBoxInputField2.setUnit(" ");
        breezeComboBoxInputField2.setUnitLabelWidth(100);
        object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridy = this.getComponentCount();
        ((GridBagConstraints)object4).fill = 2;
        ((GridBagConstraints)object4).insets = new Insets(10, 0, 0, 0);
        this.add(breezeComboBoxInputField2, object4);
        object = new BreezeSpinnerInputPanel("Time", "", 0.0f, 9999999.0f, 1.0f, runtimeSettings.time, "Seconds", d -> {
            runtimeSettings.time = d.intValue();
        });
        ((Component)object).setName("time");
        ((BreezeSpinnerInputPanel)object).setFormat(BreezeInputFieldHelper.Format.OneRow);
        ((BreezeSpinnerInputPanel)object).setLabelWidth(200);
        ((BreezeSpinnerInputPanel)object).setUnitLabelWidth(100);
        object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridy = this.getComponentCount();
        ((GridBagConstraints)object4).fill = 2;
        ((GridBagConstraints)object4).insets = new Insets(10, 0, 0, 0);
        this.add((Component)object, object4);
        BreezeSpinnerInputPanel breezeSpinnerInputPanel = new BreezeSpinnerInputPanel("Number of cross validation folds", "", 0.0f, 100.0f, 1.0f, runtimeSettings.numberOfCvFolds, " ", d -> {
            runtimeSettings.numberOfCvFolds = d.intValue();
        });
        breezeSpinnerInputPanel.setName("cvFolds");
        breezeSpinnerInputPanel.setFormat(BreezeInputFieldHelper.Format.OneRow);
        breezeSpinnerInputPanel.setLabelWidth(200);
        breezeSpinnerInputPanel.setUnitLabelWidth(100);
        object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridy = this.getComponentCount();
        ((GridBagConstraints)object4).fill = 2;
        ((GridBagConstraints)object4).insets = new Insets(10, 0, 0, 0);
        this.add(breezeSpinnerInputPanel, object4);
        this.topKInputField = new BreezeSpinnerInputPanel("Top-K classes", "", 1.0f, 1000.0f, 1.0f, runtimeSettings.topK, " ", d -> {
            runtimeSettings.topK = d.intValue();
        });
        this.topKInputField.setName("topk");
        this.topKInputField.setFormat(BreezeInputFieldHelper.Format.OneRow);
        this.topKInputField.setLabelWidth(200);
        this.topKInputField.setUnitLabelWidth(100);
        this.topKInputField.setVisible(this.isClassification);
        object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridy = this.getComponentCount();
        ((GridBagConstraints)object4).fill = 2;
        ((GridBagConstraints)object4).insets = new Insets(10, 0, 0, 0);
        this.add(this.topKInputField, object4);
        int n = this.getComponentCount();
        BreezeCheckBox breezeCheckBox = new BreezeCheckBox("Cross validate all experiments", runtimeSettings.crossValidateAll);
        breezeCheckBox.addItemListener(itemEvent -> {
            runtimeSettings.crossValidateAll = breezeCheckBox.isSelected();
        });
        object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridx = 0;
        ((GridBagConstraints)object4).gridy = n;
        ((GridBagConstraints)object4).anchor = 17;
        ((GridBagConstraints)object4).insets = new Insets(10, 0, 0, 0);
        this.add(breezeCheckBox, object4);
        BreezeButton breezeButton = EntryMainHelper.createButton("", "help", () -> BrowserLauncher.open((String)"https://help.prediktera.com/breeze/machine-learning"));
        object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridx = 0;
        ((GridBagConstraints)object4).gridy = n;
        ((GridBagConstraints)object4).anchor = 13;
        ((GridBagConstraints)object4).insets = new Insets(10, 0, 0, 0);
        this.add(breezeButton, object4);
        List<AbstractRuntimeModel.RuntimeResult> list = abstractRuntimeModel.getRuntimeResults().results;
        RuntimeTableModel runtimeTableModel = list != null && !list.isEmpty() && list.getFirst().algorithmName.equals("External") ? new RuntimeTableModel(null, null, Arrays.asList("Name", "Accuracy", "Test Accuracy", "Cross Validation", "Runtime in Seconds")) : new RuntimeTableModel(string, runtimeSettings.metric);
        this.table.initTable(runtimeTableModel, null, false);
        this.table.setTableList(abstractRuntimeModel.getRuntimeResults().results, null);
        this.table.setColumnWidth(0);
        object4 = new GridBagConstraints();
        ((GridBagConstraints)object4).gridy = this.getComponentCount();
        ((GridBagConstraints)object4).weightx = 1.0;
        ((GridBagConstraints)object4).weighty = 1.0;
        ((GridBagConstraints)object4).fill = 1;
        ((GridBagConstraints)object4).insets = new Insets(20, 0, 0, 0);
        this.add((Component)((Object)this.table), object4);
    }

    @Override
    public String getType() {
        return "runtime";
    }

    @Override
    public void showEntry(Object object, Object object2) {
    }

    public void updateTopKMax() {
        if (this.isClassification) {
            int n = 0;
            AbstractModel abstractModel = this.model.getModel();
            if (abstractModel instanceof PLSDAModel) {
                PLSDAModel pLSDAModel = (PLSDAModel)abstractModel;
                n = Math.max(pLSDAModel.getInnerColumnClass().getObject().getSize(false) - 1, 1);
            }
            this.topKInputField.setMax(n);
            this.topKInputField.setValue(Math.min((double)n, this.topKInputField.getValueDouble()));
        }
    }

    public String toString() {
        return "Runtime";
    }

    @Override
    public String getName() {
        return "runtime";
    }

    public static File trainModel(BreezeProgressManager breezeProgressManager, Analyse analyse, AbstractRuntimeModel abstractRuntimeModel, BreezeFrame breezeFrame) {
        Object object;
        File file = FileHelper.createTempFileInWorkspace("train.csv");
        abstractRuntimeModel.saveTrainingData(breezeProgressManager, file, false);
        File file2 = FileHelper.createTempFileInWorkspace("test.csv");
        abstractRuntimeModel.saveTestData(breezeProgressManager, file2, false);
        File file3 = FileHelper.createTempFileInWorkspace("model.onnx");
        if (file2.length() == 0L) {
            BreezeOptionDialog.showDialog(breezeFrame, BreezeType.Analyse, AbstractBreezeDialog.MessageType.WARNING_OPTION, "Add Test Set Required", "You have not selected any test set for training the ML model. Please ensure you select a test set to proceed with training.\nA test set is crucial for evaluating the model's performance and generalization.", null);
            breezeProgressManager.abort();
            return file3;
        }
        AbstractRuntimeModel.RuntimeSettings runtimeSettings = abstractRuntimeModel.getRuntimeSettings();
        breezeProgressManager.setInfoText("Training model using " + runtimeSettings.displayName);
        int n = 2;
        Object object2 = abstractRuntimeModel.getModel();
        if (object2 instanceof PLSDAModel) {
            object = (PLSDAModel)object2;
            n = Math.max(object.getInnerColumnClass().getObject().getSize(false) - 1, 1);
        }
        runtimeSettings.topK = Math.min(runtimeSettings.topK, n);
        object = TcpManager.getInstance().train(breezeProgressManager, file, file2, file3, analyse.getModelMethod().getType(), runtimeSettings, n);
        object2 = new AbstractRuntimeModel.RuntimeResultList(new JSONObject((String)object), analyse.getModelMethod().getType(), false);
        abstractRuntimeModel.setRuntimeResults((AbstractRuntimeModel.RuntimeResultList)object2);
        return file3;
    }

    @Override
    public AbstractViewPanel getPanel() {
        return this;
    }

    @Generated
    public BreezeTablePanel getTable() {
        return this.table;
    }

    public static class ListItem {
        String DisplayName;
        String TrainerName;

        public ListItem(String string, String string2) {
            this.setDisplayName(string);
            this.setTrainerName(string2);
        }

        public String toString() {
            return this.getDisplayName();
        }

        @Generated
        public void setDisplayName(String string) {
            this.DisplayName = string;
        }

        @Generated
        public void setTrainerName(String string) {
            this.TrainerName = string;
        }

        @Generated
        public String getDisplayName() {
            return this.DisplayName;
        }

        @Generated
        public String getTrainerName() {
            return this.TrainerName;
        }
    }

    public static class RuntimeTableModel
    extends DefaultBreezeTableImpl {
        private final String type;
        private List<Object> headers = null;
        private String metric;

        public RuntimeTableModel(String string) {
            this(string, Objects.equals(string, "Classification") ? RuntimeModelPanel.ClassificationOptimizers[0].TrainerName : RuntimeModelPanel.QuantificationOptimizers[0].TrainerName);
        }

        public RuntimeTableModel(String string, String string2) {
            this.type = string;
            this.metric = string2;
        }

        public RuntimeTableModel(String string, String string2, List<Object> list) {
            this.type = string;
            this.metric = string2;
            this.headers = list;
        }

        @Override
        public Vector<Object> createVOVector(Object object) {
            AbstractRuntimeModel.RuntimeResult runtimeResult = (AbstractRuntimeModel.RuntimeResult)object;
            Vector<Object> vector = new Vector<Object>();
            boolean bl = runtimeResult.accuracy.getFirst() == -1.0;
            String string = runtimeResult.selected ? "(X) " : "";
            vector.add(string + String.valueOf(runtimeResult));
            for (Double d : runtimeResult.accuracy) {
                vector.add(d == -1.0 ? "N/A" : d);
            }
            if (runtimeResult.accuracy.size() == 1 && this.type != null) {
                IntStream.range(0, Objects.equals(this.type, "Classification") ? ClassificationOptimizers.length - 1 : QuantificationOptimizers.length - 1).forEach(n -> vector.add("N/A"));
            }
            vector.add(runtimeResult.accuracyTest == -1.0 ? "N/A" : Double.valueOf(runtimeResult.accuracyTest));
            if (!StringHelper.equals((String)this.metric, (String)"TopKAccuracy")) {
                vector.add(runtimeResult.crossValidationResult == -1.0 ? "N/A" : Double.valueOf(runtimeResult.crossValidationResult));
            }
            vector.add(String.format("%.2f", runtimeResult.runtimeInSeconds) + (bl ? " (Did not finish)" : ""));
            return vector;
        }

        @Override
        public List<Object> getTableHeader(List list) {
            if (this.headers != null) {
                return this.headers;
            }
            ArrayList<Object> arrayList = new ArrayList<Object>();
            arrayList.add("Algorithm Name");
            String string = StringHelper.formatLabel(this.metric);
            for (ListItem listItem : Objects.equals(this.type, "Classification") ? ClassificationOptimizers : QuantificationOptimizers) {
                arrayList.add(StringHelper.formatLabel(listItem.TrainerName));
            }
            arrayList.add(string + " Test");
            if (!StringHelper.equals((String)this.metric, (String)"TopKAccuracy")) {
                arrayList.add("CV " + string);
            }
            arrayList.add("Runtime in Seconds");
            return arrayList;
        }

        @Generated
        public void setMetric(String string) {
            this.metric = string;
        }
    }
}

