/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.fedplanner;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.AFederatedPlanner;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

public class FederatedPlannerFedAll
extends AFederatedPlanner {
    @Override
    public void rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        HashMap<String, FTypes.FType> fedVars = new HashMap<String, FTypes.FType>();
        for (StatementBlock sb : prog.getStatementBlocks()) {
            this.rRewriteStatementBlock(sb, fedVars);
        }
    }

    @Override
    public void rewriteFunctionDynamic(FunctionStatementBlock function, LocalVariableMap funcArgs) {
        HashMap<String, FTypes.FType> fedVars = new HashMap<String, FTypes.FType>();
        for (Map.Entry<String, Data> varName : funcArgs.entrySet()) {
            Data data = varName.getValue();
            FTypes.FType fType = null;
            if (data instanceof CacheableData && ((CacheableData)data).isFederated()) {
                fType = ((CacheableData)data).getFedMapping().getType();
            }
            fedVars.put(varName.getKey(), fType);
        }
        this.rRewriteStatementBlock(function, fedVars);
    }

    private void rRewriteStatementBlock(StatementBlock sb, Map<String, FTypes.FType> fedVars) {
        block9: {
            block12: {
                block11: {
                    block10: {
                        block8: {
                            if (!(sb instanceof FunctionStatementBlock)) break block8;
                            FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
                            FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
                            for (StatementBlock csb : fstmt.getBody()) {
                                this.rRewriteStatementBlock(csb, fedVars);
                            }
                            break block9;
                        }
                        if (!(sb instanceof WhileStatementBlock)) break block10;
                        WhileStatementBlock wsb = (WhileStatementBlock)sb;
                        WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
                        this.rRewriteHop(wsb.getPredicateHops(), new HashMap<Long, FTypes.FType>(), Collections.emptyMap());
                        for (StatementBlock csb : wstmt.getBody()) {
                            this.rRewriteStatementBlock(csb, fedVars);
                        }
                        break block9;
                    }
                    if (!(sb instanceof IfStatementBlock)) break block11;
                    IfStatementBlock isb = (IfStatementBlock)sb;
                    IfStatement istmt = (IfStatement)isb.getStatement(0);
                    this.rRewriteHop(isb.getPredicateHops(), new HashMap<Long, FTypes.FType>(), Collections.emptyMap());
                    for (StatementBlock csb : istmt.getIfBody()) {
                        this.rRewriteStatementBlock(csb, fedVars);
                    }
                    for (StatementBlock csb : istmt.getElseBody()) {
                        this.rRewriteStatementBlock(csb, fedVars);
                    }
                    break block9;
                }
                if (!(sb instanceof ForStatementBlock)) break block12;
                ForStatementBlock fsb = (ForStatementBlock)sb;
                ForStatement fstmt = (ForStatement)fsb.getStatement(0);
                this.rRewriteHop(fsb.getFromHops(), new HashMap<Long, FTypes.FType>(), Collections.emptyMap());
                this.rRewriteHop(fsb.getToHops(), new HashMap<Long, FTypes.FType>(), Collections.emptyMap());
                this.rRewriteHop(fsb.getIncrementHops(), new HashMap<Long, FTypes.FType>(), Collections.emptyMap());
                for (StatementBlock csb : fstmt.getBody()) {
                    this.rRewriteStatementBlock(csb, fedVars);
                }
                break block9;
            }
            HashMap<Long, FTypes.FType> fedHops = new HashMap<Long, FTypes.FType>();
            if (sb.getHops() != null) {
                for (Hop c : sb.getHops()) {
                    this.rRewriteHop(c, fedHops, fedVars);
                }
            }
            if (sb.getHops() == null) break block9;
            for (Hop c : sb.getHops()) {
                if (!HopRewriteUtils.isData(c, Types.OpOpData.TRANSIENTWRITE)) continue;
                fedVars.put(c.getName(), (FTypes.FType)((Object)fedHops.get(c.getInput(0).getHopID())));
            }
        }
    }

    private void rRewriteHop(Hop hop, Map<Long, FTypes.FType> memo, Map<String, FTypes.FType> fedVars) {
        if (memo.containsKey(hop.getHopID())) {
            return;
        }
        for (Hop c : hop.getInput()) {
            this.rRewriteHop(c, memo, fedVars);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED)) {
            memo.put(hop.getHopID(), this.deriveFType((DataOp)hop));
        } else if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD)) {
            memo.put(hop.getHopID(), fedVars.get(hop.getName()));
        } else if (this.allowsFederated(hop, memo)) {
            hop.setForcedExecType(Types.ExecType.FED);
            memo.put(hop.getHopID(), this.getFederatedOut(hop, memo));
            if (memo.get(hop.getHopID()) != null) {
                hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
            }
        } else {
            memo.put(hop.getHopID(), null);
        }
    }
}

