package cc.polymorphism.assembly.expressions.flow;

import cc.polymorphism.assembly.BytecodeBlock;
import cc.polymorphism.assembly.expressions.IRExpression;
import cc.polymorphism.assembly.instructions.BytecodeLabel;
import cc.polymorphism.assembly.instructions.JumpNode;
import cc.polymorphism.assembly.instructions.SwitchNode;
import org.jetbrains.annotations.NotNull;

import java.util.ArrayList;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class IRSwitchStructure extends IRFlowStructure {
    private final IRExpression operand;
    private final SortedSet<IRCaseStructure> cases = new TreeSet<>();
    private final BytecodeBlock defaultBody;
    private final ArrayList<BytecodeLabel> caseLabels = new ArrayList<>();
    private final BytecodeLabel defaultLabel = new BytecodeLabel();
    private final BytecodeLabel exitLabel = new BytecodeLabel();

    public IRSwitchStructure(IRExpression operand, ArrayList<Integer> keys, ArrayList<BytecodeBlock> caseBodies, BytecodeBlock defaultBody) {
        this.operand = operand;
        this.defaultBody = defaultBody;

        IntStream.range(0, keys.size()).forEach(index -> cases.add(new IRCaseStructure(keys.get(index), caseBodies.get(index))));
    }

    @Override
    public BytecodeBlock bake() {
        var block = new BytecodeBlock()
                .append(operand.bake())
                .append(new SwitchNode(cases.stream().map(IRCaseStructure::getKey).collect(Collectors.toList()), caseLabels, defaultLabel));

        // Cases
        cases.forEach(caseStructure -> {
            var caseLabel = new BytecodeLabel();
            block.append(caseLabel)
                    .append(caseStructure.getBody())
                    .append(JumpNode.jumpUnconditionally(exitLabel));
            caseLabels.add(caseLabel);
        });

        // Default
        block.append(defaultLabel)
                .append(defaultBody);

        // Exit
        block.append(exitLabel);
        return block;
    }

    public ArrayList<BytecodeLabel> getCaseLabels() {
        return caseLabels;
    }

    public BytecodeLabel getDefaultLabel() {
        return defaultLabel;
    }

    public BytecodeLabel getExitLabel() {
        return exitLabel;
    }

    static class IRCaseStructure implements Comparable<IRCaseStructure> {
        private final int key;
        private final BytecodeBlock body;

        IRCaseStructure(int key, BytecodeBlock body) {
            this.key = key;
            this.body = body;
        }

        public int getKey() {
            return key;
        }

        public BytecodeBlock getBody() {
            return body;
        }

        @Override
        public int compareTo(@NotNull IRSwitchStructure.IRCaseStructure other) {
            return Integer.compare(key, other.key);
        }
    }
}
