package fag.ml.test.util;

import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.*;

import java.util.*;

public class SSVMUtil {
    public static void deobfBase64StringArrays(Map<String, ClassNode> classMap) {
        for (ClassNode cn : classMap.values()) {
            FieldNode targetField = findStaticStringArray(cn);
            if (targetField == null) continue;

            MethodNode clinit = findClinit(cn);
            if (clinit == null || !usesBase64(clinit)) continue;

            Map<Integer, String> decoded = emulateClinitToStringArray(clinit, targetField.name);
            patchStringArrayReferences(cn, targetField, decoded);
        }
    }

    public static FieldNode findStaticStringArray(ClassNode cn) {
        for (FieldNode fn : cn.fields) {
            if ((fn.access & Opcodes.ACC_STATIC) != 0 &&
                    fn.desc.equals("[Ljava/lang/String;"))
                return fn;
        }
        return null;
    }

    public static MethodNode findClinit(ClassNode cn) {
        return cn.methods.stream()
                .filter(m -> m.name.equals("<clinit>"))
                .findFirst().orElse(null);
    }

    public static boolean usesBase64(MethodNode m) {
        return Arrays.stream(m.instructions.toArray())
                .filter(i -> i instanceof MethodInsnNode)
                .map(i -> (MethodInsnNode)i)
                .anyMatch(mi -> mi.owner.equals("java/util/Base64") && mi.name.equals("getDecoder"));
    }

    public static Map<Integer, String> emulateClinitToStringArray(MethodNode cl, String arrayName) {
        Map<Integer, String> map = new HashMap<>();
        String[] tempArray = null;
        Deque<Object> stack = new ArrayDeque<>();

        for (AbstractInsnNode insn : cl.instructions.toArray()) {
            switch (insn.getOpcode()) {
                case Opcodes.LDC:
                    stack.push(((LdcInsnNode)insn).cst);
                    break;
                case Opcodes.INVOKESTATIC:
                    MethodInsnNode mi = (MethodInsnNode)insn;
                    if (mi.owner.equals("java/util/Base64$Decoder") && mi.name.equals("decode")) {
                        byte[] data = Base64.getDecoder().decode((String)stack.pop());
                        stack.push(data);
                    }
                    break;
                case Opcodes.NEWARRAY:
                    int size = resolveInt(stack.pop());
                    tempArray = new String[size];
                    stack.push(tempArray);
                    break;
                case Opcodes.AASTORE:
                    String val = (String)stack.pop();
                    Integer idx = (Integer)stack.pop();
                    stack.pop();
                    tempArray[idx] = val;
                    break;
                case Opcodes.ANEWARRAY: {
                    Object top = stack.pop();
                    int size2 = resolveInt(top); // This should be the only pop
                    String[] array = new String[size2];
                    tempArray = array;
                    stack.push(array);
                    break;
                }
            }
        }

        if (tempArray != null) {
            for (int i = 0; i < tempArray.length; i++) {
                if (tempArray[i] != null) map.put(i, tempArray[i]);
            }
        }
        return map;
    }

    public static void patchStringArrayReferences(ClassNode cn, FieldNode fn, Map<Integer,String> decoded) {
        for (MethodNode m : cn.methods) {
            InsnList insns = m.instructions;
            for (AbstractInsnNode insn = insns.getFirst(); insn != null; insn = insn.getNext()) {
                if (insn instanceof FieldInsnNode fi
                        && fi.getOpcode() == Opcodes.GETSTATIC
                        && fi.name.equals(fn.name)) {

                    AbstractInsnNode idxNode = insn.getNext();
                    AbstractInsnNode aload = idxNode.getNext();

                    if (idxNode instanceof LdcInsnNode || idxNode instanceof IntInsnNode ||
                            (idxNode.getOpcode() >= Opcodes.ICONST_M1 && idxNode.getOpcode() <= Opcodes.ICONST_5)) {
                        int index = resolveInt(idxNode);
                        if (aload != null && aload.getOpcode() == Opcodes.AALOAD) {
                            String val = decoded.get(index);
                            if (val != null) {
                                insns.set(insn, new LdcInsnNode(val));
                                insns.remove(idxNode);
                                insns.remove(aload);
                            }
                        }
                    }
                }
            }
        }
    }

    public static int resolveInt(Object o) {
        if (o instanceof Integer) return (Integer) o;
        if (o instanceof Number) return ((Number)o).intValue();
        throw new IllegalArgumentException("Expected integer value on stack, but got: " + o);
    }

    private static int resolveInt(AbstractInsnNode insn) {
        if (insn instanceof LdcInsnNode ldc && ldc.cst instanceof Integer) {
            return (Integer) ldc.cst;
        } else if (insn instanceof IntInsnNode iin) {
            return iin.operand;
        } else {
            switch (insn.getOpcode()) {
                case Opcodes.ICONST_M1: return -1;
                case Opcodes.ICONST_0: return 0;
                case Opcodes.ICONST_1: return 1;
                case Opcodes.ICONST_2: return 2;
                case Opcodes.ICONST_3: return 3;
                case Opcodes.ICONST_4: return 4;
                case Opcodes.ICONST_5: return 5;
            }
        }
        throw new IllegalArgumentException("Unsupported index instruction: " + insn);
    }
}