/*
 * Decompiled with CFR 0.152.
 */
package com.neuro.pack.network;

import com.neuro.pack.enums.ParamsType;
import com.neuro.pack.intefaces.OnNetworkParamsListener;
import com.neuro.pack.intefaces.OnUiParamsListener;
import com.neuro.pack.network.NeuralNetwork;
import com.neuro.pack.network.TrainModel;
import com.neuro.pack.network.TrainRead;
import com.neuro.pack.ui.buttons.TypeClick;
import java.awt.Color;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

public class MainNetwork
implements OnUiParamsListener {
    private int inputNodes;
    private int hiddenNodes;
    private int outputNodes;
    private int epoch;
    private double learningRate;
    private double[] testData;
    private NeuralNetwork network;
    private TrainRead trainRead;
    private OnNetworkParamsListener listener;

    public MainNetwork() {
        this.init();
    }

    private void init() {
        this.inputNodes = 784;
        this.hiddenNodes = 50;
        this.outputNodes = 10;
        this.epoch = 1;
        this.learningRate = Double.parseDouble("0.28");
    }

    @Override
    public void onUiParams(ParamsType type, int result) {
        switch (type) {
            case PARAMS_HIDDEN: {
                this.hiddenNodes = result;
                break;
            }
            case PARAMS_EPOCH: {
                this.epoch = result;
                break;
            }
            case PARAMS_RATIO: {
                this.learningRate = Double.parseDouble("0." + result);
            }
        }
    }

    @Override
    public void onUiImage(BufferedImage icon) {
        this.testData = this.bufImage(icon);
    }

    private double[] bufImage(BufferedImage icon) {
        try {
            int width = icon.getWidth();
            int height = icon.getHeight();
            double[] res = new double[width * height];
            int n = 0;
            for (int i = 0; i < height; ++i) {
                for (int j = 0; j < width; ++j) {
                    Color c = new Color(icon.getRGB(j, i));
                    res[n] = 255 - c.getRed();
                    ++n;
                }
            }
            return res;
        }
        catch (Exception e) {
            this.listener.onErrorImage();
            return new double[0];
        }
    }

    @Override
    public void onUiClick(TypeClick type) {
        switch (type) {
            case LEARNING_BTN: {
                new Thread(this::learningNetwork).start();
                break;
            }
            case DEFINING_BTN: {
                this.definingNetwork();
                break;
            }
            case RESET_BTN: {
                this.resetNetwork();
            }
        }
    }

    private void resetNetwork() {
        this.network = null;
        this.trainRead = null;
        this.hiddenNodes = 50;
        this.epoch = 1;
        this.learningRate = Double.parseDouble("0.28");
    }

    private void definingNetwork() {
        if (this.network != null) {
            double[] output = this.network.query(Arrays.copyOfRange(this.testData, 1, this.testData.length));
            int answer = this.maxIndex(output);
            this.listener.onResultNetwork(output);
            this.listener.onAnswer(answer);
        }
    }

    private int maxIndex(double[] arr) {
        double max = 0.0;
        int max_index = 0;
        for (int i = 0; i < arr.length; ++i) {
            if (!(arr[i] > max)) continue;
            max = arr[i];
            max_index = i;
        }
        return max_index;
    }

    private void learningNetwork() {
        this.network = new NeuralNetwork(this.inputNodes, this.hiddenNodes, this.outputNodes, this.learningRate);
        this.trainRead = new TrainRead();
        this.listener.onOffLearningBtn();
        try {
            List<TrainModel> models = this.trainRead.getBin();
            double onePercent = 100.0 / (double)(models.size() * this.epoch);
            int current = 1;
            for (int i = 0; i < this.epoch; ++i) {
                for (TrainModel model : models) {
                    this.listener.onCurrentProgress((int)(onePercent * (double)current));
                    double[] biased = this.biasedInputValues(model.getBuffer());
                    double[] target = this.fillArray(this.outputNodes);
                    target[model.getNumber()] = 0.99;
                    this.network.train(biased, target);
                    ++current;
                }
            }
            this.listener.onNetworkTrained();
        }
        catch (IOException e) {
            this.listener.onNotTrainFile();
        }
    }

    private double[] fillArray(int count) {
        double[] res = new double[count];
        for (int i = 0; i < count; ++i) {
            res[i] = 0.01;
        }
        return res;
    }

    private double[] biasedInputValues(int[] data) {
        double[] res = new double[data.length];
        for (int i = 0; i < data.length; ++i) {
            res[i] = (double)data[i] / 255.0 * 0.99 + 0.01;
        }
        return res;
    }

    public void setOnNetworkParamsListener(OnNetworkParamsListener netListener) {
        this.listener = netListener;
    }
}

