package insane.training.method.bp;

import insane.ActivationFunction;
import insane.NetworkLayer;
import insane.NeuralNetwork;
import insane.training.TrainingConstraints;
import insane.training.TrainingException;
import insane.training.TrainingInformation;
import insane.training.TrainingResults;
import insane.training.method.ErrorInformation;
import insane.training.method.StochasticTrainingMethod;
import java.util.Random;

/* loaded from: input_file:insane/training/method/bp/BackPropagation.class */
public final class BackPropagation extends StochasticTrainingMethod {
    public static final String IDENTIFIER = "bp";
    private BackPropagationConfiguration configuration;
    private NeuralNetwork network;
    private NetworkLayer[] layers;
    private boolean hasManyLayers;
    private int lastLayerIndex;
    private double[][] deltaErrors;
    private double[][][] deltaWeights;
    private double[][] deltaBiases;

    public BackPropagation(Random random, BackPropagationConfiguration backPropagationConfiguration) {
        super(random);
        this.configuration = backPropagationConfiguration;
    }

    /* JADX WARN: Type inference failed for: r1v14, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v18, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v22, types: [double[], double[][]] */
    @Override // insane.training.method.AbstractTrainingMethod
    protected void initializeTraining(TrainingConstraints trainingConstraints, NeuralNetwork neuralNetwork) throws TrainingException {
        this.network = neuralNetwork;
        this.layers = neuralNetwork.getLayers();
        this.lastLayerIndex = this.layers.length - 1;
        this.hasManyLayers = this.lastLayerIndex != 0;
        this.deltaErrors = new double[this.layers.length];
        this.deltaWeights = new double[this.layers.length];
        this.deltaBiases = new double[this.layers.length];
        for (int i = 0; i < this.layers.length; i++) {
            NetworkLayer networkLayer = this.layers[i];
            int neurons = networkLayer.getNeurons();
            this.deltaErrors[i] = new double[neurons];
            this.deltaWeights[i] = new double[neurons][networkLayer.getInputs()];
            this.deltaBiases[i] = new double[neurons];
        }
    }

    @Override // insane.training.method.AbstractTrainingMethod
    protected void train(TrainingInformation trainingInformation, ErrorInformation errorInformation) throws TrainingException {
        double[] inputValues = trainingInformation.getInputValues();
        double[] evaluate = this.network.evaluate(inputValues);
        double d = 0.0d;
        NetworkLayer networkLayer = this.layers[this.lastLayerIndex];
        double[] expectedOutputValues = trainingInformation.getExpectedOutputValues();
        double[] dArr = this.deltaErrors[this.lastLayerIndex];
        networkLayer.getActivationFunction();
        for (int i = 0; i < evaluate.length; i++) {
            double d2 = evaluate[i] - expectedOutputValues[i];
            d += d2 * d2;
            dArr[i] = d2;
        }
        if (this.hasManyLayers) {
            for (int i2 = this.lastLayerIndex - 1; i2 > -1; i2--) {
                double[] dArr2 = dArr;
                double[][] weights = networkLayer.getWeights();
                dArr = this.deltaErrors[i2];
                networkLayer = this.layers[i2];
                evaluate = networkLayer.getOutputs();
                ActivationFunction activationFunction = networkLayer.getActivationFunction();
                for (int i3 = 0; i3 < evaluate.length; i3++) {
                    double d3 = 0.0d;
                    for (int i4 = 0; i4 < dArr2.length; i4++) {
                        d3 += weights[i4][i3] * dArr2[i4];
                    }
                    dArr[i3] = activationFunction.computeDerivated(evaluate[i3]) * d3;
                }
            }
        }
        updateLayers(inputValues);
        errorInformation.add(d, evaluate.length);
    }

    private void updateLayers(double[] dArr) {
        double learningRate = this.configuration.getLearningRate();
        double momentum = this.configuration.getMomentum();
        for (int i = 0; i < this.layers.length; i++) {
            NetworkLayer networkLayer = this.layers[i];
            double[] dArr2 = this.deltaErrors[i];
            double[][] dArr3 = this.deltaWeights[i];
            double[][] weights = networkLayer.getWeights();
            double[] biases = networkLayer.getBiases();
            double[] dArr4 = this.deltaBiases[i];
            for (int i2 = 0; i2 < weights.length; i2++) {
                double[] dArr5 = weights[i2];
                double[] dArr6 = dArr3[i2];
                double d = learningRate * dArr2[i2];
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    double d2 = ((-d) * dArr[i3]) + (momentum * dArr6[i3]);
                    dArr6[i3] = d2;
                    int i4 = i3;
                    dArr5[i4] = dArr5[i4] + d2;
                }
                double d3 = d + (momentum * dArr4[i2]);
                dArr4[i2] = d3;
                int i5 = i2;
                biases[i5] = biases[i5] + d3;
            }
            dArr = networkLayer.getOutputs();
        }
    }

    @Override // insane.training.method.AbstractTrainingMethod
    protected TrainingResults terminateTraining(double d, int i) throws TrainingException {
        return new TrainingResults(d, i);
    }
}
