/*
 * Decompiled with CFR 0.152.
 */
package se.prediktera.mda.model.tSNE;

import com.jujutsu.tsne.barneshut.ParallelBHTsne;
import com.jujutsu.tsne.barneshut.TSneConfiguration;
import com.jujutsu.utils.MatrixOps;
import com.jujutsu.utils.MatrixUtils;
import com.jujutsu.utils.TSneUtils;
import java.io.File;
import java.io.IOException;
import java.util.Properties;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import se.prediktera.map.common.MapProperty;
import se.prediktera.map.common.NiceDataFormat;
import se.prediktera.map.common.RandomInputStream;
import se.prediktera.map.common.RandomOutputStream;
import se.prediktera.map.common.progress.ProgressManager;
import se.prediktera.map.data.modeltable.AbstractTModel;
import se.prediktera.map.datasetcontainer.DataContainer;
import se.prediktera.map.datasetcontainer.datainfo.AbstractDataInfo;
import se.prediktera.map.datasetcontainer.dataset.DataSet;
import se.prediktera.map.datasetcontainer.matrix.AbstractDataMatrix;
import se.prediktera.map.datasetcontainer.matrix.AbstractMVmatrix;
import se.prediktera.map.datasetcontainer.matrix.ModelMatrix;
import se.prediktera.map.main.IniManager;
import se.prediktera.map.main.Main_GUI;
import se.prediktera.map.main.StatusBar;
import se.prediktera.map.main.extclass.DataTreeExtLibInterface;
import se.prediktera.map.main.extclass.ExtClassLoader;
import se.prediktera.map.main.script.ScriptHelper;
import se.prediktera.map.model.AbstractModel;
import se.prediktera.map.model.AbstractModelResults;
import se.prediktera.map.model.AbstractPrediction;
import se.prediktera.mda.model.AbstractMDAModel;
import se.prediktera.mda.model.ModelException;
import se.prediktera.mda.model.PCA.PCAModel;

public class TSNEModel
extends AbstractModel
implements DataTreeExtLibInterface {
    private int outputDims = 2;
    private int max_iter = 100;
    private final int initial_dims;
    private final double perplexity;
    private final boolean usePca;
    private final boolean usePcaFromEvince = false;
    private AbstractMVmatrix X = null;
    private int[] Xindx;
    private final ModelMatrix result;
    private final AbstractMDAModel.MDAModelListener modelListener = new AbstractMDAModel.MDAModelListener(this);
    private DataContainer transformedContainer = new DataContainer("Transformed X and Y", 6, false);

    public TSNEModel(ProgressManager progressManager, DataSet dataSet, Properties properties) {
        super(dataSet);
        if (!this.resetModel(progressManager)) {
            throw new ModelException.ModelNoDataException("X and Y matrix dimensions cannot be zero");
        }
        this.addTreeNode(this.X);
        this.init(dataSet);
        this.result = new ModelMatrix((AbstractDataInfo)dataSet.getObsDataInfo(), this.X);
        this.result.setName("T");
        this.result.setMatrixInfo("Scores in <b>X</b>");
        this.addTreeNode(this.result);
        this.addTreeNode(this.transformedContainer);
        this.transformedContainer.addTreeNode(dataSet.createTransformed((byte)0, this.X, this.X));
        if (ScriptHelper.hasKey(properties, "outputdims")) {
            this.outputDims = ScriptHelper.getPropertyInt(properties, "outputdims");
        }
        this.initial_dims = ScriptHelper.getPropertyInt(properties, "initial_dims");
        this.perplexity = ScriptHelper.getPropertyDouble(properties, "perplexity");
        if (ScriptHelper.hasKey(properties, "max_iter")) {
            this.max_iter = ScriptHelper.getPropertyInt(properties, "max_iter");
        }
        this.usePca = ScriptHelper.getPropertyBool(properties, "usepca");
    }

    public TSNEModel(RandomInputStream randomInputStream, MapProperty mapProperty) throws IOException {
        super(randomInputStream, mapProperty);
        this.outputDims = mapProperty.getPropertyInt("outputDims");
        this.initial_dims = mapProperty.getPropertyInt("initial_dims");
        this.perplexity = mapProperty.getPropertyDouble("perplexity");
        this.max_iter = mapProperty.getPropertyInt("max_iter");
        this.usePca = mapProperty.getPropertyBoolean("usePca");
        this.X = (AbstractMVmatrix)mapProperty.resolveNodeFromName(randomInputStream, "X");
        this.result = (ModelMatrix)mapProperty.resolveNodeFromName(randomInputStream, "scores");
        this.transformedContainer = (DataContainer)mapProperty.resolveNodeFromName(randomInputStream, "trfC");
        this.init(this.getDataSet());
    }

    private void init(DataSet dataSet) {
        if (dataSet != null) {
            dataSet.addDataTreeChangeListener(this.modelListener);
        }
    }

    @Override
    public void SavePropertyLocal(RandomOutputStream randomOutputStream) throws IOException {
        super.SavePropertyLocal(randomOutputStream);
        randomOutputStream.writePropertyInt("outputDims", this.outputDims);
        randomOutputStream.writePropertyInt("initial_dims", this.initial_dims);
        randomOutputStream.writePropertyDouble("perplexity", this.perplexity);
        randomOutputStream.writePropertyInt("max_iter", this.max_iter);
        randomOutputStream.writePropertyBoolean("usePca", this.usePca);
        randomOutputStream.writePropertyNode("X", this.X);
        randomOutputStream.writePropertyNode("scores", this.result);
        randomOutputStream.writePropertyNode("trfC", this.transformedContainer);
    }

    protected boolean resetModel(ProgressManager progressManager) {
        this.dataSet.excludeNonVarying(progressManager, false);
        AbstractDataMatrix abstractDataMatrix = this.dataSet.getMatrix(0);
        if (abstractDataMatrix.getK() <= 0) {
            Main_GUI.setStatusMessage(StatusBar.MessageType.WARNING, "Cannot recalculate model \"" + String.valueOf(this) + "\": X matrix dimensions cannot be zero", 10);
            return false;
        }
        this.X = abstractDataMatrix.subMatrix(progressManager, this.X, this.Xindx, null, IniManager.getModelDataType());
        this.X.setName("Xres");
        this.X.setHtmlName("<html><b>X</b><sub>res</sub></html>");
        this.X.setMatrixInfo("Residual <b>X</b> matrix");
        if (progressManager != null) {
            progressManager.setInfoText("Applying transformations");
        }
        this.X.beginCalculations();
        abstractDataMatrix.applyTransformations(this.X, progressManager);
        this.X.endCalculations();
        return true;
    }

    @Override
    public ModelMatrix getScores() {
        return this.result;
    }

    @Override
    protected void doUpdateLocal(ProgressManager progressManager) {
        progressManager.setInfoText("Calculating " + this.getModelName());
        ParallelBHTsne parallelBHTsne = new ParallelBHTsne();
        TSneConfiguration tSneConfiguration = TSneUtils.buildConfig((double[][])this.getInputMatrix(), (int)this.outputDims, (int)this.initial_dims, (double)this.perplexity, (int)this.max_iter, (boolean)this.usePca, (double)0.5, (boolean)false, (boolean)true);
        double[][] dArray = parallelBHTsne.tsne(tSneConfiguration);
        this.result.resize(this.X);
        for (int i = 0; i < dArray[0].length; ++i) {
            DenseVector denseVector = new DenseVector(dArray.length);
            for (int j = 0; j < dArray.length; ++j) {
                denseVector.set(j, dArray[j][i]);
            }
            this.result.addRow(new DenseVector((Vector)denseVector));
        }
    }

    private double[][] getInputMatrix() {
        return this.toDoubleMatrix(this.X);
    }

    private PCAModel getPcaModel(DataSet dataSet) {
        for (int i = 0; i < dataSet.getTreeNodeCount(); ++i) {
            if (!(dataSet.getTreeNodeAt(i) instanceof PCAModel)) continue;
            return (PCAModel)dataSet.getTreeNodeAt(i);
        }
        return null;
    }

    private double[][] toDoubleMatrix(AbstractDataMatrix abstractDataMatrix) {
        double[][] dArray = new double[abstractDataMatrix.getN()][abstractDataMatrix.getK()];
        for (int i = 0; i < abstractDataMatrix.getN(); ++i) {
            for (int j = 0; j < abstractDataMatrix.getK(); ++j) {
                dArray[i][j] = abstractDataMatrix.getValue(i, j);
            }
        }
        return dArray;
    }

    @Override
    public void applyTransformationX(AbstractDataMatrix abstractDataMatrix) {
        if (this.X != null) {
            this.X.applyTransformations(abstractDataMatrix, null, false);
        }
    }

    @Override
    public void applyTransformationY(AbstractDataMatrix abstractDataMatrix) {
    }

    @Override
    public AbstractPrediction createPredictionLocal(ProgressManager progressManager, DataSet dataSet, Properties properties) {
        return null;
    }

    @Override
    public AbstractTModel createTableModel() {
        return null;
    }

    @Override
    public AbstractPrediction getLocalPredictionNode() {
        return null;
    }

    @Override
    public AbstractDataMatrix getMatrixByName(String string) {
        return null;
    }

    @Override
    public String getModelName() {
        return "t-SNE";
    }

    @Override
    public AbstractModelResults getResults() {
        return null;
    }

    @Override
    public int[] getYindx() {
        return null;
    }

    @Override
    public void reverseTransformationsX(DenseMatrix denseMatrix, int[] nArray, boolean bl) {
    }

    @Override
    public void reverseTransformationsY(DenseMatrix denseMatrix, int[] nArray, boolean bl) {
    }

    @Override
    protected void setProtectedLocal(ProgressManager progressManager, boolean bl) {
    }

    @Override
    public byte getNodeType() {
        return 5;
    }

    @Override
    public ExtClassLoader.LicenseLibrary getLibrary() {
        return ExtClassLoader.LicenseLibrary.MDA;
    }

    @Override
    public byte getTreeType() {
        return 20;
    }

    @Override
    public String getComponentDescription(int n) {
        return null;
    }

    @Override
    public String getInformation() {
        StringBuilder stringBuilder = new StringBuilder(420);
        stringBuilder.append("<HTML><BODY>");
        stringBuilder.append("<table border=\"0\" cellpadding=\"0\" cellspacing=\"0\" >");
        stringBuilder.append("<TR><TD>Model:</TD><TD>").append(this.getModelName()).append("</TD></TR>");
        if (this.getCase() != null) {
            stringBuilder.append("<tr><td></td><td>").append(this.getCase()).append("</td></tr>");
        }
        stringBuilder.append("<TR><TD>DataSet:</TD><TD>").append(this.dataSet != null ? this.dataSet : "N/A").append("</TD></TR>");
        stringBuilder.append("<TR><TD>Output dims:</TD><TD>").append(NiceDataFormat.toString(this.outputDims)).append("</TD></TR>");
        stringBuilder.append("<TR><TD>Initial dims:</TD><TD>").append(NiceDataFormat.toString(this.initial_dims)).append("</TD></TR>");
        stringBuilder.append("<TR><TD>Perplexity:</TD><TD>").append(NiceDataFormat.toString(this.perplexity)).append("</TD></TR>");
        stringBuilder.append("<TR><TD>Max iterations:</TD><TD>").append(NiceDataFormat.toString(this.max_iter)).append("</TD></TR>");
        stringBuilder.append("<TR><TD>Source:</TD><TD>").append(this.usePca ? "PCA" : "X-Train").append("</TD></TR>");
        stringBuilder.append("</TABLE></BODY></HTML>");
        return stringBuilder.toString();
    }

    public static void main(String[] stringArray) {
        double[][] dArray = MatrixUtils.simpleRead2DMatrix((File)new File("C:/Prediktera/mnist2500_X.txt"), (String)"   ");
        System.out.println(MatrixOps.doubleArrayToPrintString((double[][])dArray, (String)", ", (int)50, (int)10));
        ParallelBHTsne parallelBHTsne = new ParallelBHTsne();
        TSneConfiguration tSneConfiguration = TSneUtils.buildConfig((double[][])dArray, (int)2, (int)55, (double)20.0, (int)100);
        parallelBHTsne.tsne(tSneConfiguration);
    }
}

