package cc.polymorphism.obfuscator.mutator.impl.encryption.string;

import cc.polymorphism.assembly.std.Compiler;
import cc.polymorphism.obfuscator.Polymorphism;
import cc.polymorphism.obfuscator.asm.wrapper.ClassWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.InsnListWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.MethodWrapper;
import cc.polymorphism.obfuscator.dictionary.Dictionary;
import cc.polymorphism.obfuscator.dictionary.DictionaryFactory;
import cc.polymorphism.obfuscator.engine.hash.Hash;
import cc.polymorphism.obfuscator.engine.seed.Seed;
import cc.polymorphism.obfuscator.exceptions.obfuscation.FatalObfuscationException;
import cc.polymorphism.obfuscator.mutator.Mutator;
import cc.polymorphism.obfuscator.mutator.common.Executor;
import cc.polymorphism.obfuscator.util.ASMUtils;
import cc.polymorphism.obfuscator.util.RandomUtils;
import cc.polymorphism.obfuscator.util.StringConcatFactoryUtils;
import lombok.Getter;
import lombok.Setter;
import org.objectweb.asm.tree.*;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * Encrypts strings and replaces their ldc instruction with an invokestatic to the decrypt method
 * and a getstatic to a field, since the decrypt method does not return anything.
 *
 * @author Liticane
 * @since 1.0 - 07:32 (EST) 02/28/2025 (MM/DD/YYYY)
 */
public class StringEncryptionMutator extends Mutator {
    private final Dictionary dictionary = DictionaryFactory.dictionaryOf("alphabetical");

    public StringEncryptionMutator(Polymorphism ctx) {
        super(ctx);
    }

    @Override
    public void transform() {
        classStream().filter(o -> !o.isEnum() && !o.isAnnotation() && !o.isInterface() && !o.isEnum()).filter(this::isIncluded).forEach(classWrapper -> {
            final var classNode = classWrapper.getClassNode();

            final var executors = new ArrayList<Executor<Integer, Data>>();

            for (int i = 0; i < RandomUtils.randomInt(1, 8); i++) {
                switch (RandomUtils.randomInt(2)) {
                    case 0:
                        executors.add(new XorIntegerStrAction());
                        break;
                    case 1:
                        executors.add(new XorSwitchIntegerStrAction(false));
                        break;
                    default:
                        throw new FatalObfuscationException("Failed to match random value!");
                }
            }

            final var cacheField = new FieldNode(ACC_PRIVATE | ACC_STATIC, dictionary.next(), "[Ljava/lang/String;", null, null);
            final var currentField = new FieldNode(ACC_PRIVATE | ACC_STATIC, dictionary.next(), "Ljava/lang/String;", null, null);

            final var decryptMethod = generateDecryption(executors, cacheField.name, currentField.name, dictionary.next());

            for (final var ain : decryptMethod.instructions.toArray()) {
                if (ain instanceof FieldInsnNode fin)
                    fin.owner = classNode.name;
            }

            AtomicReference<Hash> hash = new AtomicReference<>(null);

            final AtomicInteger indexCounter = new AtomicInteger(0);
            classWrapper.methodStream().filter(MethodWrapper::hasInstructions).filter(o -> shouldObfuscate(classWrapper, o)).forEach(methodWrapper -> {
                StringConcatFactoryUtils.removeStringConcatFactory(methodWrapper);

                final var methodNode = methodWrapper.getMethodNode();

                final var ctx = methodNode.instructions.toArray();
                final var content = methodNode.instructions;

                Seed seed = null;

                int leeway = methodWrapper.getLeewaySize();
                for (final AbstractInsnNode ain : ctx) {
                    if (leeway < 10000)
                        break;

                    if (!(ain instanceof LdcInsnNode lin && lin.cst instanceof String string)) {
                        continue;
                    }

                    if (hash.get() == null) {
                        hash.set(hashGenerator().generate(classWrapper));
                    }

                    if (seed == null) {
                        seed = seedGenerator().generateSeed(hash.get(), classWrapper, methodWrapper);
                    }

                    final int key = RandomUtils.randomInt();
                    final int value = seed.getValue() ^ key;

                    final var list = new InsnListWrapper();
                    list.add(new LdcInsnNode(
                            encrypt(
                                    executors, string,
                                    classNode.name.replace("/", ".").hashCode(),
                                    methodNode.name.hashCode(),
                                    value
                            )
                    ));
                    list.add(new VarInsnNode(ILOAD, seed.getVariable().getSlot()));
                    list.add(new LdcInsnNode(key));
                    list.add(IXOR);
                    list.add(new LdcInsnNode(indexCounter.get()));
                    list.add(new MethodInsnNode(INVOKESTATIC, classNode.name, decryptMethod.name, decryptMethod.desc));
                    list.add(new FieldInsnNode(GETSTATIC, classNode.name, currentField.name, currentField.desc));

                    content.insertBefore(ain, list);
                    content.remove(ain);

                    indexCounter.incrementAndGet();

                    leeway = methodWrapper.getLeewaySize();
                }
            });

            if (indexCounter.get() != 0) {
                classWrapper.addField(cacheField);
                classWrapper.addField(currentField);

                final var initializer = ASMUtils.findOrCreateInitializer(classNode);

                final var list = new InsnList();
                list.add(new LdcInsnNode(indexCounter.incrementAndGet()));
                list.add(new TypeInsnNode(ANEWARRAY, "java/lang/String"));
                list.add(new FieldInsnNode(PUTSTATIC, classNode.name, cacheField.name, cacheField.desc));
                initializer.instructions.insertBefore(initializer.instructions.getFirst(), list);
                classNode.methods.add(decryptMethod);
            }
        });
    }

    private String encrypt(final List<Executor<Integer, Data>> executors,
                           final String string,
                           final int classNameHash,
                           final int methodNameHash,
                           final int seed) {
        final var characters = string.toCharArray();
        for (var index = 0; index < characters.length; index++) {
            var result = (int) characters[index];
            for (final var executor : executors) {
                final Data data = new Data(index, classNameHash, methodNameHash, seed);
                result = executor.execute(result, data);
            }
            characters[index] = (char) (result ^ seed);
        }

        return new String(characters);
    }

    private MethodNode generateDecryption(final List<Executor<Integer, Data>> executors,
                                          final String cacheFieldName,
                                          final String currentFieldName,
                                          final String decryptMethodName) {
        final var compiler = new Compiler();

        final var code = new StringBuilder();
        for (Executor<Integer, Data> executor : executors.reversed()) {
            code.append(executor.code());
        }

        final var compilation = """
                public class Decryption {
                    private static String[] REPLACE_CACHE_NAME = new String[1024];
                    \t
                    private static String REPLACE_CURRENT_NAME;
                    \t
                    private static void REPLACE_DECRYPT_NAME(String string, int seed, int index) {
                        if (REPLACE_CACHE_NAME[index] == null) {
                            final char[] characters = string.toCharArray();
                            for (int i = 0; i < characters.length; i++) {
                                int result = characters[i] ^ seed;
                REPLACE_CODE
                                characters[i] = (char) result;
                            }
                            REPLACE_CACHE_NAME[index] = new String(characters);
                        }
                        REPLACE_CURRENT_NAME = REPLACE_CACHE_NAME[index];
                    }
                }
                """
                .replaceAll("REPLACE_CACHE_NAME", cacheFieldName)
                .replaceAll("REPLACE_CURRENT_NAME", currentFieldName)
                .replaceAll("REPLACE_DECRYPT_NAME", decryptMethodName)
                .replaceAll("REPLACE_CODE", code.toString());

        final var classNode = compiler.compile("Decryption", compilation);

        if (classNode == null) {
            System.out.println(compilation);
            throw new FatalObfuscationException("Could not compile string encryption logic!");
        }

        for (final var node : new ArrayList<>(classNode.methods)) {
            if (Objects.equals(node.name, decryptMethodName)) {
                return node;
            }
        }

        throw new FatalObfuscationException("Decryption Method not found in compiled string encryption logic!");
    }

    private boolean shouldObfuscate(ClassWrapper classWrapper, MethodWrapper methodWrapper) {
        final ClassNode classNode = classWrapper.getClassNode();

        boolean annotationPresentAny = classNode.visibleAnnotations != null &&
                classNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;"));

        if (annotationPresentAny)
            return false;

        final MethodNode methodNode = methodWrapper.getMethodNode();

        return !(methodNode.visibleAnnotations != null &&
                methodNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;")));
    }

    @Override
    public String getConfigName() {
        return "string_encryption";
    }

    record Data(
            int index,
            int classNameHash,
            int methodNameHash,
            int seed
    ) {
        /* */
    }

    @Getter
    @Setter
    static class XorIntegerStrAction extends Executor<Integer, Data> {
        private final Type type;
        private final int key;

        public XorIntegerStrAction() {
            this.type = Type.values()[RandomUtils.randomInt(Type.values().length)];
            this.key = RandomUtils.randomInt();
        }

        public XorIntegerStrAction(Type type) {
            this.type = type;
            this.key = RandomUtils.randomInt();
        }

        @Override
        public String code() {
            final var code = switch (type) {
                case CLASS_HASH ->
                        "REPLACE_VALUE_NAME ^= new Throwable().getStackTrace()[0].getClassName().hashCode() ^ REPLACE_KEY";
                case METHOD_HASH ->
                        "REPLACE_VALUE_NAME ^= new Throwable().getStackTrace()[1].getMethodName().hashCode() ^ REPLACE_KEY";
                default -> "REPLACE_VALUE_NAME ^= REPLACE_KEY";
            };

            // @formatter:off
            return  """
                    REPLACE_CODE;
                    """
                    .replaceAll("REPLACE_CODE", code)
                    .replaceAll("REPLACE_VALUE_NAME", "result")
                    .replaceAll("REPLACE_KEY", String.valueOf(key));
            // @formatter:on
        }

        @Override
        public Integer execute(Integer value, Data data) {
            return switch (type) {
                case CLASS_HASH -> value ^ data.classNameHash() ^ key;
                case METHOD_HASH -> value ^ data.methodNameHash() ^ key;
                default -> value ^ key;
            };
        }

        enum Type {
            NORMAL, CLASS_HASH, METHOD_HASH
        }
    }

    @Getter
    @Setter
    static class XorSwitchIntegerStrAction extends Executor<Integer, Data> {
        @SuppressWarnings({"unchecked"})
        private final Executor<Integer, Data>[] key = new Executor[RandomUtils.randomInt(2, 7) + 1];

        public XorSwitchIntegerStrAction(final boolean inner) {
            for (int i = 0; i < key.length; i++) {
                Executor<Integer, Data> executor;
                if (inner) {
                    executor = new XorIntegerStrAction(XorIntegerStrAction.Type.NORMAL);
                } else {
                    executor = switch (RandomUtils.randomInt(2)) {
                        case 0 -> new XorIntegerStrAction();
                        case 1 -> new XorSwitchIntegerStrAction(true);
                        default -> throw new IllegalStateException("Unexpected value: " + RandomUtils.randomInt(2));
                    };
                }
                this.key[i] = executor;
            }
        }

        @Override
        public String code() {
            // @formatter:off
            final var cases = new StringBuilder();
            for (var i = 0; i < this.key.length; i++) {
                cases.append("""
                                case REPLACE_VALUE:
                                    REPLACE_EXECUTOR
                                    break;
                             """
                        .replaceAll("REPLACE_VALUE", String.valueOf(i))
                        .replaceAll("REPLACE_EXECUTOR", key[i].code())
                );
            }

            return  """
                    switch (i % REPLACE_LENGTH) {
                        REPLACE_CASES
                    }
                    """
                    .replaceAll("REPLACE_LENGTH", String.valueOf(this.key.length))
                    .replaceAll("REPLACE_CASES", cases.toString());
            // @formatter:on
        }

        @Override
        public Integer execute(Integer value, Data data) {
            return this.key[data.index() % this.key.length].execute(value, data);
        }
    }
}