package de.unijena.bioinf.canopus.dnn;

import de.unijena.bioinf.canopus.dnn.ActivationFunction;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import org.ejml.data.FMatrixRMaj;
import org.ejml.dense.row.CommonOps_FDRM;

/* loaded from: input_file:de/unijena/bioinf/canopus/dnn/FullyConnectedLayer.class */
public class FullyConnectedLayer {
    protected FMatrixRMaj W;
    protected float[] B;
    protected ActivationFunction activationFunction;

    public FullyConnectedLayer(float[][] fArr, float[] fArr2, ActivationFunction activationFunction) {
        this.W = new FMatrixRMaj(fArr);
        this.B = fArr2;
        this.activationFunction = activationFunction;
    }

    public FullyConnectedLayer(int i, int i2, float[] fArr, float[] fArr2, ActivationFunction activationFunction) {
        this.W = new FMatrixRMaj(i, i2, true, fArr);
        this.B = fArr2;
        this.activationFunction = activationFunction;
    }

    public float[] getWeightMatrixCopy() {
        return this.W.data;
    }

    public float[] getBiasVectorCopy() {
        return this.B;
    }

    public int getInputSize() {
        return this.W.numRows;
    }

    public int getOutputSize() {
        return this.B.length;
    }

    public void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public FMatrixRMaj eval(FMatrixRMaj fMatrixRMaj) {
        float[] fArr = new float[this.B.length * fMatrixRMaj.numRows];
        int i = 0;
        for (int i2 = 0; i2 < fMatrixRMaj.numRows; i2++) {
            System.arraycopy(this.B, 0, fArr, i, this.B.length);
            i += this.B.length;
        }
        FMatrixRMaj wrap = FMatrixRMaj.wrap(fMatrixRMaj.numRows, this.B.length, fArr);
        CommonOps_FDRM.multAdd(fMatrixRMaj, this.W, wrap);
        this.activationFunction.eval(fArr);
        return wrap;
    }

    public void dump(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(this.activationFunction instanceof ActivationFunction.Identity ? 0 : this.activationFunction instanceof ActivationFunction.Tanh ? 1 : this.activationFunction instanceof ActivationFunction.ReLu ? 2 : this.activationFunction instanceof ActivationFunction.SELU ? 3 : 1000);
        objectOutputStream.writeInt(this.W.numCols);
        objectOutputStream.writeInt(this.W.numRows);
        int i = this.W.numCols * this.W.numRows;
        for (int i2 = 0; i2 < i; i2++) {
            objectOutputStream.writeFloat(this.W.data[i2]);
        }
        for (int i3 = 0; i3 < this.B.length; i3++) {
            objectOutputStream.writeFloat(this.B[i3]);
        }
    }

    public static FullyConnectedLayer load(ObjectInputStream objectInputStream) throws IOException {
        ActivationFunction selu;
        int readInt = objectInputStream.readInt();
        if (readInt == 0) {
            selu = new ActivationFunction.Identity();
        } else if (readInt == 1) {
            selu = new ActivationFunction.Tanh();
        } else if (readInt == 2) {
            selu = new ActivationFunction.ReLu();
        } else {
            if (readInt != 3) {
                throw new IllegalArgumentException("Unknown activation function with code " + readInt);
            }
            selu = new ActivationFunction.SELU();
        }
        int readInt2 = objectInputStream.readInt();
        int readInt3 = objectInputStream.readInt();
        float[] fArr = new float[readInt2 * readInt3];
        int i = readInt2 * readInt3;
        for (int i2 = 0; i2 < i; i2++) {
            fArr[i2] = objectInputStream.readFloat();
        }
        float[] fArr2 = new float[readInt2];
        for (int i3 = 0; i3 < readInt2; i3++) {
            fArr2[i3] = objectInputStream.readFloat();
        }
        return new FullyConnectedLayer(readInt3, readInt2, fArr, fArr2, selu);
    }

    public String toString() {
        String str = "W[" + this.W.numRows + "," + this.W.numCols + "]x + B[" + this.B.length + "]";
        return this.activationFunction instanceof ActivationFunction.Identity ? str : this.activationFunction instanceof ActivationFunction.Tanh ? "tanh( " + str + " )" : "f( " + this.activationFunction.getClass().getName() + ", " + str + ")";
    }
}
