package cern.colt.matrix.tfloat.algo.solver.preconditioner;

import cern.colt.matrix.Norm;
import cern.colt.matrix.tfloat.FloatMatrix1D;
import cern.colt.matrix.tfloat.FloatMatrix2D;
import cern.colt.matrix.tfloat.algo.DenseFloatAlgebra;
import cern.colt.matrix.tfloat.algo.FloatProperty;
import cern.colt.matrix.tfloat.impl.DenseFloatMatrix1D;
import cern.colt.matrix.tfloat.impl.SparseFloatMatrix1D;
import cern.colt.matrix.tfloat.impl.SparseRCMFloatMatrix2D;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:cern/colt/matrix/tfloat/algo/solver/preconditioner/FloatILUT.class */
public class FloatILUT implements FloatPreconditioner {
    private SparseRCMFloatMatrix2D LU;
    private final FloatMatrix1D y;
    private final float tau;
    private final List<IntFloatEntry> lower;
    private final List<IntFloatEntry> upper;
    private final int p;
    private final int n;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:cern/colt/matrix/tfloat/algo/solver/preconditioner/FloatILUT$IntFloatEntry.class */
    public static class IntFloatEntry implements Comparable<IntFloatEntry> {
        public int index;
        public float value;

        public IntFloatEntry(int i, float f) {
            this.index = i;
            this.value = f;
        }

        @Override // java.lang.Comparable
        public int compareTo(IntFloatEntry intFloatEntry) {
            if (Math.abs(this.value) < Math.abs(intFloatEntry.value)) {
                return 1;
            }
            return Math.abs(this.value) == Math.abs(intFloatEntry.value) ? 0 : -1;
        }

        public String toString() {
            return "(" + this.index + "=" + this.value + ")";
        }
    }

    public FloatILUT(int i, float f, int i2) {
        this.n = i;
        this.tau = f;
        this.p = i2;
        this.lower = new ArrayList(i);
        this.upper = new ArrayList(i);
        this.y = new DenseFloatMatrix1D(i);
    }

    public FloatILUT(int i) {
        this(i, 1.0E-6f, 25);
    }

    @Override // cern.colt.matrix.tfloat.algo.solver.preconditioner.FloatPreconditioner
    public FloatMatrix1D apply(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        if (floatMatrix1D2 == null) {
            floatMatrix1D2 = floatMatrix1D.like();
        }
        unitLowerSolve(floatMatrix1D, this.y);
        return upperSolve(this.y, floatMatrix1D2);
    }

    @Override // cern.colt.matrix.tfloat.algo.solver.preconditioner.FloatPreconditioner
    public FloatMatrix1D transApply(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        if (floatMatrix1D2 == null) {
            floatMatrix1D2 = floatMatrix1D.like();
        }
        upperTransSolve(floatMatrix1D, this.y);
        return unitLowerTransSolve(this.y, floatMatrix1D2);
    }

    @Override // cern.colt.matrix.tfloat.algo.solver.preconditioner.FloatPreconditioner
    public void setMatrix(FloatMatrix2D floatMatrix2D) {
        FloatProperty.DEFAULT.isSquare(floatMatrix2D);
        if (floatMatrix2D.rows() != this.n) {
            throw new IllegalArgumentException("A.rows() != n");
        }
        this.LU = new SparseRCMFloatMatrix2D(this.n, this.n);
        this.LU.assign(floatMatrix2D);
        this.LU.trimToSize();
        factor();
    }

    private void factor() {
        int rows = this.LU.rows();
        for (int i = 1; i < rows; i++) {
            SparseFloatMatrix1D viewRow = this.LU.viewRow(i);
            float norm = DenseFloatAlgebra.DEFAULT.norm(viewRow, Norm.Two) * this.tau;
            for (int i2 = 0; i2 < i; i2++) {
                SparseFloatMatrix1D viewRow2 = this.LU.viewRow(i2);
                if (viewRow2.getQuick(i2) == 0.0f) {
                    throw new RuntimeException("Zero diagonal entry on row " + (i2 + 1) + " during ILU process");
                }
                float quick = viewRow.getQuick(i2) / viewRow2.getQuick(i2);
                if (Math.abs(quick) > norm) {
                    int size = (int) viewRow2.size();
                    for (int i3 = i2 + 1; i3 < size; i3++) {
                        viewRow.setQuick(i3, viewRow.getQuick(i3) - (quick * viewRow2.getQuick(i3)));
                    }
                    viewRow.setQuick(i2, quick);
                }
            }
            gather(viewRow, norm, i);
        }
    }

    private void gather(SparseFloatMatrix1D sparseFloatMatrix1D, float f, int i) {
        int i2 = 0;
        int i3 = 0;
        long[] elements = sparseFloatMatrix1D.elements().keys().elements();
        for (int i4 = 0; i4 < elements.length; i4++) {
            if (elements[i4] < i) {
                i2++;
            } else if (elements[i4] > i) {
                i3++;
            }
        }
        float[] array = sparseFloatMatrix1D.toArray();
        sparseFloatMatrix1D.assign(0.0f);
        this.lower.clear();
        for (int i5 = 0; i5 < i; i5++) {
            if (Math.abs(array[i5]) > f) {
                this.lower.add(new IntFloatEntry(i5, array[i5]));
            }
        }
        this.upper.clear();
        for (int i6 = i + 1; i6 < array.length; i6++) {
            if (Math.abs(array[i6]) > f) {
                this.upper.add(new IntFloatEntry(i6, array[i6]));
            }
        }
        Collections.sort(this.lower);
        Collections.sort(this.upper);
        sparseFloatMatrix1D.setQuick(i, array[i]);
        for (int i7 = 0; i7 < Math.min(i2 + this.p, this.lower.size()); i7++) {
            IntFloatEntry intFloatEntry = this.lower.get(i7);
            sparseFloatMatrix1D.setQuick(intFloatEntry.index, intFloatEntry.value);
        }
        for (int i8 = 0; i8 < Math.min(i3 + this.p, this.upper.size()); i8++) {
            IntFloatEntry intFloatEntry2 = this.upper.get(i8);
            sparseFloatMatrix1D.setQuick(intFloatEntry2.index, intFloatEntry2.value);
        }
    }

    private FloatMatrix1D unitLowerSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D).elements();
        float[] elements2 = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        int rows = this.LU.rows();
        for (int i = 0; i < rows; i++) {
            SparseFloatMatrix1D viewRow = this.LU.viewRow(i);
            float f = 0.0f;
            for (int i2 = 0; i2 < i; i2++) {
                f += viewRow.getQuick(i2) * elements2[i2];
            }
            elements2[i] = elements[i] - f;
        }
        return floatMatrix1D2;
    }

    private FloatMatrix1D unitLowerTransSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        floatMatrix1D2.assign(floatMatrix1D);
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        for (int rows = this.LU.rows() - 1; rows >= 0; rows--) {
            SparseFloatMatrix1D viewRow = this.LU.viewRow(rows);
            for (int i = 0; i < rows; i++) {
                int i2 = i;
                elements[i2] = elements[i2] - (viewRow.getQuick(i) * elements[rows]);
            }
        }
        return floatMatrix1D2;
    }

    private FloatMatrix1D upperSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D).elements();
        float[] elements2 = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        for (int rows = this.LU.rows() - 1; rows >= 0; rows--) {
            SparseFloatMatrix1D viewRow = this.LU.viewRow(rows);
            int size = (int) viewRow.size();
            float f = 0.0f;
            for (int i = rows + 1; i < size; i++) {
                f += viewRow.getQuick(i) * elements2[i];
            }
            elements2[rows] = (elements[rows] - f) / viewRow.getQuick(rows);
        }
        return floatMatrix1D2;
    }

    private FloatMatrix1D upperTransSolve(FloatMatrix1D floatMatrix1D, FloatMatrix1D floatMatrix1D2) {
        floatMatrix1D2.assign(floatMatrix1D);
        float[] elements = ((DenseFloatMatrix1D) floatMatrix1D2).elements();
        int rows = this.LU.rows();
        for (int i = 0; i < rows; i++) {
            SparseFloatMatrix1D viewRow = this.LU.viewRow(i);
            int size = (int) viewRow.size();
            int i2 = i;
            elements[i2] = elements[i2] / viewRow.getQuick(i);
            for (int i3 = i + 1; i3 < size; i3++) {
                int i4 = i3;
                elements[i4] = elements[i4] - (viewRow.getQuick(i3) * elements[i]);
            }
        }
        return floatMatrix1D2;
    }
}
