/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public final class CLALibTSMM {
    private static final Log LOG = LogFactory.getLog((String)CLALibTSMM.class.getName());

    private CLALibTSMM() {
    }

    public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
        List<AColGroup> groups = cmb.getColGroups();
        int numColumns = cmb.getNumColumns();
        if (groups.size() >= numColumns) {
            MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k);
            LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k);
            return;
        }
        int numRows = cmb.getNumRows();
        boolean shouldFilter = CLALibUtils.shouldPreFilter(groups);
        boolean overlapping = cmb.isOverlapping();
        if (shouldFilter) {
            double[] constV = new double[numColumns];
            List<AColGroup> filteredGroups = CLALibUtils.filterGroups(groups, constV);
            CLALibTSMM.tsmmColGroups(filteredGroups, ret, numRows, overlapping, k);
            CLALibTSMM.addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV);
        } else {
            CLALibTSMM.tsmmColGroups(groups, ret, numRows, overlapping, k);
        }
        ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret));
        ret.examSparsity();
    }

    private static void addCorrectionLayer(List<AColGroup> filteredGroups, MatrixBlock result, int nRows, int nCols, double[] constV) {
        double[] retV = result.getDenseBlockValues();
        double[] filteredColSum = CLALibUtils.getColSum(filteredGroups, nCols, nRows);
        CLALibTSMM.addCorrectionLayer(constV, filteredColSum, nRows, retV);
    }

    private static void tsmmColGroups(List<AColGroup> groups, MatrixBlock ret, int nRows, boolean overlapping, int k) {
        if (k <= 1) {
            CLALibTSMM.tsmmColGroupsSingleThread(groups, ret, nRows);
        } else if (overlapping) {
            CLALibTSMM.tsmmColGroupsMultiThreadOverlapping(groups, ret, nRows, k);
        } else {
            CLALibTSMM.tsmmColGroupsMultiThread(groups, ret, nRows, k);
        }
    }

    private static void tsmmColGroupsSingleThread(List<AColGroup> groups, MatrixBlock ret, int nRows) {
        for (int i = 0; i < groups.size(); ++i) {
            AColGroup g = groups.get(i);
            g.tsmm(ret, nRows);
            for (int j = i + 1; j < groups.size(); ++j) {
                AColGroup h = groups.get(j);
                g.tsmmAColGroup(h, ret);
            }
        }
    }

    private static void tsmmColGroupsMultiThreadOverlapping(List<AColGroup> groups, MatrixBlock ret, int nRows, int k) {
        LOG.warn((Object)"fallback to single threaded for now");
        CLALibTSMM.tsmmColGroupsSingleThread(groups, ret, nRows);
    }

    private static void tsmmColGroupsMultiThread(List<AColGroup> groups, MatrixBlock ret, int nRows, int k) {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList<Callable<MatrixBlock>> tasks = new ArrayList<Callable<MatrixBlock>>(groups.size() * (1 + groups.size()) / 2);
            for (int i = 0; i < groups.size(); ++i) {
                AColGroup g = groups.get(i);
                tasks.add(new TSMMTask(g, ret, nRows));
                for (int j = i + 1; j < groups.size(); ++j) {
                    tasks.add(new TSMMColGroupTask(g, groups.get(j), ret));
                }
            }
            for (Future future : pool.invokeAll(tasks)) {
                future.get();
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
    }

    public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) {
        int nColRow = constV.length;
        for (int row = 0; row < nColRow; ++row) {
            int offOut = nColRow * row;
            double v1l = constV[row];
            double v2l = filteredColSum[row] + constV[row] * (double)nRow;
            for (int col = row; col < nColRow; ++col) {
                int n = offOut + col;
                ret[n] = ret[n] + (v1l * filteredColSum[col] + v2l * constV[col]);
            }
        }
    }

    private static class TSMMColGroupTask
    implements Callable<MatrixBlock> {
        private final AColGroup _g;
        private final AColGroup _h;
        private final MatrixBlock _ret;

        protected TSMMColGroupTask(AColGroup g, AColGroup h, MatrixBlock ret) {
            this._g = g;
            this._h = h;
            this._ret = ret;
        }

        @Override
        public MatrixBlock call() {
            try {
                this._g.tsmmAColGroup(this._h, this._ret);
                return this._ret;
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
    }

    private static class TSMMTask
    implements Callable<MatrixBlock> {
        private final AColGroup _g;
        private final MatrixBlock _ret;
        private final int _nRows;

        protected TSMMTask(AColGroup g, MatrixBlock ret, int nRows) {
            this._g = g;
            this._ret = ret;
            this._nRows = nRows;
        }

        @Override
        public MatrixBlock call() {
            try {
                this._g.tsmm(this._ret, this._nRows);
                return this._ret;
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
    }
}

