/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuaternaryWCeMMFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuaternaryWDivMMFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuaternaryWSLossFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuaternaryWSigmoidFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuaternaryWUMMFEDInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;

public abstract class QuaternaryFEDInstruction
extends ComputationFEDInstruction {
    protected CPOperand _input4 = null;

    protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, Operator operator, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String instruction_str) {
        super(type, operator, in1, in2, in3, out, opcode, instruction_str);
    }

    protected QuaternaryFEDInstruction(FEDInstruction.FEDType type, Operator operator, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String instruction_str) {
        super(type, operator, in1, in2, in3, out, opcode, instruction_str);
        this._input4 = in4;
    }

    public static QuaternaryFEDInstruction parseInstruction(QuaternaryCPInstruction inst, ExecutionContext ec) {
        Data data = ec.getVariable(inst.input1);
        if (data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FTypes.FType.BROADCAST)) {
            return QuaternaryFEDInstruction.parseInstruction(inst);
        }
        return null;
    }

    public static QuaternaryFEDInstruction parseInstruction(QuaternarySPInstruction inst, ExecutionContext ec) {
        Data data = ec.getVariable(inst.input1);
        if (data instanceof MatrixObject && ((MatrixObject)data).isFederated()) {
            return QuaternaryFEDInstruction.parseInstruction(inst);
        }
        return null;
    }

    private static QuaternaryFEDInstruction parseInstruction(QuaternaryCPInstruction instr) {
        QuaternaryOperator qop = (QuaternaryOperator)instr.getOperator();
        if (qop.wtype1 != null) {
            return QuaternaryWSLossFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype2 != null) {
            return QuaternaryWSigmoidFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype3 != null) {
            return QuaternaryWDivMMFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype4 != null) {
            return QuaternaryWCeMMFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype5 != null) {
            return QuaternaryWUMMFEDInstruction.parseInstruction(instr);
        }
        return null;
    }

    private static QuaternaryFEDInstruction parseInstruction(QuaternarySPInstruction instr) {
        QuaternaryOperator qop = (QuaternaryOperator)instr.getOperator();
        if (qop.wtype1 != null) {
            return QuaternaryWSLossFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype2 != null) {
            return QuaternaryWSigmoidFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype3 != null) {
            return QuaternaryWDivMMFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype4 != null) {
            return QuaternaryWCeMMFEDInstruction.parseInstruction(instr);
        }
        if (qop.wtype5 != null) {
            return QuaternaryWUMMFEDInstruction.parseInstruction(instr);
        }
        return null;
    }

    public static QuaternaryFEDInstruction parseInstruction(String str) {
        String[] parts;
        String opcode;
        if (str.startsWith(Types.ExecType.SPARK.name())) {
            str = QuaternaryFEDInstruction.rewriteSparkInstructionToCP(str);
        }
        int addInput4 = (opcode = (parts = InstructionUtils.getInstructionPartsWithValueType(str))[0]).equals("wcemm") || opcode.equals("wsloss") || opcode.equals("wdivmm") ? 1 : 0;
        int addUOpcode = opcode.equals("wumm") ? 1 : 0;
        InstructionUtils.checkNumFields(parts, 6 + addInput4 + addUOpcode);
        CPOperand in1 = new CPOperand(parts[1 + addUOpcode]);
        CPOperand in2 = new CPOperand(parts[2 + addUOpcode]);
        CPOperand in3 = new CPOperand(parts[3 + addUOpcode]);
        CPOperand out = new CPOperand(parts[4 + addInput4 + addUOpcode]);
        QuaternaryFEDInstruction.checkDataTypes(Types.DataType.MATRIX, in1, in2, in3);
        QuaternaryOperator qop = null;
        if (addInput4 == 1) {
            CPOperand in4 = new CPOperand(parts[4]);
            if (opcode.equals("wcemm")) {
                WeightedCrossEntropy.WCeMMType wcemm_type = WeightedCrossEntropy.WCeMMType.valueOf(parts[6]);
                if (wcemm_type.hasFourInputs()) {
                    QuaternaryFEDInstruction.checkDataTypes(new Types.DataType[]{Types.DataType.SCALAR, Types.DataType.MATRIX}, in4);
                }
                qop = wcemm_type.hasFourInputs() ? new QuaternaryOperator(wcemm_type, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wcemm_type);
                return new QuaternaryWCeMMFEDInstruction(qop, in1, in2, in3, in4, out, opcode, str);
            }
            if (opcode.equals("wdivmm")) {
                WeightedDivMM.WDivMMType wdivmm_type = WeightedDivMM.WDivMMType.valueOf(parts[6]);
                if (wdivmm_type.hasFourInputs()) {
                    QuaternaryFEDInstruction.checkDataTypes(new Types.DataType[]{Types.DataType.SCALAR, Types.DataType.MATRIX}, in4);
                }
                qop = new QuaternaryOperator(wdivmm_type);
                return new QuaternaryWDivMMFEDInstruction(qop, in1, in2, in3, in4, out, opcode, str);
            }
            if (opcode.equals("wsloss")) {
                WeightedSquaredLoss.WeightsType weights_type = WeightedSquaredLoss.WeightsType.valueOf(parts[6]);
                if (weights_type.hasFourInputs()) {
                    QuaternaryFEDInstruction.checkDataTypes(Types.DataType.MATRIX, in4);
                }
                qop = new QuaternaryOperator(weights_type);
                return new QuaternaryWSLossFEDInstruction(qop, in1, in2, in3, in4, out, opcode, str);
            }
        } else {
            if (opcode.equals("wsigmoid")) {
                WeightedSigmoid.WSigmoidType wsigmoid_type = WeightedSigmoid.WSigmoidType.valueOf(parts[5]);
                qop = new QuaternaryOperator(wsigmoid_type);
                return new QuaternaryWSigmoidFEDInstruction(qop, in1, in2, in3, out, opcode, str);
            }
            if (opcode.equals("wumm")) {
                WeightedUnaryMM.WUMMType wumm_type = WeightedUnaryMM.WUMMType.valueOf(parts[6]);
                String uopcode = parts[1];
                qop = new QuaternaryOperator(wumm_type, uopcode);
                return new QuaternaryWUMMFEDInstruction(qop, in1, in2, in3, out, opcode, str);
            }
        }
        throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for QuaternaryFEDInstruction.");
    }

    protected static void checkDataTypes(Types.DataType data_type, CPOperand ... cp_operands) {
        QuaternaryFEDInstruction.checkDataTypes(new Types.DataType[]{data_type}, cp_operands);
    }

    protected static void checkDataTypes(Types.DataType[] data_types, CPOperand ... cp_operands) {
        for (CPOperand cpo : cp_operands) {
            if (QuaternaryFEDInstruction.checkDataType(data_types, cpo)) continue;
            throw new DMLRuntimeException("Federated quaternary operations only supported with matrix inputs and scalar epsilon.");
        }
    }

    private static boolean checkDataType(Types.DataType[] data_types, CPOperand cp_operand) {
        for (Types.DataType dt : data_types) {
            if (cp_operand.getDataType() != dt) continue;
            return true;
        }
        return false;
    }

    protected static String rewriteSparkInstructionToCP(String inst_str) {
        if (((String)(inst_str = ((String)inst_str).replace(Types.ExecType.SPARK.name(), Types.ExecType.CP.name()))).contains("mapwcemm")) {
            inst_str = ((String)inst_str).replace("mapwcemm", "wcemm");
        } else if (((String)inst_str).contains("mapwdivmm")) {
            inst_str = ((String)inst_str).replace("mapwdivmm", "wdivmm");
        } else if (((String)inst_str).contains("mapwsigmoid")) {
            inst_str = ((String)inst_str).replace("mapwsigmoid", "wsigmoid");
        } else if (((String)inst_str).contains("mapwsloss")) {
            inst_str = ((String)inst_str).replace("mapwsloss", "wsloss");
        } else if (((String)inst_str).contains("mapwumm")) {
            inst_str = ((String)inst_str).replace("mapwumm", "wumm");
        } else if (((String)inst_str).contains("redwdivmm") || ((String)inst_str).contains("redwsloss")) {
            inst_str = ((String)inst_str).replace("redwdivmm", "wdivmm");
            inst_str = ((String)inst_str).replace("redwsloss", "wsloss");
            inst_str = ((String)inst_str).replace("\u00b0true", "");
            inst_str = ((String)inst_str).replace("\u00b0false", "");
        }
        inst_str = (String)inst_str + "\u00b01";
        return inst_str;
    }

    protected void setOutputDataCharacteristics(MatrixObject X, MatrixObject U, MatrixObject V, ExecutionContext ec) {
        long rows;
        long l = rows = X.getNumRows() > 1L ? X.getNumRows() : U.getNumRows();
        long cols = X.getNumColumns() > 1L ? X.getNumColumns() : (U.getNumColumns() == V.getNumRows() ? V.getNumColumns() : V.getNumRows());
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(rows, cols, X.getBlocksize());
    }
}

