/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.example;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Random;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.tribuo.DataSource;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.DataProvenance;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.util.Util;

public abstract class ClusteringDataGenerator {
    private static ClusteringFactory clusteringFactory = new ClusteringFactory();

    public static Dataset<ClusterID> gaussianClusters(long size, long seed) {
        if (size < 1L) {
            throw new IllegalArgumentException("Size must be a positive number, received " + size);
        }
        Random rng = new Random(seed);
        String[] featureNames = new String[]{"A", "B"};
        double[] mixingPMF = new double[]{0.1, 0.35, 0.05, 0.25, 0.25};
        double[] mixingCDF = Util.generateCDF((double[])mixingPMF);
        MultivariateNormalDistribution first = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), new double[]{0.0, 0.0}, (double[][])new double[][]{{1.0, 0.0}, {0.0, 1.0}});
        MultivariateNormalDistribution second = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), new double[]{5.0, 5.0}, (double[][])new double[][]{{1.0, 0.0}, {0.0, 1.0}});
        MultivariateNormalDistribution third = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), new double[]{2.5, 2.5}, (double[][])new double[][]{{1.0, 0.5}, {0.5, 1.0}});
        MultivariateNormalDistribution fourth = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), new double[]{10.0, 0.0}, (double[][])new double[][]{{0.1, 0.0}, {0.0, 0.1}});
        MultivariateNormalDistribution fifth = new MultivariateNormalDistribution((RandomGenerator)new JDKRandomGenerator(rng.nextInt()), new double[]{-1.0, 0.0}, (double[][])new double[][]{{1.0, 0.0}, {0.0, 0.1}});
        MultivariateNormalDistribution[] gaussians = new MultivariateNormalDistribution[]{first, second, third, fourth, fifth};
        ArrayList<ArrayExample> trainingData = new ArrayList<ArrayExample>();
        int i = 0;
        while ((long)i < size) {
            int centroid = Util.sampleFromCDF((double[])mixingCDF, (Random)rng);
            double[] sample = gaussians[centroid].sample();
            trainingData.add(new ArrayExample((Output)new ClusterID(centroid), featureNames, sample));
            ++i;
        }
        SimpleDataSourceProvenance trainingProvenance = new SimpleDataSourceProvenance("Generated clustering data", (OutputFactory)clusteringFactory);
        return new MutableDataset((DataSource)new ListDataSource(trainingData, (OutputFactory)clusteringFactory, (DataSourceProvenance)trainingProvenance));
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> denseTrainTest() {
        return ClusteringDataGenerator.denseTrainTest(-1.0);
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> denseTrainTest(double negate) {
        MutableDataset train = new MutableDataset((DataProvenance)new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(), (OutputFactory)clusteringFactory), (OutputFactory)clusteringFactory);
        String[] names = new String[]{"A", "B", "C", "D"};
        double[] values = new double[]{1.0, 0.5, 1.0, negate * 1.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(1), names, values));
        values = new double[]{1.5, 0.35, 1.3, negate * 1.2};
        train.add((Example)new ArrayExample((Output)new ClusterID(1), (String[])names.clone(), values));
        values = new double[]{1.2, 0.45, 1.5, negate * 1.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(1), (String[])names.clone(), values));
        values = new double[]{negate * 1.1, 0.55, negate * 1.5, 0.5};
        train.add((Example)new ArrayExample((Output)new ClusterID(2), (String[])names.clone(), values));
        values = new double[]{negate * 1.5, 0.25, negate * 1.0, 0.125};
        train.add((Example)new ArrayExample((Output)new ClusterID(2), (String[])names.clone(), values));
        values = new double[]{negate * 1.0, 0.5, negate * 1.123, 0.123};
        train.add((Example)new ArrayExample((Output)new ClusterID(2), (String[])names.clone(), values));
        values = new double[]{1.5, 5.0, 0.5, 4.5};
        train.add((Example)new ArrayExample((Output)new ClusterID(3), (String[])names.clone(), values));
        values = new double[]{1.234, 5.1235, 0.1235, 6.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(3), (String[])names.clone(), values));
        values = new double[]{1.734, 4.5, 0.5123, 5.5};
        train.add((Example)new ArrayExample((Output)new ClusterID(3), (String[])names.clone(), values));
        values = new double[]{negate * 1.0, 0.25, 5.0, 10.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(0), (String[])names.clone(), values));
        values = new double[]{negate * 1.4, 0.55, 5.65, 12.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(0), (String[])names.clone(), values));
        values = new double[]{negate * 1.9, 0.25, 5.9, 15.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(0), (String[])names.clone(), values));
        MutableDataset test = new MutableDataset((DataProvenance)new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(), (OutputFactory)clusteringFactory), (OutputFactory)clusteringFactory);
        values = new double[]{2.0, 0.45, 3.5, negate * 2.0};
        test.add((Example)new ArrayExample((Output)new ClusterID(1), (String[])names.clone(), values));
        values = new double[]{negate * 2.0, 0.55, negate * 2.5, 2.5};
        test.add((Example)new ArrayExample((Output)new ClusterID(2), (String[])names.clone(), values));
        values = new double[]{1.75, 5.0, 1.0, 6.5};
        test.add((Example)new ArrayExample((Output)new ClusterID(3), (String[])names.clone(), values));
        values = new double[]{negate * 1.5, 0.25, 5.0, 20.0};
        test.add((Example)new ArrayExample((Output)new ClusterID(0), (String[])names.clone(), values));
        return new Pair((Object)train, (Object)test);
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> sparseTrainTest() {
        return ClusteringDataGenerator.sparseTrainTest(-1.0);
    }

    public static Pair<Dataset<ClusterID>, Dataset<ClusterID>> sparseTrainTest(double negate) {
        MutableDataset train = new MutableDataset((DataProvenance)new SimpleDataSourceProvenance("TrainingData", OffsetDateTime.now(), (OutputFactory)clusteringFactory), (OutputFactory)clusteringFactory);
        String[] names = new String[]{"A", "B", "C", "D"};
        double[] values = new double[]{1.0, 0.5, 1.0, negate * 1.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(1), names, values));
        names = new String[]{"B", "D", "F", "H"};
        values = new double[]{1.5, 0.35, 1.3, negate * 1.2};
        train.add((Example)new ArrayExample((Output)new ClusterID(1), names, values));
        names = new String[]{"A", "J", "D", "M"};
        values = new double[]{1.2, 0.45, 1.5, negate * 1.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(1), names, values));
        names = new String[]{"C", "E", "F", "H"};
        values = new double[]{negate * 1.1, 0.55, negate * 1.5, 0.5};
        train.add((Example)new ArrayExample((Output)new ClusterID(2), names, values));
        names = new String[]{"E", "G", "F", "I"};
        values = new double[]{negate * 1.5, 0.25, negate * 1.0, 0.125};
        train.add((Example)new ArrayExample((Output)new ClusterID(2), names, values));
        names = new String[]{"J", "K", "C", "E"};
        values = new double[]{negate * 1.0, 0.5, negate * 1.123, 0.123};
        train.add((Example)new ArrayExample((Output)new ClusterID(2), names, values));
        names = new String[]{"E", "A", "K", "J"};
        values = new double[]{1.5, 5.0, 0.5, 4.5};
        train.add((Example)new ArrayExample((Output)new ClusterID(3), names, values));
        names = new String[]{"B", "C", "E", "H"};
        values = new double[]{1.234, 5.1235, 0.1235, 6.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(3), names, values));
        names = new String[]{"A", "M", "I", "J"};
        values = new double[]{1.734, 4.5, 0.5123, 5.5};
        train.add((Example)new ArrayExample((Output)new ClusterID(3), names, values));
        names = new String[]{"Z", "A", "B", "C"};
        values = new double[]{negate * 1.0, 0.25, 5.0, 10.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(0), names, values));
        names = new String[]{"K", "V", "E", "D"};
        values = new double[]{negate * 1.4, 0.55, 5.65, 12.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(0), names, values));
        names = new String[]{"B", "G", "E", "A"};
        values = new double[]{negate * 1.9, 0.25, 5.9, 15.0};
        train.add((Example)new ArrayExample((Output)new ClusterID(0), names, values));
        MutableDataset test = new MutableDataset((DataProvenance)new SimpleDataSourceProvenance("TestingData", OffsetDateTime.now(), (OutputFactory)clusteringFactory), (OutputFactory)clusteringFactory);
        names = new String[]{"AA", "B", "C", "D"};
        values = new double[]{2.0, 0.45, 3.5, negate * 2.0};
        test.add((Example)new ArrayExample((Output)new ClusterID(1), names, values));
        names = new String[]{"B", "BB", "F", "E"};
        values = new double[]{negate * 2.0, 0.55, negate * 2.5, 2.5};
        test.add((Example)new ArrayExample((Output)new ClusterID(2), names, values));
        names = new String[]{"B", "E", "G", "H"};
        values = new double[]{1.75, 5.0, 1.0, 6.5};
        test.add((Example)new ArrayExample((Output)new ClusterID(3), names, values));
        names = new String[]{"B", "CC", "DD", "EE"};
        values = new double[]{negate * 1.5, 0.25, 5.0, 20.0};
        test.add((Example)new ArrayExample((Output)new ClusterID(0), names, values));
        return new Pair((Object)train, (Object)test);
    }

    public static Example<ClusterID> invalidSparseExample() {
        return new ArrayExample((Output)new ClusterID(1), new String[]{"1", "5", "8"}, new double[]{1.0, 5.0, 8.0});
    }

    public static Example<ClusterID> emptyExample() {
        return new ArrayExample((Output)new ClusterID(1), new String[0], new double[0]);
    }
}

