package insane.training.method;

import insane.NetworkLayer;
import insane.NeuralNetwork;
import insane.training.TrainingConstraints;
import insane.training.TrainingException;
import insane.training.TrainingInformation;
import insane.training.TrainingMethod;
import insane.training.TrainingResults;
import insane.training.io.TrainingInformationLoader;
import insane.training.io.TrainingInformationManager;
import insane.training.pruning.PruningMethod;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.Random;

/* loaded from: input_file:insane/training/method/AbstractTrainingMethod.class */
public abstract class AbstractTrainingMethod implements TrainingMethod {
    private Random rand;

    public AbstractTrainingMethod(Random random) {
        this.rand = random;
    }

    public final Random getRandom() {
        return this.rand;
    }

    public static final void permute(Random random, TrainingInformation[] trainingInformationArr, int i, int i2) {
        for (int i3 = i2 - 1; i3 > i; i3--) {
            TrainingInformation trainingInformation = trainingInformationArr[i3];
            int nextInt = random.nextInt(i3);
            trainingInformationArr[i3] = trainingInformationArr[nextInt];
            trainingInformationArr[nextInt] = trainingInformation;
        }
    }

    @Override // insane.training.TrainingMethod
    public final TrainingResults train(File file, int i, int i2, TrainingInformationLoader trainingInformationLoader, TrainingConstraints trainingConstraints, NeuralNetwork neuralNetwork, PruningMethod pruningMethod) throws IOException, TrainingException {
        double mse;
        boolean z = i2 > 0;
        if (i < 1 || (z && i >= i2)) {
            return train(trainingConstraints, i, neuralNetwork, pruningMethod, TrainingInformationManager.toArray(trainingInformationLoader.load(file, i2)));
        }
        if (!trainingInformationLoader.isSingleTrainingInformationLoadAllowed()) {
            throw new TrainingException("Loader does not allow single training information parsing");
        }
        TrainingMethodInformation init = init(trainingConstraints, neuralNetwork);
        double maxError = trainingConstraints.getMaxError();
        int maxEpochs = trainingConstraints.getMaxEpochs();
        boolean hasNoMaxError = init.hasNoMaxError();
        boolean hasNoMaxEpochs = init.hasNoMaxEpochs();
        NetworkLayer[] layers = neuralNetwork.getLayers();
        int i3 = 0;
        while (true) {
            ErrorInformation errorInformation = new ErrorInformation();
            i3++;
            int i4 = i;
            int i5 = i2;
            boolean z2 = true;
            BufferedReader bufferedReader = TrainingInformationLoader.getBufferedReader(file);
            initializeEpochTraining();
            do {
                if (z) {
                    if (i5 < i) {
                        i4 = i5;
                        z2 = false;
                    } else {
                        i5 -= i;
                    }
                }
                TrainingInformation[] array = TrainingInformationManager.toArray(trainingInformationLoader.load(bufferedReader, i4));
                if (z2 && array.length == 0) {
                    z2 = false;
                } else {
                    performPermutedBlockTraining(array, 0, array.length, errorInformation);
                }
            } while (z2);
            bufferedReader.close();
            terminateEpochTraining(errorInformation);
            mse = errorInformation.getMse();
            saveBest(init, mse, i3, layers);
            if ((hasNoMaxError || mse > maxError) && (hasNoMaxEpochs || i3 < maxEpochs)) {
            }
        }
        return restoreAndPruneBest(init, mse, i3, layers, pruningMethod);
    }

    @Override // insane.training.TrainingMethod
    public final TrainingResults train(TrainingConstraints trainingConstraints, int i, NeuralNetwork neuralNetwork, PruningMethod pruningMethod, TrainingInformation... trainingInformationArr) throws TrainingException {
        double mse;
        TrainingMethodInformation init = init(trainingConstraints, neuralNetwork);
        double maxError = trainingConstraints.getMaxError();
        int maxEpochs = trainingConstraints.getMaxEpochs();
        boolean hasNoMaxError = init.hasNoMaxError();
        boolean hasNoMaxEpochs = init.hasNoMaxEpochs();
        NetworkLayer[] layers = neuralNetwork.getLayers();
        int i2 = 0;
        if (i < 1 || i > trainingInformationArr.length) {
            i = trainingInformationArr.length;
        }
        ErrorInformation errorInformation = new ErrorInformation();
        while (true) {
            i2++;
            errorInformation.reset();
            initializeEpochTraining();
            int i3 = 0;
            int i4 = i;
            while (true) {
                int i5 = i4;
                if (i5 >= trainingInformationArr.length) {
                    break;
                }
                performPermutedBlockTraining(trainingInformationArr, i3, i5, errorInformation);
                i3 = i5;
                i4 = i5 + i;
            }
            performPermutedBlockTraining(trainingInformationArr, i3, trainingInformationArr.length, errorInformation);
            terminateEpochTraining(errorInformation);
            mse = errorInformation.getMse();
            saveBest(init, mse, i2, layers);
            if ((hasNoMaxError || mse > maxError) && (hasNoMaxEpochs || i2 < maxEpochs)) {
            }
        }
        return restoreAndPruneBest(init, mse, i2, layers, pruningMethod);
    }

    private void performPermutedBlockTraining(TrainingInformation[] trainingInformationArr, int i, int i2, ErrorInformation errorInformation) throws TrainingException {
        permute(this.rand, trainingInformationArr, i, i2);
        train(trainingInformationArr[i], errorInformation);
        for (int i3 = i + 1; i3 < i2; i3++) {
            train(trainingInformationArr[i3], errorInformation);
        }
    }

    protected final TrainingMethodInformation init(TrainingConstraints trainingConstraints, NeuralNetwork neuralNetwork) throws TrainingException {
        boolean z = trainingConstraints.getMaxError() <= 0.0d;
        boolean z2 = trainingConstraints.getMaxEpochs() < 1;
        if (z2 && z) {
            throw new TrainingException("Invalid training constraints");
        }
        NetworkLayer[] layers = neuralNetwork.getLayers();
        if (layers == null || layers.length == 0) {
            throw new TrainingException("Neural network must contain at least one layer");
        }
        TrainingMethodInformation trainingMethodInformation = new TrainingMethodInformation(z, z2, layers);
        initializeTraining(trainingConstraints, neuralNetwork);
        return trainingMethodInformation;
    }

    private void saveBest(TrainingMethodInformation trainingMethodInformation, double d, int i, NetworkLayer[] networkLayerArr) {
        if (!trainingMethodInformation.hasNoMaxError() || d >= trainingMethodInformation.getBestMse()) {
            return;
        }
        trainingMethodInformation.setBestMse(d);
        trainingMethodInformation.setBestEpoch(i);
        double[][][] bestWeights = trainingMethodInformation.getBestWeights();
        double[][] bestBiases = trainingMethodInformation.getBestBiases();
        for (int i2 = 0; i2 < networkLayerArr.length; i2++) {
            NetworkLayer networkLayer = networkLayerArr[i2];
            updateValues(networkLayer.getWeights(), bestWeights[i2], networkLayer.getBiases(), bestBiases[i2]);
        }
    }

    private TrainingResults restoreAndPruneBest(TrainingMethodInformation trainingMethodInformation, double d, int i, NetworkLayer[] networkLayerArr, PruningMethod pruningMethod) throws TrainingException {
        if (!trainingMethodInformation.hasNoMaxError()) {
            trainingMethodInformation.setBestEpoch(i);
            trainingMethodInformation.setBestMse(d);
        } else if (d != trainingMethodInformation.getBestMse()) {
            double[][][] bestWeights = trainingMethodInformation.getBestWeights();
            double[][] bestBiases = trainingMethodInformation.getBestBiases();
            for (int i2 = 0; i2 < networkLayerArr.length; i2++) {
                NetworkLayer networkLayer = networkLayerArr[i2];
                updateValues(bestWeights[i2], networkLayer.getWeights(), bestBiases[i2], networkLayer.getBiases());
            }
        }
        if (pruningMethod != null) {
            pruningMethod.prune(networkLayerArr);
        }
        return terminateTraining(trainingMethodInformation.getBestMse(), trainingMethodInformation.getBestEpoch());
    }

    private static void updateValues(double[][] dArr, double[][] dArr2, double[] dArr3, double[] dArr4) {
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr5 = dArr[i];
            System.arraycopy(dArr5, 0, dArr2[i], 0, dArr5.length);
        }
        System.arraycopy(dArr3, 0, dArr4, 0, dArr3.length);
    }

    protected abstract void initializeTraining(TrainingConstraints trainingConstraints, NeuralNetwork neuralNetwork) throws TrainingException;

    protected abstract void initializeEpochTraining() throws TrainingException;

    protected abstract void train(TrainingInformation trainingInformation, ErrorInformation errorInformation) throws TrainingException;

    protected abstract void terminateEpochTraining(ErrorInformation errorInformation) throws TrainingException;

    protected abstract TrainingResults terminateTraining(double d, int i) throws TrainingException;
}
