package rip.marie.mutator.data.string;

import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.*;
import rip.marie.mutator.IMutator;
import rip.marie.mutator.data.string.modifier.CharModifier;
import rip.marie.mutator.data.string.modifier.CharModifierData;
import rip.marie.mutator.data.string.modifier.impl.*;
import rip.marie.mutator.data.string.task.ClassWrapperTask;
import rip.marie.mutator.data.string.task.RemoveConcatenationTask;
import rip.marie.obfuscator.ZywcfuscatorBase;
import rip.marie.util.asm.BytecodeUtil;
import rip.marie.util.string.StringTemplate;
import rip.marie.util.string.StringUtil;
import rip.marie.util.wrapper.FieldWrapper;
import rip.marie.util.wrapper.MethodWrapper;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.objectweb.asm.Opcodes.*;

@SuppressWarnings("SpellCheckingInspection")
public class StringEncryptionMutator implements IMutator {
    private final List<ClassWrapperTask> tasks = new ArrayList<>(List.of(
            new RemoveConcatenationTask()
    ));

    public static String encrypt(String string, List<CharModifier> variation, StringEncryptionData data) {
        char[] encrypted = new char[string.length()];
        for (int i = 0; i < string.length(); i++) {
            char character = string.toCharArray()[i];

            for (CharModifier charModifier : variation) {
                character = charModifier.perform(data, i, character);
            }

            encrypted[i] = character;
        }

        return new String(encrypted);
    }

    public static String encryptLight(String string, int key1, int key2) {
        char[] encrypted = new char[string.length()];
        for (int i = 0; i < string.length(); i++) {
            int result = string.toCharArray()[i];

            if (i % 2 == 0)
                result ^= key1;
            else
                result ^= key2;

            encrypted[i] = (char) result;
        }
        return new String(encrypted);
    }

    private static MethodNode makeLightMethod(final String name) {
        MethodNode method = new MethodNode(ACC_PUBLIC | ACC_STATIC, name, "(Ljava/lang/String;II)Ljava/lang/String;", null, null);
        method.instructions = new InsnList();

        final InsnList context = method.instructions;

        LabelNode label0 = new LabelNode();
        context.add(label0);
        context.add(new VarInsnNode(ALOAD, 0));
        context.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "length", "()I", false));
        context.add(new IntInsnNode(NEWARRAY, T_CHAR));
        context.add(new VarInsnNode(ASTORE, 3));
        LabelNode label1 = new LabelNode();
        context.add(label1);
        context.add(new InsnNode(ICONST_0));
        context.add(new VarInsnNode(ISTORE, 4));
        LabelNode label2 = new LabelNode();
        context.add(label2);
        context.add(new FrameNode(Opcodes.F_APPEND, 2, new Object[]{"[C", Opcodes.INTEGER}, 0, null));
        context.add(new VarInsnNode(ILOAD, 4));
        context.add(new VarInsnNode(ALOAD, 0));
        context.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "length", "()I", false));
        LabelNode label3 = new LabelNode();
        context.add(new JumpInsnNode(IF_ICMPGE, label3));
        LabelNode label4 = new LabelNode();
        context.add(label4);
        context.add(new VarInsnNode(ALOAD, 0));
        context.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "toCharArray", "()[C", false));
        context.add(new VarInsnNode(ILOAD, 4));
        context.add(new InsnNode(CALOAD));
        context.add(new VarInsnNode(ISTORE, 5));
        LabelNode label5 = new LabelNode();
        context.add(label5);
        context.add(new VarInsnNode(ILOAD, 4));
        context.add(new InsnNode(ICONST_2));
        context.add(new InsnNode(IREM));
        LabelNode label6 = new LabelNode();
        context.add(new JumpInsnNode(IFNE, label6));
        LabelNode label7 = new LabelNode();
        context.add(label7);
        context.add(new VarInsnNode(ILOAD, 5));
        context.add(new VarInsnNode(ILOAD, 1));
        context.add(new InsnNode(IXOR));
        context.add(new VarInsnNode(ISTORE, 5));
        LabelNode label8 = new LabelNode();
        context.add(new JumpInsnNode(GOTO, label8));
        context.add(label6);
        context.add(new FrameNode(Opcodes.F_APPEND, 1, new Object[]{Opcodes.INTEGER}, 0, null));
        context.add(new VarInsnNode(ILOAD, 5));
        context.add(new VarInsnNode(ILOAD, 2));
        context.add(new InsnNode(IXOR));
        context.add(new VarInsnNode(ISTORE, 5));
        context.add(label8);
        context.add(new FrameNode(Opcodes.F_SAME, 0, null, 0, null));
        context.add(new VarInsnNode(ALOAD, 3));
        context.add(new VarInsnNode(ILOAD, 4));
        context.add(new VarInsnNode(ILOAD, 5));
        context.add(new InsnNode(I2C));
        context.add(new InsnNode(CASTORE));
        LabelNode label9 = new LabelNode();
        context.add(label9);
        context.add(new IincInsnNode(4, 1));
        context.add(new JumpInsnNode(GOTO, label2));
        context.add(label3);
        context.add(new FrameNode(Opcodes.F_CHOP, 2, null, 0, null));
        context.add(new TypeInsnNode(NEW, "java/lang/String"));
        context.add(new InsnNode(DUP));
        context.add(new VarInsnNode(ALOAD, 3));
        context.add(new MethodInsnNode(INVOKESPECIAL, "java/lang/String", "<init>", "([C)V", false));
        context.add(new InsnNode(ARETURN));

        method.maxStack = Integer.MAX_VALUE;
        method.maxLocals = 100;

        return method;
    }

    @Override
    public void run(ZywcfuscatorBase base) {

        base.getWhitelistedClasses().forEach(classWrapper -> {
            final ClassNode classNode = classWrapper.getBase();
            final String className = classWrapper.getDottedName();

            tasks.forEach(task -> task.execute(classWrapper));

            boolean additionalProtection;
            additionalProtection = className.endsWith("RetardedHardcodedCheckLMFAO");

            final AtomicBoolean shouldAddDecryption = new AtomicBoolean();

            final FieldNode cacheFieldNode = new FieldNode(ACC_PRIVATE | ACC_STATIC | ACC_FINAL,
                    StringUtil.generateString(StringTemplate.ABC, 8), "[Ljava/lang/String;", null, null);

            final List<CharModifier> variation = new ArrayList<>();
            for (int i = 0; i < RANDOM.nextInt(8, 24); i++) {
                CharModifier charModifier = switch (RANDOM.nextInt(5)) {
                    case 1 -> new RandomXorCharModifier();
                    case 2 -> new MethodSaltXorCharModifier();
                    case 3 -> new ClassSaltXorCharModifier();
                    case 4 -> new IndexSwitchXorCharModifier();
                    default -> new ConstantXorCharModifier();
                };
                variation.add(charModifier);
            }

            final String decryptMethodName = StringUtil.generateString(StringTemplate.ABC, 8);
            final MethodNode decryptMethodNode = makeInternalDecryptMethod(decryptMethodName, variation,
                    classNode, cacheFieldNode);

            final String lightMethodName = StringUtil.generateString(StringTemplate.ABC, 8);
            final MethodNode lightMethodNode = makeLightMethod(lightMethodName);

            AtomicInteger index = new AtomicInteger();

            classWrapper.getMethods().forEach(methodWrapper -> {
                final MethodNode methodNode = methodWrapper.getBase();

                final AbstractInsnNode[] untouchedContext = methodNode.instructions.toArray();
                final InsnList context = methodNode.instructions;

                for (AbstractInsnNode ain : untouchedContext) {
                    if (ain instanceof LdcInsnNode lin
                            && lin.cst instanceof String string) {
                        final int localIndex = index.get();

                        int key = RANDOM.nextInt(Short.MIN_VALUE, Short.MAX_VALUE);

                        int keyCalling = methodNode.name.hashCode() << key;
                        int keyOwner = className.hashCode() ^ key;

                        final StringEncryptionData data = new StringEncryptionData(
                                localIndex, key, keyCalling, keyOwner
                        );

                        String encryptedString = encrypt(string, variation, data);

                        final int key1 = RANDOM.nextInt(Short.MIN_VALUE, Short.MAX_VALUE);
                        final int key2 = RANDOM.nextInt(Short.MIN_VALUE, Short.MAX_VALUE);

                        if (additionalProtection) {
                            encryptedString = encryptLight(encryptedString, key1, key2);
                        }

                        final InsnList replacement = new InsnList();
                        {
                            replacement.add(BytecodeUtil.makeInteger(3));
                            replacement.add(new TypeInsnNode(ANEWARRAY, "java/lang/Object"));
                            replacement.add(new InsnNode(DUP));
                            replacement.add(BytecodeUtil.makeInteger(0));
                            replacement.add(BytecodeUtil.makeInteger(localIndex));
                            replacement.add(new MethodInsnNode(INVOKESTATIC, "java/lang/Integer", "valueOf", "(I)Ljava/lang/Integer;"));
                            replacement.add(new InsnNode(AASTORE));
                            replacement.add(new InsnNode(DUP));
                            replacement.add(BytecodeUtil.makeInteger(1));
                            replacement.add(BytecodeUtil.makeInteger(key));
                            replacement.add(new MethodInsnNode(INVOKESTATIC, "java/lang/Integer", "valueOf", "(I)Ljava/lang/Integer;"));
                            replacement.add(new InsnNode(AASTORE));
                            replacement.add(new InsnNode(DUP));
                            replacement.add(BytecodeUtil.makeInteger(2));
                            replacement.add(new LdcInsnNode(encryptedString));
                            replacement.add(new InsnNode(AASTORE));
                            replacement.add(new MethodInsnNode(INVOKESTATIC, classNode.name, decryptMethodNode.name, decryptMethodNode.desc));
                            replacement.add(new FieldInsnNode(GETSTATIC, classNode.name, cacheFieldNode.name, "[Ljava/lang/String;"));
                            replacement.add(BytecodeUtil.makeInteger(localIndex));
                            replacement.add(new InsnNode(AALOAD));

                            if (additionalProtection) {
                                replacement.add(BytecodeUtil.makeInteger(key1));
                                replacement.add(BytecodeUtil.makeInteger(key2));
                                replacement.add(new MethodInsnNode(INVOKESTATIC,
                                        classNode.name, lightMethodNode.name, lightMethodNode.desc));
                            }
                        }

                        context.insertBefore(ain, replacement);
                        context.remove(ain);

                        index.getAndIncrement();

                        shouldAddDecryption.set(true);
                    }
                }
            });

            if (shouldAddDecryption.get()) {
                final MethodNode initializer = classWrapper.getStaticInit().getBase();
                final InsnList blockArray = new InsnList();
                {
                    int localIndex = index.get();
                    blockArray.add(BytecodeUtil.makeInteger(localIndex));
                    blockArray.add(new TypeInsnNode(ANEWARRAY, "java/lang/String"));
                    blockArray.add(new FieldInsnNode(PUTSTATIC, classNode.name, cacheFieldNode.name, cacheFieldNode.desc));
                }
                initializer.instructions.insertBefore(initializer.instructions.getFirst(), blockArray);

                classWrapper.getFields().add(new FieldWrapper(classWrapper, cacheFieldNode));

                classWrapper.getMethods().add(new MethodWrapper(classWrapper, decryptMethodNode));

                if (additionalProtection) {
                    classWrapper.getMethods().add(new MethodWrapper(classWrapper, lightMethodNode));
                }
            }
        });
    }

    private MethodNode makeInternalDecryptMethod(String name, List<CharModifier> modifiers,
                                                 ClassNode classNode, FieldNode fieldNode) {
        int paramIndex = 0;

        MethodNode method = new MethodNode(ACC_PRIVATE | ACC_STATIC | ACC_FINAL,
                name, "([Ljava/lang/Object;)V", null, null);
        method.instructions = new InsnList();
        final InsnList instructions = method.instructions;

        int stringIndexIndex = 1;
        int stringKeyIndex = 2;

        LabelNode lblExtractIndex = new LabelNode();
        instructions.add(lblExtractIndex);
        instructions.add(new VarInsnNode(ALOAD, paramIndex));
        instructions.add(BytecodeUtil.makeInteger(0));
        instructions.add(new InsnNode(AALOAD));
        instructions.add(new TypeInsnNode(CHECKCAST, "java/lang/Integer"));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/Integer", "intValue", "()I"));
        instructions.add(new VarInsnNode(ISTORE, stringIndexIndex));

        LabelNode lblExtractKey = new LabelNode();
        instructions.add(lblExtractKey);
        instructions.add(new VarInsnNode(ALOAD, paramIndex));
        instructions.add(BytecodeUtil.makeInteger(1));
        instructions.add(new InsnNode(AALOAD));
        instructions.add(new TypeInsnNode(CHECKCAST, "java/lang/Integer"));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/Integer", "intValue", "()I"));
        instructions.add(new VarInsnNode(ISTORE, stringKeyIndex));

        LabelNode lblStarting = new LabelNode();
        LabelNode lblStore = new LabelNode();
        LabelNode lblWrite = new LabelNode();
        LabelNode lblReturn = new LabelNode();

        instructions.add(lblStarting);
        instructions.add(new FieldInsnNode(GETSTATIC, classNode.name, fieldNode.name, "[Ljava/lang/String;"));
        instructions.add(new VarInsnNode(ILOAD, stringIndexIndex));
        instructions.add(new InsnNode(AALOAD));
        instructions.add(new JumpInsnNode(IFNONNULL, lblReturn));

        int traceArrayIndex = 3;

        String haha = new Throwable().getStackTrace()[0].getMethodName();

        LabelNode label2 = new LabelNode();
        instructions.add(label2);
        instructions.add(new TypeInsnNode(NEW, "java/lang/Throwable"));
        instructions.add(new InsnNode(DUP));
        instructions.add(new MethodInsnNode(INVOKESPECIAL, "java/lang/Throwable", "<init>", "()V", false));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/Throwable", "getStackTrace", "()[Ljava/lang/StackTraceElement;", false));
        instructions.add(new VarInsnNode(ASTORE, traceArrayIndex));

        int keyCallingIndex = 4;
        int keyOwnerIndex = 5;

        LabelNode label3 = new LabelNode();
        instructions.add(label3);
        instructions.add(new VarInsnNode(ALOAD, traceArrayIndex));
        instructions.add(new InsnNode(ICONST_1));
        instructions.add(new InsnNode(AALOAD));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StackTraceElement", "getMethodName", "()Ljava/lang/String;", false));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "hashCode", "()I", false));
        instructions.add(new VarInsnNode(ILOAD, stringKeyIndex));
        instructions.add(new InsnNode(ISHL));
        instructions.add(new VarInsnNode(ISTORE, keyCallingIndex));

        LabelNode label4 = new LabelNode();
        instructions.add(label4);
        instructions.add(new VarInsnNode(ALOAD, traceArrayIndex));
        instructions.add(new InsnNode(ICONST_0));
        instructions.add(new InsnNode(AALOAD));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StackTraceElement", "getClassName", "()Ljava/lang/String;", false));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "hashCode", "()I", false));
        instructions.add(new VarInsnNode(ILOAD, stringKeyIndex));
        instructions.add(new InsnNode(IXOR));
        instructions.add(new VarInsnNode(ISTORE, keyOwnerIndex));

        int stringIndex = 6;

        LabelNode lblExtractString = new LabelNode();
        instructions.add(lblExtractString);
        instructions.add(new VarInsnNode(ALOAD, paramIndex));
        instructions.add(BytecodeUtil.makeInteger(2));
        instructions.add(new InsnNode(AALOAD));
        instructions.add(new TypeInsnNode(CHECKCAST, "java/lang/String"));
        instructions.add(new VarInsnNode(ASTORE, stringIndex));

        int arrayIndex = 7;

        LabelNode label5 = new LabelNode();
        instructions.add(label5);
        instructions.add(new VarInsnNode(ALOAD, stringIndex));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "length", "()I", false));
        instructions.add(new IntInsnNode(NEWARRAY, T_CHAR));
        instructions.add(new VarInsnNode(ASTORE, arrayIndex));

        int loopIntegerIndex = 8;

        LabelNode label6 = new LabelNode();
        instructions.add(label6);
        instructions.add(new InsnNode(ICONST_0));
        instructions.add(new VarInsnNode(ISTORE, loopIntegerIndex));

        LabelNode label7 = new LabelNode();
        instructions.add(label7);
        instructions.add(new VarInsnNode(ILOAD, loopIntegerIndex));
        instructions.add(new VarInsnNode(ALOAD, stringIndex));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "length", "()I", false));
        instructions.add(new JumpInsnNode(IF_ICMPGE, lblWrite));

        int currentCharacterIndex = 9;

        LabelNode label9 = new LabelNode();
        instructions.add(label9);
        instructions.add(new VarInsnNode(ALOAD, stringIndex));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/String", "toCharArray", "()[C", false));
        instructions.add(new VarInsnNode(ILOAD, loopIntegerIndex));
        instructions.add(new InsnNode(CALOAD));
        instructions.add(new VarInsnNode(ISTORE, currentCharacterIndex));

        method.maxLocals = 9;

        for (CharModifier charModifier : modifiers) {
            final LabelNode lblNext = new LabelNode();
            final CharModifierData data = new CharModifierData(
                    stringIndex,
                    stringIndexIndex,
                    stringKeyIndex,
                    keyCallingIndex,
                    keyOwnerIndex,
                    arrayIndex,
                    loopIntegerIndex,
                    currentCharacterIndex,
                    lblNext,
                    lblStore
            );
            final InsnList addition = charModifier.generate(method, data);
            instructions.add(addition);
            instructions.add(lblNext);
        }

        instructions.add(lblStore);
        instructions.add(new VarInsnNode(ALOAD, arrayIndex));
        instructions.add(new VarInsnNode(ILOAD, loopIntegerIndex));
        instructions.add(new VarInsnNode(ILOAD, currentCharacterIndex));
        instructions.add(new InsnNode(I2C));
        instructions.add(new InsnNode(CASTORE));

        LabelNode lblRepeat = new LabelNode();
        instructions.add(lblRepeat);
        instructions.add(new IincInsnNode(loopIntegerIndex, 1));
        instructions.add(new JumpInsnNode(GOTO, label7));

        instructions.add(lblWrite);
        instructions.add(new FieldInsnNode(GETSTATIC, classNode.name, fieldNode.name, "[Ljava/lang/String;"));
        instructions.add(new VarInsnNode(ILOAD, stringIndexIndex));
        instructions.add(new TypeInsnNode(NEW, "java/lang/String"));
        instructions.add(new InsnNode(DUP));
        instructions.add(new VarInsnNode(ALOAD, arrayIndex));
        instructions.add(new MethodInsnNode(INVOKESPECIAL, "java/lang/String", "<init>", "([C)V", false));
        instructions.add(new InsnNode(AASTORE));

        instructions.add(lblReturn);
        instructions.add(new InsnNode(RETURN));

        method.maxLocals = 100;

        return method;
    }
}