package cc.polymorphism.obfuscator.util;

import cc.polymorphism.obfuscator.asm.StackEmulator;
import cc.polymorphism.obfuscator.exceptions.obfuscation.FatalObfuscationException;
import cc.polymorphism.obfuscator.logging.Logger;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;

import java.util.Set;

public class ASMUtils implements Opcodes {
    public static AbstractInsnNode getNumberInsn(int number) {
        if (number >= -1 && number <= 5)
            return new InsnNode(number + 3);
        else if (number >= -128 && number <= 127)
            return new IntInsnNode(BIPUSH, number);
        else if (number >= -32768 && number <= 32767)
            return new IntInsnNode(SIPUSH, number);
        else
            return new LdcInsnNode(number);
    }

    public static AbstractInsnNode getNumberInsn(long number) {
        if (number == 0L || number == 1L)
            return new InsnNode((int) (number + 9));
        else
            return new LdcInsnNode(number);
    }

    public static AbstractInsnNode getNumberInsn(float number) {
        if (number == 0F || number == 1F || number == 2F) {
            return new InsnNode((int) (number + 11));
        } else {
            return new LdcInsnNode(number);
        }
    }

    public static AbstractInsnNode getNumberInsn(double number) {
        if (number == 0D || number == 1D)
            return new InsnNode((int) (number + 14));
        else
            return new LdcInsnNode(number);
    }

    public static int getIntegerFromInsn(AbstractInsnNode insn) {
        int opcode = insn.getOpcode();

        if (opcode >= ICONST_M1 && opcode <= ICONST_5) {
            return opcode - 3;
        } else if (insn instanceof IntInsnNode && insn.getOpcode() != NEWARRAY) {
            return ((IntInsnNode) insn).operand;
        } else if (insn instanceof LdcInsnNode && ((LdcInsnNode) insn).cst instanceof Integer) {
            return (Integer) ((LdcInsnNode) insn).cst;
        }

        throw new FatalObfuscationException("Attempted to get integer constant from " + insn.getOpcode() + " : " + insn);
    }


    public static LabelNode exitLabel(MethodNode methodNode) {
        LabelNode lbl = new LabelNode();
        LabelNode escLbl = new LabelNode();

        InsnList instructions = methodNode.instructions;
        AbstractInsnNode target = instructions.getFirst();

        instructions.insertBefore(target, new JumpInsnNode(GOTO, escLbl));
        instructions.insertBefore(target, lbl);

        switch (Type.getReturnType(methodNode.desc).getSort()) {
            case Type.VOID:
                instructions.insertBefore(target, new InsnNode(RETURN));
                break;
            case Type.BOOLEAN:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomInt(1, 2)));
                instructions.insertBefore(target, new InsnNode(IRETURN));
                break;
            case Type.CHAR:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomInt(Character.MIN_VALUE, Character.MAX_VALUE + 1)));
                instructions.insertBefore(target, new InsnNode(IRETURN));
                break;
            case Type.BYTE:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomInt(Byte.MIN_VALUE, Byte.MAX_VALUE + 1)));
                instructions.insertBefore(target, new InsnNode(IRETURN));
                break;
            case Type.SHORT:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomInt(Short.MIN_VALUE, Short.MAX_VALUE + 1)));
                instructions.insertBefore(target, new InsnNode(IRETURN));
                break;
            case Type.INT:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomInt(Integer.MIN_VALUE, Integer.MAX_VALUE)));
                instructions.insertBefore(target, new InsnNode(IRETURN));
                break;
            case Type.LONG:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomLong(Long.MIN_VALUE, Long.MAX_VALUE)));
                instructions.insertBefore(target, new InsnNode(LRETURN));
                break;
            case Type.FLOAT:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomFloat(Float.MIN_VALUE, Float.MAX_VALUE)));
                instructions.insertBefore(target, new InsnNode(FRETURN));
                break;
            case Type.DOUBLE:
                instructions.insertBefore(target, getNumberInsn(RandomUtils.randomDouble(Double.MIN_VALUE, Double.MAX_VALUE)));
                instructions.insertBefore(target, new InsnNode(DRETURN));
                break;
            default:
                instructions.insertBefore(target, new InsnNode(ACONST_NULL));
                instructions.insertBefore(target, new InsnNode(ARETURN));
                break;
        }
        instructions.insertBefore(target, escLbl);

        return lbl;
    }

    public static Set<AbstractInsnNode> getEmptyStack(MethodNode methodNode, ClassNode classNode) {
        final StackEmulator emulator = new StackEmulator(methodNode, methodNode.instructions.getLast());

        try {
            emulator.emulate(false);
        } catch (RuntimeException e) {
            Logger.severe(String.format("An error occured while trying to emulate the stack of %s.%s %s!",
                    classNode.name, methodNode.name, methodNode.desc));
        }

        return emulator.getEmptyStack();
    }

    public static MethodNode findOrCreateConstructor(ClassNode classNode) {
        MethodNode clinit = findMethod(classNode, "<init>", "()V");
        if (clinit == null) {
            clinit = new MethodNode(ACC_PUBLIC, "<init>", "()V", null, null);
            clinit.instructions.add(new InsnNode(RETURN));
            classNode.methods.add(clinit);
        }
        return clinit;
    }

    public static MethodNode findOrCreateInitializer(ClassNode classNode) {
        MethodNode clinit = findMethod(classNode, "<clinit>", "()V");
        if (clinit == null) {
            clinit = new MethodNode(ACC_STATIC, "<clinit>", "()V", null, null);
            clinit.instructions.add(new InsnNode(RETURN));
            classNode.methods.add(clinit);
        }
        return clinit;
    }

    public static MethodNode findMethod(ClassNode classNode, String name, String desc) {
        return classNode.methods
                .stream()
                .filter(methodNode -> name.equals(methodNode.name) && desc.equals(methodNode.desc))
                .findAny()
                .orElse(null);
    }

    public static boolean isInteger(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            return ldc.cst instanceof Integer;
        }

        int opcode = ain.getOpcode();

        return (opcode >= ICONST_M1 && opcode <= ICONST_5) || opcode == BIPUSH || opcode == SIPUSH;
    }

    public static boolean isLong(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc) {
            return ldc.cst instanceof Long;
        }

        int opcode = ain.getOpcode();

        return (opcode >= LCONST_0 && opcode <= LCONST_1);
    }

    public static int getInteger(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc && ldc.cst instanceof Integer) {
            return (int) ldc.cst;
        }

        if (ain instanceof IntInsnNode iin) {
            return iin.operand;
        }

        return switch (ain.getOpcode()) {
            case ICONST_M1 -> -1;
            case ICONST_0 -> 0;
            case ICONST_1 -> 1;
            case ICONST_2 -> 2;
            case ICONST_3 -> 3;
            case ICONST_4 -> 4;
            case ICONST_5 -> 5;
            default -> throw new IllegalStateException("Unexpected value: " + ain.getOpcode());
        };
    }


    public static long getLong(AbstractInsnNode ain) {
        if (ain instanceof LdcInsnNode ldc && ldc.cst instanceof Long) {
            return (long) ldc.cst;
        }

        return switch (ain.getOpcode()) {
            case LCONST_0 -> 0;
            case LCONST_1 -> 1;
            default -> throw new IllegalStateException("Unexpected value: " + ain.getOpcode());
        };
    }
}
