package cc.polymorphism.obfuscator.mutator.impl.encryption.string;

import cc.polymorphism.assembly.WrappedType;
import cc.polymorphism.assembly.expressions.IRVariable;
import cc.polymorphism.obfuscator.Polymorphism;
import cc.polymorphism.obfuscator.asm.wrapper.ClassWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.InsnListWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.MethodWrapper;
import cc.polymorphism.obfuscator.engine.hash.Hash;
import cc.polymorphism.obfuscator.engine.seed.Seed;
import cc.polymorphism.obfuscator.mutator.Mutator;
import cc.polymorphism.obfuscator.mutator.impl.encryption.common.IntAlgorithm;
import cc.polymorphism.obfuscator.mutator.impl.encryption.common.IntCryptData;
import cc.polymorphism.obfuscator.util.StringConcatFactoryUtils;
import org.objectweb.asm.tree.*;

public class PolymorphicStringEncryptionMutator extends Mutator {
    public PolymorphicStringEncryptionMutator(Polymorphism ctx) {
        super(ctx);
    }

    @Override
    public void transform() {
        classStream().filter(o -> !o.isInterface() && !o.isAnnotation()).filter(this::isIncluded).forEach(classWrapper -> {
            final Hash hash = hashGenerator().generate(classWrapper);

            classWrapper.methodStream().filter(MethodWrapper::hasInstructions).filter(o -> shouldObfuscate(classWrapper, o)).forEach(methodWrapper -> {
                StringConcatFactoryUtils.removeStringConcatFactory(methodWrapper);

                final MethodNode methodNode = methodWrapper.getMethodNode();

                final InsnList content = methodNode.instructions;
                final AbstractInsnNode[] ctx = content.toArray();

                int leeway = methodWrapper.getLeewaySize();
                for (final AbstractInsnNode ain : ctx) {
                    if (leeway < 10000)
                        break;

                    if (!(ain instanceof LdcInsnNode lin && lin.cst instanceof String str)) {
                        continue;
                    }

                    final InsnList list = generateDecryption(hash, classWrapper, methodWrapper, str);

                    content.insertBefore(ain, list);
                    content.remove(ain);

                    leeway = methodWrapper.getLeewaySize();
                }
            });
        });
    }

    private String encryptString(final IntAlgorithm algorithm, final String string) {
        final StringBuilder stringBuilder = new StringBuilder(string);
        for (int i = 0; i < stringBuilder.length(); i++) {
            int result = stringBuilder.charAt(i);
            result = algorithm.encrypt(result);
            stringBuilder.setCharAt(i, (char) result);
        }
        return stringBuilder.toString();
    }

    private InsnList generateDecryption(final Hash hash, final ClassWrapper classWrapper, final MethodWrapper methodWrapper, final String string) {
        final Seed seed = seedGenerator().generateSeed(hash, classWrapper, methodWrapper);

        final MethodNode methodNode = methodWrapper.getMethodNode();

        final IRVariable varBuilder = new IRVariable(WrappedType.from(StringBuilder.class), methodNode.maxLocals++);
        final IRVariable varIndex = new IRVariable(WrappedType.from(int.class), methodNode.maxLocals++);
        final IRVariable varCharacter = new IRVariable(WrappedType.from(int.class), methodNode.maxLocals++);

        final IntCryptData data = new IntCryptData(hash, seed, classWrapper, methodWrapper, varCharacter);
        final IntAlgorithm algorithm = new IntAlgorithm(data);

        final String encryptedString = encryptString(algorithm, string);

        final InsnListWrapper instructions = new InsnListWrapper();

        // A:
        instructions.add(new TypeInsnNode(NEW, "java/lang/StringBuilder"));
        instructions.add(DUP);
        instructions.add(new LdcInsnNode(encryptedString));
        instructions.add(new MethodInsnNode(INVOKESPECIAL, "java/lang/StringBuilder", "<init>", "(Ljava/lang/String;)V", false));
        instructions.add(new VarInsnNode(ASTORE, varBuilder.getSlot()));

        // B:
        instructions.add(ICONST_0);
        instructions.add(new VarInsnNode(ISTORE, varIndex.getSlot()));

        // C:
        final LabelNode labelC = new LabelNode();
        instructions.add(labelC);
        instructions.add(new VarInsnNode(ILOAD, varIndex.getSlot()));
        instructions.add(new VarInsnNode(ALOAD, varBuilder.getSlot()));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "length", "()I", false));

        final LabelNode labelF = new LabelNode();
        instructions.add(new JumpInsnNode(IF_ICMPGE, labelF));

        // D:
        {
            instructions.add(new VarInsnNode(ALOAD, varBuilder.getSlot()));
            instructions.add(new VarInsnNode(ILOAD, varIndex.getSlot()));
            instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "charAt", "(I)C", false));
            instructions.add(new VarInsnNode(ISTORE, varCharacter.getSlot()));

            /// WRITE THE ACTUAL ALGORITHM ==============================================================================
            instructions.add(algorithm.render().compile());

            instructions.add(new VarInsnNode(ALOAD, varBuilder.getSlot()));
            instructions.add(new VarInsnNode(ILOAD, varIndex.getSlot()));
            instructions.add(new VarInsnNode(ILOAD, varCharacter.getSlot()));
            instructions.add(I2C);
            instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "setCharAt", "(IC)V", false));
        }


        // E:
        instructions.add(new IincInsnNode(varIndex.getSlot(), 1));
        instructions.add(new JumpInsnNode(GOTO, labelC));

        // F:
        instructions.add(labelF);
        instructions.add(new VarInsnNode(ALOAD, varBuilder.getSlot()));
        instructions.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "toString", "()Ljava/lang/String;", false));

        return instructions;
    }

    private boolean shouldObfuscate(ClassWrapper classWrapper, MethodWrapper methodWrapper) {
        final ClassNode classNode = classWrapper.getClassNode();

        boolean annotationPresentAny = classNode.visibleAnnotations != null &&
                classNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;"));

        if (annotationPresentAny)
            return false;

        final MethodNode methodNode = methodWrapper.getMethodNode();

        return !(methodNode.visibleAnnotations != null &&
                methodNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;")));
    }

    @Override
    public String getConfigName() {
        return "polymorphic_string_encryption";
    }
}