import insane.ActivationFunction;
import insane.NetworkLayerProperties;
import insane.NeuralNetwork;
import insane.activation.Sigmoid;
import insane.training.NetworkInitializer;
import insane.training.TrainingConstraints;
import insane.training.TrainingInformation;
import insane.training.TrainingResults;
import insane.training.TrainingMethod;
import insane.training.distribution.NormalDistribution;
import insane.training.method.rprop.ResilientPropagation;
import insane.training.method.rprop.ResilientPropagationConfiguration;
import java.util.Random;

/**
 * Simple "exclusive OR" training with a multi-layered perceptron (MLP) composed
 * of 2 layers and a sigmoid activation function for each layer.
 * This neural network is initialized with a normal (gaussian) distribution and
 * trained with a resilient propagation (RPROP) method
 *
 * @author ncottin
 */
public final class ExclusiveOrResilientPropagation extends GenericExample {

    public static void main(String[] args) {
        try {
            // Initialize supervised training data to be used during the
            // training phase of the MLP
            TrainingInformation[] info = new TrainingInformation[] {
                new TrainingInformation(new double[]{0.0, 0.0}, new double[]{0.0}),
                new TrainingInformation(new double[]{0.0, 1.0}, new double[]{1.0}),
                new TrainingInformation(new double[]{1.0, 0.0}, new double[]{1.0}),
                new TrainingInformation(new double[]{1.0, 1.0}, new double[]{0.0})
            };

            ActivationFunction activation = new Sigmoid();

            // Set layers properties
            NetworkLayerProperties[] props = {
                new NetworkLayerProperties(2, activation),
                new NetworkLayerProperties(1, activation)
            };

            NeuralNetwork nnet = new NeuralNetwork(2, props);

            // Configure resilient propagation training using default
            // configuration
            // The configuration parameters (ie learning rate and momentum) may
            // vary depending on the nature of the source training data
            Random rand = new Random();
            ResilientPropagationConfiguration configuration = new ResilientPropagationConfiguration();
            TrainingMethod trainingMethod = new ResilientPropagation(rand, configuration);

            // Set other training constraints
            TrainingConstraints constraints = new TrainingConstraints();
            //configuration.setNegativeMomentum(0.5);
            configuration.setPositiveMomentum(1.5);
            constraints.setMaxError(1E-4); // Indicate the expected MSE
            constraints.setMaxEpochs(1500); // Make sure the training process ends

            // Train the network with uniformly chosen values
            NetworkInitializer.initialize(new NormalDistribution(rand), nnet);
            TrainingResults results = trainingMethod.train(constraints, TrainingMethod.ALL_ITEMS, nnet, null, info);

            // It may happen that the training gets stuck in a local minimum
            // In this case, the required maximum error cannot be reached
            // even after running all epochs
            if (!results.valid(constraints)) {
                System.out.println("MSE is " + results.getMeanSquareError());
                System.out.println("Caution: the expected MSE is not reached");
                System.out.println("The evaluated values may not be accurate");
                System.out.println("Please run again");
                System.out.println();
                return;
            }

            System.out.println("MSE:              " + results.getMeanSquareError());
            System.out.println("Best epoch:       " + results.getBestEpoch());

            double[] outputs;
            for (TrainingInformation ti : info) {
                System.out.println();
                System.out.print("Input values:     ");
                println(ti.getInputValues());
                System.out.print("Expected output:  ");
                println(ti.getExpectedOutputValues());
                System.out.print("Evaluated output: ");

                // Use the network to produce output (evaluated) values
                outputs = nnet.evaluate(ti.getInputValues());
                println(outputs);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
