package org.jnt.matrix.ir.lift;

import org.jnt.matrix.ir.def.Stack;
import org.jnt.matrix.ir.expr.IExpr;
import org.jnt.matrix.ir.expr.impl.*;
import org.jnt.matrix.ir.print.IRPrinter;
import org.jnt.matrix.ir.stmt.IStmt;
import org.jnt.matrix.ir.stmt.impl.*;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;

import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.function.Predicate;

import static org.objectweb.asm.Opcodes.*;

/**
 * @author etho
 */
public class IRLifter {
    private IRPrinter printer;

    public IRLifter(IRPrinter printer) {
        this.printer = printer;
    }

    public static boolean isInteger(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            return ldc.cst instanceof Integer;
        }

        int opcode = ain.getOpcode();

        return (opcode >= ICONST_M1 && opcode <= ICONST_5) || opcode == BIPUSH || opcode == SIPUSH;
    }

    public static int getInteger(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc && ldc.cst instanceof Integer) {
            return (int) ldc.cst;
        }

        if (ain instanceof IntInsnNode iin) {
            return iin.operand;
        }

        return switch (ain.getOpcode()) {
            case ICONST_M1 -> -1;
            case ICONST_0 -> 0;
            case ICONST_1 -> 1;
            case ICONST_2 -> 2;
            case ICONST_3 -> 3;
            case ICONST_4 -> 4;
            case ICONST_5 -> 5;
            default -> throw new IllegalStateException("Unexpected value: " + ain.getOpcode());
        };
    }

    public static boolean isFloat(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            return ldc.cst instanceof Float;
        }

        int opcode = ain.getOpcode();
        return (opcode >= FCONST_0 && opcode <= FCONST_2);
    }

    public static boolean isDouble(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            return ldc.cst instanceof Double;
        }

        int opcode = ain.getOpcode();
        return opcode == DCONST_0 || opcode == DCONST_1;
    }

    public static boolean isLong(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            return ldc.cst instanceof Long;
        }

        int opcode = ain.getOpcode();
        return opcode == LCONST_0 || opcode == LCONST_1;
    }

    public static float getFloat(AbstractInsnNode ain) {
        int opcode = ain.getOpcode();

        if (ain instanceof LdcInsnNode ldc) {
            return (float) ldc.cst;
        } else if (opcode == FCONST_0) {
            return 0f;
        } else if (opcode == FCONST_1) {
            return 1f;
        } else if (opcode == FCONST_2) {
            return 2f;
        }
        throw new RuntimeException("Failed to extract float.");
    }

    public static double getDouble(AbstractInsnNode ain) {
        int opcode = ain.getOpcode();

        if (ain instanceof LdcInsnNode ldc) {
            return (double) ldc.cst;
        } else if (opcode == DCONST_0) {
            return 0D;
        } else if (opcode == DCONST_1) {
            return 1D;
        }
        throw new RuntimeException("Failed to extract double.");
    }

    public static long getLong(AbstractInsnNode ain) {
        int opcode = ain.getOpcode();

        if (ain instanceof LdcInsnNode ldc) {
            return (long) ldc.cst;
        } else if (opcode == LCONST_0) {
            return 0L;
        } else if (opcode == LCONST_1) {
            return 1L;
        }
        throw new RuntimeException("Failed to extract long.");
    }

    public static boolean isString(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            return ldc.cst instanceof String;
        }
        return false;
    }

    public static String getString(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            if (ldc.cst instanceof String) {
                return (String) ldc.cst;
            }
        }
        return null;
    }

    public static boolean isReturning(AbstractInsnNode insn) {
        int op = insn.getOpcode();
        return (op >= IRETURN && op <= RETURN);
    }

    public static boolean isStore(AbstractInsnNode insn) {
        int op = insn.getOpcode();
        return (op >= ISTORE && op <= ASTORE);
    }

    public static boolean isLoad(AbstractInsnNode insn) {
        int op = insn.getOpcode();
        return (op >= ILOAD && op <= ALOAD);
    }

    public static boolean isArithmetic(AbstractInsnNode insn) {
        return switch (insn.getOpcode()) {
            case IADD, ISUB, IDIV, IMUL,
                 IXOR, ISHL, ISHR, IAND,
                 IOR, IUSHR, IREM -> true;
            default -> false;
        };
    }

    public static String fromArithmetic(AbstractInsnNode insn) {
        return switch (insn.getOpcode()) {
            case IADD, DADD, FADD, LADD -> "+";
            case ISUB, DSUB, FSUB, LSUB -> "-";
            case IMUL, DMUL, FMUL, LMUL -> "*";
            case IDIV, DDIV, FDIV, LDIV -> "/";
            case IXOR, LXOR -> "^";
            case ISHL, LSHL -> "<<";
            case ISHR, LSHR -> ">>";
            case IAND, LAND -> "&";
            case IOR, LOR -> "|";
            case IUSHR, LUSHR -> ">>>";
            case IREM, LREM, FREM, DREM -> "%";
            default -> throw new IllegalStateException("Unexpected value: " + insn.getOpcode());
        };
    }

    public List<IStmt> lift(MethodNode method) {
        final List<IStmt> stmts = new ArrayList<>();
        final Stack stack = new Stack();

        for (var insn : method.instructions) {
            processIfCst(insn, stack);

            if (insn.getOpcode() == DUP) {
                IExpr expr = stack.pop();

                stack.push(expr);
                stack.push(expr);

                stmts.add(new NewVarStmt(expr, expr.getType(), -1, true));
            } else if (insn.getOpcode() == POP) {
                stmts.add(new PopStmt(stack.pop()));
            }

            switch (insn) {
                case LabelNode lbl -> {
                    stmts.add(new NewBlockStmt(lbl));
                }
                case TypeInsnNode typeInsn -> {
                    if (typeInsn.getOpcode() == NEW) {
                        String td = typeInsn.desc;
                        if (td.charAt(0) != '[') {
                            td = "L" + td + ";";
                        }

                        Type t = Type.getType(td);
                        stack.push(new AllocateObjectExpr(t));
                    }
                }
                case JumpInsnNode jmpInsn -> {
                    LabelNode target = jmpInsn.label;

                    switch (insn.getOpcode()) {
                        case GOTO -> stmts.add(new RawJumpStmt(target));
                        case IFEQ, IFNE -> {
                            stmts.add(new IfStmt(insn.getOpcode(), target, stack.pop()));
                        }
                        case IF_ICMPEQ, IF_ICMPNE,
                             IF_ICMPGT, IF_ICMPLT,
                             IF_ICMPGE, IF_ICMPLE -> {
                            stmts.add(new IfStmt(insn.getOpcode(), target,
                                    stack.pop(), stack.pop()));
                        }
                    }
                }
                case FieldInsnNode fieldInsn -> {
                    if (insn.getOpcode() == GETSTATIC || insn.getOpcode() == GETFIELD) {
                        stack.push(new FieldGetExpr(
                                    insn.getOpcode(),
                                    fieldInsn.owner,
                                    fieldInsn.name,
                                    fieldInsn.desc
                                )
                        );
                    }
                }
                case MethodInsnNode methodInsn -> {
                    int argc = Type.getArgumentCount(methodInsn.desc);
                    Type rType = Type.getReturnType(methodInsn.desc);

                    IExpr[] args = new IExpr[argc];
                    for (int i = 0; i < argc; i++) {
                        args[i] = stack.pop();
                    }

                    IExpr expr;
                    if (insn.getOpcode() == INVOKESTATIC) {
                         expr = new StaticCallExpr(
                                insn.getOpcode(),
                                methodInsn.owner,
                                methodInsn.name,
                                methodInsn.desc,
                                args
                        );
                    } else {
                        expr = new VirtualCallExpr(
                                stack.pop(),
                                insn.getOpcode(),
                                methodInsn.owner,
                                methodInsn.name,
                                methodInsn.desc,
                                args
                        );
                    }

                    if (rType.getSort() == Type.VOID) {
                        stmts.add(new PopStmt(expr));
                    } else {
                        stack.push(expr);
                    }
                }
                case InsnNode insnNode -> {
                    if (isReturning(insn)) {
                        if (insn.getOpcode() != RETURN) {
                            IExpr expr = stack.pop();
                            stmts.add(new ReturnStmt(expr));
                        } else {
                            stmts.add(new ReturnStmt(null));
                        }
                    } else if (isArithmetic(insn)) {
                        String op = fromArithmetic(insn);

                        IExpr a = stack.pop();
                        IExpr b = stack.pop();

                        stack.push(new ArithmeticExpr(b, a, op, b.getType()));
                    }
                }
                case VarInsnNode varInsn -> {
                    if (isStore(insn)) {
                        IExpr expr = stack.pop();

                        Type type = switch (expr) {
                            case ConstantExpr constExpr -> constExpr.getType();
                            case ArithmeticExpr arithExpr -> arithExpr.getType();
                            case StaticCallExpr callExpr -> callExpr.getType();
                            case FieldGetExpr fieldGetExpr -> fieldGetExpr.getType();
                            case AllocateObjectExpr allocObjExpr -> allocObjExpr.getType();
                            default -> throw new IllegalStateException("Unexpected value: " + expr);
                        };

                        stmts.add(new NewVarStmt(expr, type, varInsn.var, false));
                    } else if (isLoad(insn)) {
                        Type type = switch (insn.getOpcode()) {
                            case ILOAD -> Type.INT_TYPE;
                            case FLOAD -> Type.FLOAT_TYPE;
                            case DLOAD -> Type.DOUBLE_TYPE;
                            default -> Type.getType(Object.class);
                        };

                        stack.push(new VarExpr(varInsn.var, type));
                    }
                }
                default -> {
                }
            }
        }

        return stmts;
    }

    private void processIfCst(AbstractInsnNode insn, Stack stack) {
        for (ConstantType type : ConstantType.values()) {
            if (type.isTypeMatch(insn)) {
                Object value = type.getValue(insn);
                stack.push(new ConstantExpr(value, type.getType()));
                return;
            }
        }
    }

    private enum ConstantType {
        INTEGER(Type.INT_TYPE, IRLifter::isInteger, IRLifter::getInteger),
        FLOAT(Type.FLOAT_TYPE, IRLifter::isFloat, IRLifter::getFloat),
        LONG(Type.LONG_TYPE, IRLifter::isLong, IRLifter::getLong),
        DOUBLE(Type.DOUBLE_TYPE, IRLifter::isDouble, IRLifter::getDouble),
        STRING(Type.getType(String.class), IRLifter::isString, IRLifter::getString);

        private final Type type;
        private final Predicate<AbstractInsnNode> typePredicate;
        private final Function<AbstractInsnNode, Object> valueGetter;

        ConstantType(Type type, Predicate<AbstractInsnNode> typePredicate, Function<AbstractInsnNode, Object> valueGetter) {
            this.type = type;
            this.typePredicate = typePredicate;
            this.valueGetter = valueGetter;
        }

        public Type getType() {
            return type;
        }

        public boolean isTypeMatch(AbstractInsnNode insn) {
            return typePredicate.test(insn);
        }

        public Object getValue(AbstractInsnNode insn) {
            return valueGetter.apply(insn);
        }
    }
}
