package fag.ml.sigma.util;

import fag.ml.sigma.Main;
import fag.ml.util.Constants;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;

import java.util.ArrayList;
import java.util.Objects;

public class SigmaPatchUtil implements Constants {
    private static ClassNode cachedClass = null;
    private static ClassNode cachedBooleanSetting = null;
    private static ClassNode cachedStringSetting = null;
    private static ClassNode premiumModuleNode = null;
    private static ClassNode regularModuleNode = null;

    public static ClassNode findClassWithDiscordRpcConstructor() {
        for (var classNode : Main.NODES.values()) {
            for (var method : classNode.methods) {
                if (method.instructions == null) continue;

                for (var insn : method.instructions) {
                    if (insn instanceof MethodInsnNode methodInsn) {
                        if (methodInsn.getOpcode() == Opcodes.INVOKESPECIAL
                                && methodInsn.owner.equals("club/minnced/discord/rpc/DiscordRichPresence")
                                && methodInsn.name.equals("<init>")
                                && methodInsn.desc.equals("()V")) {
                            cachedClass = classNode;
                            return classNode;
                        }
                    }
                }
            }
        }
        return null;
    }

    public static void patchNCPPhaseCheck() {
        for (var classNode : Main.NODES.values()) {
            boolean hasLookupInConstructor = false;

            // 1. Check constructors for MethodHandles.lookup()
            for (var method : classNode.methods) {
                if ("<init>".equals(method.name) && method.instructions != null) {
                    for (var insn : method.instructions) {
                        if (insn instanceof MethodInsnNode mInsn) {
                            if (mInsn.owner.equals("java/lang/invoke/MethodHandles") &&
                                    mInsn.name.equals("lookup") &&
                                    mInsn.desc.equals("()Ljava/lang/invoke/MethodHandles$Lookup;")) {
                                hasLookupInConstructor = true;
                                break;
                            }
                        }
                    }
                }
                if (hasLookupInConstructor) break;
            }
            if (!hasLookupInConstructor) continue;

            // 2. Check fields count and boolean field
            if (classNode.fields.size() <= 6) continue;

            var booleanFields = new ArrayList<FieldNode>();
            for (var field : classNode.fields) {
                if ("Z".equals(field.desc)) {
                    booleanFields.add(field);
                }
            }
            if (booleanFields.size() != 1) continue;

            var boolField = booleanFields.getFirst();
            LOGGER.info("found event handler class (class name: " + classNode.name + ")");
            LOGGER.info("found ncp phase check (field name: " + boolField.name + boolField.desc + ")");

            // 3. Find method that returns this boolean field at some point
            MethodNode getter = null;

            outer:
            for (var method : classNode.methods) {
                if (method.instructions == null) continue;
                if (!method.desc.equals("()Z")) continue; // must return boolean

                var insns = method.instructions.toArray();

                for (var i = 0; i < insns.length - 2; i++) {
                    var insn0 = insns[i];
                    var insn1 = insns[i + 1];
                    var insn2 = insns[i + 2];

                    if (insn0.getOpcode() == Opcodes.ALOAD &&
                            insn0 instanceof VarInsnNode varInsn &&
                            varInsn.var == 0 &&

                            insn1.getOpcode() == Opcodes.GETFIELD &&
                            insn1 instanceof FieldInsnNode fInsn &&
                            fInsn.name.equals(boolField.name) &&
                            fInsn.desc.equals(boolField.desc) &&

                            insn2.getOpcode() == Opcodes.IRETURN) {

                        getter = method;
                        break outer;
                    }
                }
            }

            if (getter == null) continue;

            LOGGER.info("found ncp phase check (method name: " + getter.name + ")");

            // 4. Overwrite method instructions to just 'return false'
            getter.instructions.clear();
            InsnList insns = new InsnList();
            insns.add(new InsnNode(Opcodes.ICONST_0)); // push false
            insns.add(new InsnNode(Opcodes.IRETURN));  // return
            getter.instructions.add(insns);

            // 5. Fix stack and locals
            getter.maxStack = 1;
            getter.maxLocals = 1;

            cachedClass = classNode;

            LOGGER.info("patched ncp phase check (method name: " + getter.name + ")");
        }
    }

    public static void findNCPPhase() {
        if (cachedClass == null) return;

        for (var method : cachedClass.methods) {
            if (!"<init>".equals(method.name) || method.instructions == null) continue;

            var insns = method.instructions.toArray();

            for (var i = 0; i < insns.length - 7; i++) {
                // aload this
                // aload <some local var or field>
                // invokevirtual java/lang/Object.getClass ()Ljava/lang/Class;
                // invokevirtual java/lang/Class.getSuperclass ()Ljava/lang/Class;
                // ldc <some class>
                // invokevirtual java/lang/Class.getSuperclass ()Ljava/lang/Class;
                // if_acmpne (jump)
                if (
                        insns[i].getOpcode() == Opcodes.ALOAD &&
                                insns[i + 1].getOpcode() == Opcodes.ALOAD &&
                                insns[i + 2] instanceof MethodInsnNode m1 &&
                                m1.owner.equals("java/lang/Object") &&
                                m1.name.equals("getClass") &&
                                m1.desc.equals("()Ljava/lang/Class;") &&

                                insns[i + 3] instanceof MethodInsnNode m2 &&
                                m2.owner.equals("java/lang/Class") &&
                                m2.name.equals("getSuperclass") &&
                                m2.desc.equals("()Ljava/lang/Class;") &&

                                insns[i + 4].getOpcode() == Opcodes.LDC &&
                                insns[i + 5] instanceof MethodInsnNode m3 &&
                                m3.owner.equals("java/lang/Class") &&
                                m3.name.equals("getSuperclass") &&
                                m3.desc.equals("()Ljava/lang/Class;") &&

                                insns[i + 6].getOpcode() == Opcodes.IF_ACMPNE
                ) {
                    // Extract the class constant from LDC
                    var ldc = ((LdcInsnNode) insns[i + 4]).cst;
                    if (ldc instanceof Type) {
                        var internalName = ((Type) ldc).getInternalName();
                        var superClassNode = Main.NODES.get(internalName);
                        if (superClassNode != null) {
                            cachedClass = superClassNode;
                            LOGGER.info("found ncp phase class (class name: " + superClassNode.name + ")");
                            return;
                        }
                    }
                }
            }
        }
    }

    public static void findStringSetting() {
        for (var method : cachedClass.methods) {
            if (!method.name.equals("<init>")) continue;

            for (var insn : method.instructions) {
                if (insn.getOpcode() == Opcodes.NEW && insn instanceof TypeInsnNode newInsn) {
                    if (Main.NODES.containsKey(newInsn.desc)) {
                        cachedBooleanSetting = Main.NODES.get(newInsn.desc);
                        LOGGER.info("found boolean setting (class name: " + newInsn.desc + ")");

                        String superName = cachedBooleanSetting.superName;
                        if (Main.NODES.containsKey(superName)) {
                            for (var node : Main.NODES.values()) {
                                if (!Objects.equals(node.superName, superName)) continue;

                                if (node.signature == null || !node.signature.contains("L" + superName + "<Ljava/lang/String;>;"))
                                    continue;

                                int listStringCount = 0;

                                for (var field : node.fields) {
                                    if (field.signature != null &&
                                            field.signature.equals("Ljava/util/List<Ljava/lang/String;>;")) {
                                        listStringCount++;
                                    }

                                    if (listStringCount >= 2) break;
                                }

                                if (listStringCount == 2) {
                                    LOGGER.info("found string setting (class name: " + node.name + ")");
                                    cachedStringSetting = node;
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    public static void patchBooleanSetting() {
        if (cachedBooleanSetting == null) return;

        for (var method : cachedBooleanSetting.methods) {
            if (!method.desc.equals("()L" + cachedBooleanSetting.name + ";"))
                continue;

            InsnList insns = method.instructions;
            AbstractInsnNode target = null;

            for (AbstractInsnNode insn : insns) {
                if (insn.getOpcode() == Opcodes.PUTFIELD && insn instanceof FieldInsnNode) {
                    FieldInsnNode fieldInsn = (FieldInsnNode) insn;
                    if (fieldInsn.desc.equals("Z")) { // boolean field
                        target = insn;
                        break;
                    }
                }
            }

            if (target != null) {
                var toRemove = new ArrayList<AbstractInsnNode>();

                AbstractInsnNode cursor = target;
                for (int i = 0; i < 5 && cursor != null; i++) {
                    toRemove.add(cursor);
                    cursor = cursor.getPrevious();
                }

                for (AbstractInsnNode rem : toRemove) {
                    insns.remove(rem);
                }

                insns.clear();
                insns.add(new VarInsnNode(Opcodes.ALOAD, 0));
                insns.add(new InsnNode(Opcodes.ARETURN));

                LOGGER.info("patched boolean setting (method name: " + method.name + ")");
            }
        }
    }

    public static void patchStringSetting() {
        if (cachedStringSetting == null) return;

        for (var method : cachedStringSetting.methods) {
            if (!method.desc.equals("([Ljava/lang/String;)" + "L" + cachedStringSetting.name + ";"))
                continue;

            InsnList insns = method.instructions;
            AbstractInsnNode target = null;

            for (var insn : insns) {
                if (insn.getOpcode() == Opcodes.INVOKEINTERFACE && insn instanceof MethodInsnNode methodInsn) {
                    if (methodInsn.name.equals("addAll") && methodInsn.owner.equals("java/util/List")) {
                        target = insn;
                        break;
                    }
                }
            }

            if (target != null) {
                var toRemove = new ArrayList<AbstractInsnNode>();
                AbstractInsnNode cursor = target;

                for (int i = 0; i < 5; i++) {
                    if (cursor == null) break;
                    toRemove.add(cursor);
                    cursor = cursor.getPrevious();
                }

                AbstractInsnNode after = target.getNext();
                if (after != null && after.getOpcode() == Opcodes.POP) {
                    toRemove.add(after);
                }

                for (AbstractInsnNode rem : toRemove) {
                    insns.remove(rem);
                }

                insns.clear();
                insns.add(new VarInsnNode(Opcodes.ALOAD, 0));
                insns.add(new InsnNode(Opcodes.ARETURN));

                LOGGER.info("patched string setting (method name: " + method.name + ")");
            }
        }
    }

    public static void findPremiumModules() {
        if (cachedClass == null) return;

        // 1. Get superclass of cachedClassSuperclass and cache as premiumModuleNode
        var premiumSuperName = cachedClass.superName;
        if (premiumSuperName != null) {
            premiumModuleNode = Main.NODES.get(premiumSuperName);
            if (premiumModuleNode != null) {
                LOGGER.info("found premium module class (class name: " + premiumSuperName + ")");

                // 2. Get superclass of premiumModuleNode and cache as regularModuleNode
                String regularSuperName = premiumModuleNode.superName;
                if (regularSuperName != null) {
                    regularModuleNode = Main.NODES.get(regularSuperName);
                    if (regularModuleNode != null) {
                        LOGGER.info("found abstract module class (class name: " + regularSuperName + ")");
                    }
                }
            }
        }
    }

    public static void patchPremiumModules() {
        if (premiumModuleNode == null || regularModuleNode == null) return;

        for (var classNode : Main.NODES.values()) {
            if (!classNode.superName.equals(premiumModuleNode.name)) continue;

            classNode.superName = regularModuleNode.name;

            for (var method : classNode.methods) {
                if (!method.name.equals("<init>")) continue;

                var insns = method.instructions;
                for (var insn : insns) {
                    if (insn.getOpcode() == Opcodes.INVOKESPECIAL && insn instanceof MethodInsnNode invoke) {
                        if (!invoke.owner.equals(premiumModuleNode.name) || !invoke.name.equals("<init>")) continue;

                        insns.insertBefore(invoke, new InsnNode(Opcodes.DUP_X2));
                        insns.insertBefore(invoke, new InsnNode(Opcodes.POP));

                        String middle = invoke.desc.substring("(Ljava/lang/String;Ljava/lang/String;".length(), invoke.desc.length() - 2);
                        String newDesc = "(" + middle + "Ljava/lang/String;Ljava/lang/String;)V";
                        invoke.owner = regularModuleNode.name;
                        invoke.desc = newDesc;

                        LOGGER.info("patched premium module " + classNode.name);
                    }
                }
            }
        }
    }
}