/*
 * Decompiled with CFR 0.152.
 */
package se.prediktera.mda.model;

import java.io.IOException;
import java.util.Arrays;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import se.prediktera.map.common.MapProperty;
import se.prediktera.map.common.RandomInputStream;
import se.prediktera.map.common.RandomOutputStream;
import se.prediktera.map.common.error.ErrorHandler;
import se.prediktera.map.common.progress.ProgressManager;
import se.prediktera.map.datasetcontainer.Case;
import se.prediktera.map.datasetcontainer.classes.ClassTypes;
import se.prediktera.map.datasetcontainer.classes.InnerColumnClass;
import se.prediktera.map.datasetcontainer.datainfo.AbstractDataInfo;
import se.prediktera.map.datasetcontainer.dataset.DataSet;
import se.prediktera.map.datasetcontainer.matrix.AbstractDataMatrix;
import se.prediktera.map.datasetcontainer.matrix.AbstractMVmatrix;
import se.prediktera.map.datasetcontainer.matrix.MVmatrix;
import se.prediktera.map.datasetcontainer.matrix.ModelMatrix;
import se.prediktera.map.datasetcontainer.transform.AbstractTransform;
import se.prediktera.map.datasetcontainer.transform.Transformations;
import se.prediktera.map.main.Main_GUI;
import se.prediktera.map.main.StatusBar;
import se.prediktera.mda.Statistics;
import se.prediktera.mda.model.AbstractCVResults;
import se.prediktera.mda.model.AbstractMDAModel;
import se.prediktera.mda.model.PLSDA.PLSDAModel;

public abstract class CrossvalidationModel
extends AbstractMDAModel {
    protected java.util.Vector<int[]> indxVector = null;
    private byte crossvalidationType = 0;
    private String cvCategoryName = null;
    protected int nr_groups = 7;
    protected boolean calc_ci = true;
    protected boolean typeisfull = true;
    private Long randomSeed = null;

    private static int[] createIndxArray(int n) {
        int[] nArray = new int[n];
        for (int i = 0; i < n; ++i) {
            nArray[i] = -1;
        }
        return nArray;
    }

    public CrossvalidationModel(DataSet dataSet, int[] nArray, int[] nArray2, boolean bl) {
        super(dataSet, nArray, nArray2, bl);
    }

    public CrossvalidationModel(RandomInputStream randomInputStream, MapProperty mapProperty) throws IOException {
        super(randomInputStream, mapProperty);
        this.nr_groups = mapProperty.getPropertyInt("nr_groups", 0);
        this.crossvalidationType = mapProperty.getPropertyByte("cvType");
        if (!mapProperty.hasProperty("cvType") && this.nr_groups > 0) {
            this.crossvalidationType = 1;
        }
        this.cvCategoryName = mapProperty.getPropertyStringUTF("cvClassName");
        if (mapProperty.hasProperty("indxVsize")) {
            int n = mapProperty.getLVInt();
            this.indxVector = new java.util.Vector();
            for (int i = 0; i < n; ++i) {
                if (!mapProperty.hasProperty("indxVi" + i)) continue;
                this.indxVector.addElement(mapProperty.getLVIntArray());
            }
        }
        this.calc_ci = mapProperty.getPropertyBoolean("ci", false);
        this.typeisfull = mapProperty.getPropertyBoolean("type", false);
    }

    public boolean getCI() {
        return this.calc_ci;
    }

    public String getExcludeBy() {
        return CV_NAME[this.crossvalidationType];
    }

    public MVmatrix getXorig(ProgressManager progressManager) throws Exception {
        AbstractDataMatrix abstractDataMatrix = this.dataSet.getMatrix(0);
        MVmatrix mVmatrix = abstractDataMatrix.subMatrix(this.getXindx(), null);
        mVmatrix.setName("X");
        mVmatrix.setHtmlName("<html><b>X</b></html>");
        this.X.applyTransformations(mVmatrix, progressManager);
        return mVmatrix;
    }

    public MVmatrix getYorigReversibleOnly(ProgressManager progressManager) throws Exception {
        AbstractDataMatrix abstractDataMatrix = this.dataSet.getMatrix(2);
        MVmatrix mVmatrix = abstractDataMatrix.subMatrix(this.getXindx(), this.getYindx());
        mVmatrix.setName("Y");
        mVmatrix.setHtmlName("<html><b>Y</b></html>");
        this.Y.applyTransformations(mVmatrix, progressManager, true, true);
        return mVmatrix;
    }

    public AbstractTransform getYtransformationByType(int n) {
        return this.Y.getTransformationByType(n);
    }

    public boolean isTypeFull() {
        return this.typeisfull;
    }

    @Override
    public void SaveModel(RandomOutputStream randomOutputStream) throws IOException {
        randomOutputStream.writePropertyInt("nr_groups", this.nr_groups);
        randomOutputStream.writePropertyByte("cvType", this.crossvalidationType);
        randomOutputStream.writePropertyStringUTF("cvClassName", this.cvCategoryName);
        if (this.indxVector != null) {
            randomOutputStream.writePropertyInt("indxVsize", this.indxVector.size());
            for (int i = 0; i < this.indxVector.size(); ++i) {
                randomOutputStream.writePropertyIntArray("indxVi" + i, this.indxVector.elementAt(i));
            }
        }
        randomOutputStream.writePropertyBoolean("ci", this.calc_ci);
        randomOutputStream.writePropertyBoolean("type", this.typeisfull);
    }

    public void setCrossvalidation(byte by, int n, String string, boolean bl) {
        this.crossvalidationType = by;
        this.nr_groups = n;
        this.cvCategoryName = string;
        this.calc_ci = bl;
        this.resetCrossvalidation();
    }

    public void setRandomSeed(Long l) {
        this.randomSeed = l;
    }

    public void setType(boolean bl) {
        this.typeisfull = bl;
    }

    protected void addCrossvalidationInfo(StringBuilder stringBuilder) {
        if (this.crossvalidationType != 0) {
            stringBuilder.append("<TR><TD colspan=2></TD></TR>");
            stringBuilder.append("<TR><TD colspan=2><B>Crossvalidation</B></TD></TR>");
            stringBuilder.append("<TR><TD>Type:</TD><TD>").append(this.typeisfull ? "Full" : "Partial").append("</TD></TR>");
            stringBuilder.append("<TR><TD>Exclusion by:</TD><TD>").append(this.getExcludeBy()).append("</TD></TR>");
            stringBuilder.append("<TR><TD>Rounds:</TD><TD>").append(this.nr_groups).append("</TD></TR>");
            if (this.crossvalidationType == 3) {
                stringBuilder.append("<TR><TD>Category:</TD><TD>").append(this.cvCategoryName).append("</TD></TR>");
            }
        }
    }

    protected double calc_cv_cum(ModelMatrix modelMatrix) {
        double d = 1.0 - modelMatrix.getLastScalar();
        if (d > 1.1) {
            d = 1.1;
        }
        for (int i = 0; i < this.res.getRound(); ++i) {
            double d2 = modelMatrix.getValue(i, 0);
            if (d2 < -0.1) {
                d2 = -0.1;
            }
            d *= 1.0 - d2;
        }
        return 1.0 - d;
    }

    protected double calc_cv_var_cum(DenseVector denseVector, int n, ModelMatrix modelMatrix) {
        double d = 1.0 - denseVector.get(n);
        if (d > 1.1) {
            d = 1.1;
        }
        for (int i = 0; i < this.res.getRound(); ++i) {
            double d2 = modelMatrix.getValue(i, n);
            if (d2 < -0.1) {
                d2 = -0.1;
            }
            d *= 1.0 - d2;
        }
        return 1.0 - d;
    }

    protected DenseVector calc_invert_cum(DenseVector denseVector, DenseVector denseVector2) {
        DenseVector denseVector3 = new DenseVector(denseVector2.size());
        for (int i = 0; i < denseVector2.size(); ++i) {
            if (denseVector != null) {
                denseVector3.set(i, denseVector2.get(i) - denseVector.get(i));
                continue;
            }
            denseVector3.set(i, denseVector2.get(i));
        }
        return denseVector3;
    }

    protected void calcConfidenceInterval(List<DenseVector[]> list, ModelMatrix[] modelMatrixArray) throws Exception {
        int n;
        DenseVector[] denseVectorArray = list.get(0);
        DenseVector[] denseVectorArray2 = new DenseVector[denseVectorArray.length];
        double d = list.size();
        for (n = 0; n < denseVectorArray.length; ++n) {
            double d2;
            double d3;
            int n2;
            double d4;
            ModelMatrix modelMatrix;
            Object object;
            int n3 = denseVectorArray[n].size();
            DenseVector denseVector = new DenseVector(n3);
            DenseVector denseVector2 = list.iterator();
            while (denseVector2.hasNext()) {
                object = denseVector2.next()[n];
                modelMatrix = modelMatrixArray[n];
                DenseVector denseVector3 = modelMatrix.getLastRM();
                d4 = denseVector3.dot((Vector)object);
                if (d4 < 0.0) {
                    object.scale(-1.0);
                }
                for (n2 = 0; n2 < n3; ++n2) {
                    d3 = object.get(n2);
                    denseVector.set(n2, denseVector.get(n2) + d3);
                }
            }
            for (int i = 0; i < n3; ++i) {
                denseVector.set(i, denseVector.get(i) / d);
            }
            denseVector2 = new DenseVector(n3);
            object = list.iterator();
            while (object.hasNext()) {
                modelMatrix = ((DenseVector[])object.next())[n];
                for (int i = 0; i < n3; ++i) {
                    d4 = modelMatrix.get(i) - denseVector.get(i);
                    denseVector2.set(i, denseVector2.get(i) + d4 * d4);
                }
            }
            int n4 = (int)d - 1;
            d4 = d2 = Statistics.getT(3, n4);
            for (n2 = 0; n2 < n3; ++n2) {
                d3 = Math.sqrt(denseVector2.get(n2) / (d - 1.0));
                denseVector2.set(n2, d3 * d4);
            }
            denseVectorArray2[n] = denseVector2;
        }
        for (n = 0; n < modelMatrixArray.length; ++n) {
            modelMatrixArray[n].addConfidenceInterval(denseVectorArray2[n]);
        }
    }

    @Override
    protected void cleanUpChild() {
        super.cleanUpChild();
        if (this.indxVector != null) {
            this.indxVector.clear();
            this.indxVector = null;
        }
    }

    protected abstract AbstractCVResults createResults(Case var1, AbstractDataMatrix var2, AbstractDataMatrix var3) throws Exception;

    protected abstract AbstractMVmatrix getCvMatrix();

    @Override
    protected void resetCrossvalidation() {
        AbstractMVmatrix abstractMVmatrix = this.getCvMatrix();
        this.indxVector = this.Make_cv_group(abstractMVmatrix.getN(), abstractMVmatrix.getNref());
    }

    private Random createRandom() {
        if (this.randomSeed != null) {
            return new Random(this.randomSeed);
        }
        return new Random(System.currentTimeMillis());
    }

    private InnerColumnClass getClassFromName(String string) {
        AbstractDataInfo abstractDataInfo = this.getCase().getDataInfo(0);
        if (abstractDataInfo != null) {
            Transformations transformations = abstractDataInfo.getTransformationNode();
            int n = transformations.getColumnCount();
            for (int i = 0; i < n; ++i) {
                if (transformations.getColumnType(i) != 4 || !transformations.getInnerColumn(i).toString().equals(string)) continue;
                return (InnerColumnClass)transformations.getInnerColumn(i);
            }
        }
        return null;
    }

    /*
     * Enabled aggressive block sorting
     */
    private java.util.Vector<int[]> Make_cv_group(int n, int[] nArray) {
        java.util.Vector<Object> vector;
        block38: {
            int n2;
            int n3;
            Random random;
            int n4;
            block39: {
                int n5;
                Random random2;
                LinkedList<Integer> linkedList;
                InnerColumnClass innerColumnClass;
                block37: {
                    ClassTypes classTypes;
                    InnerColumnClass innerColumnClass2;
                    block35: {
                        int n6;
                        Random random3;
                        LinkedList<Integer> linkedList2;
                        block33: {
                            block31: {
                                block29: {
                                    block36: {
                                        block34: {
                                            block32: {
                                                block30: {
                                                    block28: {
                                                        if (this.crossvalidationType == 0 || this.nr_groups == 0) {
                                                            return null;
                                                        }
                                                        if (this.crossvalidationType == 5) {
                                                            this.nr_groups = n;
                                                        }
                                                        vector = new java.util.Vector<int[]>(this.nr_groups);
                                                        n4 = (int)Math.ceil((double)n / (double)this.nr_groups);
                                                        if (this.crossvalidationType != 1) break block28;
                                                        break block29;
                                                    }
                                                    if (this.crossvalidationType != 2 && this.crossvalidationType != 5) break block30;
                                                    break block31;
                                                }
                                                if (this.crossvalidationType != 4) break block32;
                                                linkedList2 = new LinkedList<Integer>();
                                                random3 = this.createRandom();
                                                for (n6 = 0; n6 < n; ++n6) {
                                                    linkedList2.add(n6);
                                                }
                                                break block33;
                                            }
                                            if (this.crossvalidationType != 3) break block34;
                                            innerColumnClass2 = this.getClassFromName(this.cvCategoryName);
                                            if (innerColumnClass2 == null) {
                                                this.crossvalidationType = 0;
                                                return null;
                                            }
                                            classTypes = innerColumnClass2.getObject();
                                            vector = new java.util.Vector(classTypes.size());
                                            break block35;
                                        }
                                        if (this.crossvalidationType != 6) break block36;
                                        if (!(this instanceof PLSDAModel)) {
                                            throw new ErrorHandler.InformationMessageException("Stratified cross validation is only applicable for PLS-DA model");
                                        }
                                        innerColumnClass = ((PLSDAModel)this).getInnerColumnClass();
                                        linkedList = new LinkedList<Integer>();
                                        random2 = this.createRandom();
                                        for (n5 = 0; n5 < n; ++n5) {
                                            linkedList.add(n5);
                                        }
                                        break block37;
                                    }
                                    if (this.crossvalidationType != 7) break block38;
                                    random = this.createRandom();
                                    random.setSeed(new Date().getTime());
                                    break block39;
                                }
                                for (int i = 0; i < this.nr_groups; ++i) {
                                    int[] nArray2 = CrossvalidationModel.createIndxArray(n4);
                                    int n7 = 0;
                                    for (int j = i; j < n; j += this.nr_groups, ++n7) {
                                        nArray2[n7] = j;
                                    }
                                    vector.add(nArray2);
                                }
                                break block38;
                            }
                            for (int i = 0; i < this.nr_groups; ++i) {
                                int[] nArray3 = CrossvalidationModel.createIndxArray(n4);
                                int n8 = i * n4;
                                for (int j = 0; j < n4; ++j) {
                                    if (n8 + j >= n) continue;
                                    nArray3[j] = n8 + j;
                                }
                                vector.add(nArray3);
                            }
                            break block38;
                        }
                        for (n6 = 0; n6 < this.nr_groups; ++n6) {
                            int[] nArray4 = CrossvalidationModel.createIndxArray(n4);
                            for (int i = 0; i < n4 && !linkedList2.isEmpty(); ++i) {
                                n3 = random3.nextInt(linkedList2.size());
                                nArray4[i] = (Integer)linkedList2.remove(n3);
                            }
                            vector.add(nArray4);
                        }
                        break block38;
                    }
                    for (int i = 0; i < classTypes.size(); ++i) {
                        java.util.Vector<Integer> vector2 = new java.util.Vector<Integer>();
                        for (int j = 0; j < n; ++j) {
                            n3 = innerColumnClass2.getValueShort(nArray[j]);
                            if (n3 != i) continue;
                            vector2.add(j);
                        }
                        int[] nArray5 = new int[vector2.size()];
                        for (n3 = 0; n3 < vector2.size(); ++n3) {
                            nArray5[n3] = (Integer)vector2.elementAt(n3);
                        }
                        vector.add(nArray5);
                    }
                    break block38;
                }
                for (n5 = 0; n5 < this.nr_groups; ++n5) {
                    int[] nArray6 = CrossvalidationModel.createIndxArray(n4);
                    for (n3 = 0; n3 < n4 && !linkedList.isEmpty(); ++n3) {
                        int n9;
                        while ((n2 = innerColumnClass.getValueShort(nArray[random2.nextInt(n - 1)])) != innerColumnClass.getValueShort(nArray[(Integer)linkedList.get(n9 = random2.nextInt(linkedList.size()))])) {
                        }
                        nArray6[n3] = (Integer)linkedList.remove(n9);
                    }
                    vector.add(nArray6);
                }
                break block38;
            }
            block14: for (int i = 0; i < this.nr_groups * 10; ++i) {
                int[] nArray7 = CrossvalidationModel.createIndxArray(n4);
                int n10 = 0;
                block15: while (true) {
                    if (n10 >= n4) {
                        vector.add(nArray7);
                        continue block14;
                    }
                    boolean bl = true;
                    while (true) {
                        if (bl) {
                            bl = false;
                            n3 = random.nextInt(n - 1);
                        } else {
                            ++n10;
                            continue block15;
                        }
                        for (n2 = 0; n2 < n10; ++n2) {
                            if (nArray7[n10] != n3) continue;
                            bl = true;
                            break;
                        }
                        if (bl) continue;
                        nArray7[n10] = n3;
                    }
                    break;
                }
            }
        }
        if (vector != null) {
            int n11 = this.X.getN();
            for (int i = 0; i < vector.size(); ++i) {
                int[] nArray8 = vector.elementAt(i);
                int n12 = 0;
                for (int j = 0; j < nArray8.length && nArray8[j] >= 0; ++j, ++n12) {
                }
                Arrays.sort(nArray8, 0, n12);
                if (nArray8.length < n11 - 1) continue;
                Main_GUI.setStatusMessage(StatusBar.MessageType.WARNING, "Cross-Validation procedure disabled due to too few training set observations", 7);
                return null;
            }
        }
        return vector;
    }

    public byte getCrossvalidationType() {
        return this.crossvalidationType;
    }

    public int getNr_groups() {
        return this.nr_groups;
    }

    public String getCvCategoryName() {
        return this.cvCategoryName;
    }
}

