package insane.training.method.rprop;

import insane.ActivationFunction;
import insane.NetworkLayer;
import insane.NeuralNetwork;
import insane.training.TrainingConstraints;
import insane.training.TrainingException;
import insane.training.TrainingInformation;
import insane.training.TrainingResults;
import insane.training.method.ErrorInformation;
import insane.training.method.SequentialTrainingMethod;
import java.util.Random;

/* loaded from: input_file:insane/training/method/rprop/ResilientPropagation.class */
public final class ResilientPropagation extends SequentialTrainingMethod {
    public static final String IDENTIFIER = "rprop";
    private ResilientPropagationConfiguration configuration;
    private NeuralNetwork network;
    private NetworkLayer[] layers;
    private boolean hasManyLayers;
    private int lastLayerIndex;
    private double[][] deltaErrors;
    private double[][][] deltaWeights;
    private double[][] deltaBiases;
    private double[][][] deltaWeightsUpdates;
    private double[][] deltaBiasesUpdates;
    private double[][][] weightsGradients;
    private double[][] biasesGradients;
    private double[][][] previousWeightsGradients;
    private double[][] previousBiasesGradients;

    public ResilientPropagation(Random random, ResilientPropagationConfiguration resilientPropagationConfiguration) {
        super(random);
        this.configuration = resilientPropagationConfiguration;
    }

    /* JADX WARN: Type inference failed for: r1v14, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v18, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v22, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v26, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v30, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v34, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v38, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v42, types: [double[][], double[][][]] */
    /* JADX WARN: Type inference failed for: r1v46, types: [double[], double[][]] */
    @Override // insane.training.method.AbstractTrainingMethod
    protected void initializeTraining(TrainingConstraints trainingConstraints, NeuralNetwork neuralNetwork) throws TrainingException {
        this.network = neuralNetwork;
        this.layers = neuralNetwork.getLayers();
        this.lastLayerIndex = this.layers.length - 1;
        this.hasManyLayers = this.lastLayerIndex != 0;
        this.deltaErrors = new double[this.layers.length];
        this.deltaWeights = new double[this.layers.length];
        this.deltaBiases = new double[this.layers.length];
        this.deltaWeightsUpdates = new double[this.layers.length];
        this.deltaBiasesUpdates = new double[this.layers.length];
        this.weightsGradients = new double[this.layers.length];
        this.biasesGradients = new double[this.layers.length];
        this.previousWeightsGradients = new double[this.layers.length];
        this.previousBiasesGradients = new double[this.layers.length];
        double initialDelta = this.configuration.getInitialDelta();
        for (int i = 0; i < this.layers.length; i++) {
            NetworkLayer networkLayer = this.layers[i];
            double[][] weights = networkLayer.getWeights();
            int neurons = networkLayer.getNeurons();
            this.deltaErrors[i] = new double[weights.length];
            int length = weights[0].length;
            this.deltaWeights[i] = new double[weights.length][length];
            this.deltaBiases[i] = new double[neurons];
            double[][] dArr = new double[weights.length][length];
            for (double[] dArr2 : dArr) {
                for (int i2 = 0; i2 < dArr2.length; i2++) {
                    dArr2[i2] = initialDelta;
                }
            }
            this.deltaWeightsUpdates[i] = dArr;
            double[] dArr3 = new double[neurons];
            for (int i3 = 0; i3 < dArr3.length; i3++) {
                dArr3[i3] = initialDelta;
            }
            this.deltaBiasesUpdates[i] = dArr3;
            this.weightsGradients[i] = new double[dArr.length][length];
            this.biasesGradients[i] = new double[neurons];
            this.previousWeightsGradients[i] = new double[dArr.length][length];
            this.previousBiasesGradients[i] = new double[neurons];
        }
    }

    @Override // insane.training.method.AbstractTrainingMethod
    protected void train(TrainingInformation trainingInformation, ErrorInformation errorInformation) throws TrainingException {
        double[] inputValues = trainingInformation.getInputValues();
        double[] evaluate = this.network.evaluate(inputValues);
        double d = 0.0d;
        NetworkLayer networkLayer = this.layers[this.lastLayerIndex];
        double[] expectedOutputValues = trainingInformation.getExpectedOutputValues();
        if (this.hasManyLayers) {
            inputValues = this.layers[this.lastLayerIndex - 1].getOutputs();
        }
        double[][] dArr = this.weightsGradients[this.lastLayerIndex];
        double[] dArr2 = this.biasesGradients[this.lastLayerIndex];
        double[] dArr3 = this.deltaErrors[this.lastLayerIndex];
        for (int i = 0; i < evaluate.length; i++) {
            double d2 = evaluate[i] - expectedOutputValues[i];
            d += d2 * d2;
            dArr3[i] = d2;
            double[] dArr4 = dArr[i];
            for (int i2 = 0; i2 < dArr4.length; i2++) {
                int i3 = i2;
                dArr4[i3] = dArr4[i3] + (d2 * inputValues[i2]);
            }
            int i4 = i;
            dArr2[i4] = dArr2[i4] - d2;
        }
        if (this.hasManyLayers) {
            NetworkLayer networkLayer2 = this.lastLayerIndex > 1 ? this.layers[this.lastLayerIndex - 2] : null;
            for (int i5 = this.lastLayerIndex - 1; i5 > 0; i5--) {
                double[] dArr5 = dArr3;
                double[][] weights = networkLayer.getWeights();
                dArr3 = this.deltaErrors[i5];
                updateValues(this.weightsGradients[i5], this.biasesGradients[i5], weights, dArr5, dArr3, this.layers[i5], networkLayer2.getOutputs());
                networkLayer = networkLayer2;
                networkLayer2 = this.layers[i5 - 1];
            }
            updateValues(this.weightsGradients[0], this.biasesGradients[0], this.layers[1].getWeights(), dArr3, this.deltaErrors[0], this.layers[0], trainingInformation.getInputValues());
        }
        errorInformation.add(d, evaluate.length);
    }

    private void updateValues(double[][] dArr, double[] dArr2, double[][] dArr3, double[] dArr4, double[] dArr5, NetworkLayer networkLayer, double[] dArr6) {
        double[] outputs = networkLayer.getOutputs();
        ActivationFunction activationFunction = networkLayer.getActivationFunction();
        for (int i = 0; i < outputs.length; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < dArr4.length; i2++) {
                d += dArr3[i2][i] * dArr4[i2];
            }
            double computeDerivated = activationFunction.computeDerivated(outputs[i]) * d;
            dArr5[i] = computeDerivated;
            double[] dArr7 = dArr[i];
            for (int i3 = 0; i3 < dArr7.length; i3++) {
                int i4 = i3;
                dArr7[i4] = dArr7[i4] + (computeDerivated * dArr6[i3]);
            }
            int i5 = i;
            dArr2[i5] = dArr2[i5] - computeDerivated;
        }
    }

    @Override // insane.training.method.AbstractTrainingMethod
    protected void initializeEpochTraining() throws TrainingException {
        double[][][] dArr = this.weightsGradients;
        double[][] dArr2 = this.biasesGradients;
        this.weightsGradients = this.previousWeightsGradients;
        this.biasesGradients = this.previousBiasesGradients;
        this.previousWeightsGradients = dArr;
        this.previousBiasesGradients = dArr2;
        for (int i = 0; i < this.weightsGradients.length; i++) {
            for (double[] dArr3 : this.weightsGradients[i]) {
                for (int i2 = 0; i2 < dArr3.length; i2++) {
                    dArr3[i2] = 0.0d;
                }
            }
        }
        for (int i3 = 0; i3 < this.biasesGradients.length; i3++) {
            double[] dArr4 = this.biasesGradients[i3];
            for (int i4 = 0; i4 < dArr4.length; i4++) {
                dArr4[i4] = 0.0d;
            }
        }
    }

    @Override // insane.training.method.AbstractTrainingMethod
    protected void terminateEpochTraining(ErrorInformation errorInformation) throws TrainingException {
        double positiveMomentum = this.configuration.getPositiveMomentum();
        double negativeMomentum = this.configuration.getNegativeMomentum();
        double deltaMin = this.configuration.getDeltaMin();
        double deltaMax = this.configuration.getDeltaMax();
        for (int i = 0; i < this.layers.length; i++) {
            NetworkLayer networkLayer = this.layers[i];
            double[][] weights = networkLayer.getWeights();
            double[][] dArr = this.previousWeightsGradients[i];
            double[][] dArr2 = this.weightsGradients[i];
            double[][] dArr3 = this.deltaWeightsUpdates[i];
            double[][] dArr4 = this.deltaWeights[i];
            for (int i2 = 0; i2 < weights.length; i2++) {
                double[] dArr5 = weights[i2];
                double[] dArr6 = dArr[i2];
                double[] dArr7 = dArr2[i2];
                double[] dArr8 = dArr4[i2];
                double[] dArr9 = dArr3[i2];
                double[] dArr10 = weights[i2];
                for (int i3 = 0; i3 < dArr5.length; i3++) {
                    double d = dArr6[i3] * dArr7[i3];
                    if (d > 0.0d) {
                        double d2 = dArr9[i3] * positiveMomentum;
                        if (d2 > deltaMax) {
                            d2 = deltaMax;
                        }
                        dArr9[i3] = d2;
                        double invSign = invSign(dArr7[i3], d2);
                        dArr8[i3] = invSign;
                        int i4 = i3;
                        dArr10[i4] = dArr10[i4] + invSign;
                    } else if (d == 0.0d) {
                        double invSign2 = invSign(dArr7[i3], dArr9[i3]);
                        dArr8[i3] = invSign2;
                        int i5 = i3;
                        dArr10[i5] = dArr10[i5] + invSign2;
                    } else {
                        double d3 = dArr9[i3] * negativeMomentum;
                        if (d3 < deltaMin) {
                            d3 = deltaMin;
                        }
                        dArr9[i3] = d3;
                        int i6 = i3;
                        dArr10[i6] = dArr10[i6] - dArr8[i3];
                        dArr7[i3] = 0.0d;
                    }
                }
            }
            double[] biases = networkLayer.getBiases();
            double[] dArr11 = this.previousBiasesGradients[i];
            double[] dArr12 = this.biasesGradients[i];
            double[] dArr13 = this.deltaBiasesUpdates[i];
            double[] dArr14 = this.deltaBiases[i];
            for (int i7 = 0; i7 < biases.length; i7++) {
                double d4 = dArr11[i7] * dArr12[i7];
                if (d4 > 0.0d) {
                    double d5 = dArr13[i7] * positiveMomentum;
                    if (d5 > deltaMax) {
                        d5 = deltaMax;
                    }
                    dArr13[i7] = d5;
                    double invSign3 = invSign(dArr12[i7], d5);
                    dArr14[i7] = invSign3;
                    int i8 = i7;
                    biases[i8] = biases[i8] + invSign3;
                } else if (d4 == 0.0d) {
                    double invSign4 = invSign(dArr12[i7], dArr13[i7]);
                    dArr14[i7] = invSign4;
                    int i9 = i7;
                    biases[i9] = biases[i9] + invSign4;
                } else {
                    double d6 = dArr13[i7] * negativeMomentum;
                    if (d6 < deltaMin) {
                        d6 = deltaMin;
                    }
                    dArr13[i7] = d6;
                    int i10 = i7;
                    biases[i10] = biases[i10] - dArr14[i7];
                    dArr12[i7] = 0.0d;
                }
            }
        }
    }

    private static double invSign(double d, double d2) {
        return d < 0.0d ? d2 : -d2;
    }

    @Override // insane.training.method.AbstractTrainingMethod
    protected TrainingResults terminateTraining(double d, int i) throws TrainingException {
        return new TrainingResults(d, i);
    }
}
