package cc.polymorphism.obfuscator.mutator.impl.encryption.number;

import cc.polymorphism.assembly.BytecodeBlock;
import cc.polymorphism.assembly.WrappedType;
import cc.polymorphism.assembly.expressions.IRExpressions;
import cc.polymorphism.assembly.std.Compiler;
import cc.polymorphism.obfuscator.Polymorphism;
import cc.polymorphism.obfuscator.asm.wrapper.ClassWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.MethodWrapper;
import cc.polymorphism.obfuscator.dictionary.Dictionary;
import cc.polymorphism.obfuscator.dictionary.DictionaryFactory;
import cc.polymorphism.obfuscator.engine.hash.Hash;
import cc.polymorphism.obfuscator.engine.seed.Seed;
import cc.polymorphism.obfuscator.exceptions.obfuscation.FatalObfuscationException;
import cc.polymorphism.obfuscator.mutator.Mutator;
import cc.polymorphism.obfuscator.mutator.common.Executor;
import cc.polymorphism.obfuscator.mutator.impl.encryption.number.common.ByteData;
import cc.polymorphism.obfuscator.mutator.impl.encryption.number.common.XorByteExecutor;
import cc.polymorphism.obfuscator.mutator.impl.encryption.number.common.XorSwitchByteExecutor;
import cc.polymorphism.obfuscator.util.ASMUtils;
import cc.polymorphism.obfuscator.util.RandomUtils;
import org.objectweb.asm.tree.*;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * Encrypts integers and replaces their iconst, bipush, sipush or ldc instruction with an invokestatic to the decrypt method
 * and a getstatic to a field, the decrypt method does not return anything.
 *
 * @author Liticane
 * @since 1.0 - 12:17 (EST) 03/02/2025 (MM/DD/YYYY)
 */
public class IntegerMutator extends Mutator {
    private final Dictionary dictionary = DictionaryFactory.dictionaryOf("alphabetical");

    public IntegerMutator(Polymorphism ctx) {
        super(ctx);
    }

    @Override
    public void transform() {
        classStream().filter(o -> !o.isInterface() && !o.isAnnotation()).filter(this::isIncluded).forEach(classWrapper -> {
            final ClassNode classNode = classWrapper.getClassNode();

            final MemberNames memberNames = new MemberNames();

            final List<Executor<Byte, ByteData>> byteExecutors = new ArrayList<>();
            final List<Executor<Integer, IntegerData>> intExecutors = new ArrayList<>();

            for (int i = 0; i < RandomUtils.randomInt(1, 8); i++) {
                if (RandomUtils.randomBoolean()) {
                    byteExecutors.add(new XorByteExecutor());
                } else {
                    byteExecutors.add(new XorSwitchByteExecutor());
                }
            }

            for (int i = 0; i < RandomUtils.randomInt(1, 8); i++) {
                intExecutors.add(new XorIntExecutor());
            }

            final FieldNode cacheField = new FieldNode(ACC_PRIVATE | ACC_STATIC, memberNames.cacheName, "[Ljava/lang/Integer;",
                    null, null);
            final FieldNode currentField = new FieldNode(ACC_PRIVATE | ACC_STATIC, memberNames.currentName, "I",
                    null, null);

            final MethodNode decryptMethod = generateDecryption(byteExecutors, intExecutors, memberNames);

            for (final AbstractInsnNode ain : decryptMethod.instructions.toArray()) {
                if (ain instanceof FieldInsnNode fin)
                    fin.owner = classNode.name;
            }

            AtomicReference<Hash> hash = new AtomicReference<>(null);

            final AtomicInteger indexCounter = new AtomicInteger();
            classWrapper.methodStream().filter(MethodWrapper::hasInstructions).filter(o -> shouldObfuscate(classWrapper, o)).forEach(methodWrapper -> {
                final MethodNode methodNode = methodWrapper.getMethodNode();

                final InsnList content = methodNode.instructions;
                final AbstractInsnNode[] ctx = content.toArray();

                Seed seed = null;

                int leeway = methodWrapper.getLeewaySize();
                for (final AbstractInsnNode ain : ctx) {
                    if (leeway < 10000)
                        break;

                    if (!ASMUtils.isInteger(ain))
                        continue;

                    if (hash.get() == null) {
                        hash.set(hashGenerator().generate(classWrapper));
                    }

                    if (seed == null) {
                        seed = seedGenerator().generateSeed(hash.get(), classWrapper, methodWrapper);
                    }

                    final int key = RandomUtils.randomInt();
                    final int value = seed.getValue() ^ key;

                    int number = ASMUtils.getInteger(ain);

                    byte[] encrypted = encrypt(
                            intExecutors,
                            byteExecutors,
                            number,
                            classWrapper.getNormalizedName().hashCode(),
                            methodNode.name.hashCode(),
                            value
                    );

                    final BytecodeBlock block = new BytecodeBlock()
                            .append(IRExpressions.invokeStatic(
                                    classNode,
                                    decryptMethod,
                                    IRExpressions.newArray(
                                            byte.class,
                                            IRExpressions.intConstant(encrypted[0]),
                                            IRExpressions.intConstant(encrypted[1]),
                                            IRExpressions.intConstant(encrypted[2]),
                                            IRExpressions.intConstant(encrypted[3])
                                    ),
                                    IRExpressions.intXor(
                                            seed.getVariable(),
                                            IRExpressions.intConstant(key)
                                    ),
                                    IRExpressions.intConstant(indexCounter.getAndIncrement())
                            ))
                            .append(IRExpressions.getStatic(WrappedType.from(classNode), currentField.name, WrappedType.from(int.class)));

                    content.insertBefore(ain, block.compile());
                    content.remove(ain);

                    leeway = methodWrapper.getLeewaySize();
                }
            });

            if (indexCounter.get() != 0) {
                classWrapper.addField(cacheField);
                classWrapper.addField(currentField);

                final var initializer = ASMUtils.findOrCreateInitializer(classNode);
                final var list = new InsnList();
                list.add(new LdcInsnNode(indexCounter.incrementAndGet()));
                list.add(new TypeInsnNode(ANEWARRAY, "java/lang/Integer"));
                list.add(new FieldInsnNode(PUTSTATIC, classNode.name, cacheField.name, cacheField.desc));
                initializer.instructions.insertBefore(initializer.instructions.getFirst(), list);

                classNode.methods.add(decryptMethod);
            }
        });
    }

    /*
     * Encryption test.
     */
    @SuppressWarnings("ALL")
    private byte[] encrypt(List<Executor<Integer, IntegerData>> intExecutors,
                           List<Executor<Byte, ByteData>> byteExecutors,
                           int value,
                           int classNameHash,
                           int methodNameHash,
                           int seed) {
        value ^= seed;

        for (Executor<Integer, IntegerData> intExecutor : intExecutors) {
            final IntegerData integerData = new IntegerData(classNameHash, methodNameHash, seed);
            value = intExecutor.execute(value, integerData);
        }

        final byte[] bytes = new byte[]{
                (byte) (value >> 24),
                (byte) (value >> 16),
                (byte) (value >> 8),
                (byte) (value)
        };

        for (int i = 0; i < bytes.length; i++) {
            byte b = bytes[i];
            for (Executor<Byte, ByteData> byteExecutor : byteExecutors) {
                final ByteData data = new ByteData(i, classNameHash, methodNameHash);
                b = byteExecutor.execute(b, data);
            }
            bytes[i] = b;
        }

        return bytes;
    }

    /*
     * Decryption test.
     */
    @SuppressWarnings("ALL")
    private int decrypt(List<Executor<Integer, IntegerData>> intExecutors,
                        List<Executor<Byte, ByteData>> byteExecutors,
                        byte[] bytes,
                        int classNameHash,
                        int methodNameHash,
                        int seed) {
        for (int i = 0; i < bytes.length; i++) {
            byte b = bytes[i];
            for (Executor<Byte, ByteData> byteExecutor : byteExecutors) {
                final ByteData data = new ByteData(i, classNameHash, methodNameHash);
                b = byteExecutor.execute(b, data);
            }
            bytes[i] = b;
        }

        int value = ((bytes[0] & 0xFF) << 24) | ((bytes[1] & 0xFF) << 16) | ((bytes[2] & 0xFF) << 8) | (bytes[3] & 0xFF);

        for (Executor<Integer, IntegerData> intExecutor : intExecutors) {
            final IntegerData integerData = new IntegerData(classNameHash, methodNameHash, seed);
            value = intExecutor.execute(value, integerData);
        }

        return value ^ seed;
    }

    private MethodNode generateDecryption(final List<Executor<Byte, ByteData>> byteExecutors,
                                          final List<Executor<Integer, IntegerData>> intExecutors,
                                          final MemberNames memberNames) {
        final var compiler = new Compiler();

        final StringBuilder codeByte = new StringBuilder(), codeInt = new StringBuilder();

        for (Executor<Byte, ByteData> executor : byteExecutors) {
            codeByte.append(executor.code());
        }

        for (Executor<Integer, IntegerData> executor : intExecutors) {
            codeInt.append(executor.code());
        }

        final var compilation = """
                public class Decryption {
                    private static Integer[] REPLACE_CACHE_NAME = new Integer[1024];
                    private static int REPLACE_CURRENT_NAME = 0;
                
                    private static void REPLACE_DECRYPT_NAME(byte[] bytes, int seed, int index) {
                        if (REPLACE_CACHE_NAME[index] == null) {
                            for (int i = 0; i < bytes.length; i++) {
                                byte b = bytes[i];
                
                REPLACE_CODE_BYTE
                
                                bytes[i] = b;
                            }
                
                            int value = ((bytes[0] & 0xFF) << 24) | ((bytes[1] & 0xFF) << 16) | ((bytes[2] & 0xFF) << 8) | (bytes[3] & 0xFF);
                
                REPLACE_CODE_INT
                
                            REPLACE_CACHE_NAME[index] = value ^ seed;
                        }
                        REPLACE_CURRENT_NAME = REPLACE_CACHE_NAME[index];
                    }
                }
                """
                // Replace names
                .replaceAll("REPLACE_CACHE_NAME", memberNames.cacheName)
                .replaceAll("REPLACE_CURRENT_NAME", memberNames.currentName)
                .replaceAll("REPLACE_DECRYPT_NAME", memberNames.decryptName)

                // Replace code
                .replaceAll("REPLACE_CODE_BYTE", codeByte.toString())
                .replaceAll("REPLACE_CODE_INT", codeInt.toString());

        final var classNode = compiler.compile("Decryption", compilation);

        if (classNode == null) {
            System.out.println(compilation);
            throw new FatalObfuscationException("Could not compile string encryption logic!");
        }

        for (final var node : new ArrayList<>(classNode.methods)) {
            if (Objects.equals(node.name, memberNames.decryptName)) {
                return node;
            }
        }

        throw new FatalObfuscationException("Decryption Method not found in compiled string encryption logic!");
    }

    @Override
    public String getConfigName() {
        return "integer_encryption";
    }

    record IntegerData(
            int classNameHash,
            int methodNameHash,
            int seed
    ) {
        /* */
    }

    static class XorIntExecutor extends Executor<Integer, IntegerData> {
        private final Type type;
        private final int key;

        public XorIntExecutor() {
            this.type = Type.values()[RandomUtils.randomInt(Type.values().length)];
            this.key = RandomUtils.randomInt();
        }

        @Override
        public String code() {
            final var code = switch (type) {
                case CLASS_HASH ->
                        "REPLACE_VALUE_NAME ^= new Throwable().getStackTrace()[0].getClassName().hashCode() ^ REPLACE_KEY";
                case METHOD_HASH ->
                        "REPLACE_VALUE_NAME ^= new Throwable().getStackTrace()[1].getMethodName().hashCode() ^ REPLACE_KEY";
                default -> "REPLACE_VALUE_NAME ^= REPLACE_KEY";
            };

            // @formatter:off
            return  """
                    REPLACE_CODE;
                    """
                    .replaceAll("REPLACE_CODE", code)
                    .replaceAll("REPLACE_VALUE_NAME", "value")
                    .replaceAll("REPLACE_KEY", String.valueOf(key));
            // @formatter:on
        }

        @Override
        public Integer execute(Integer value, IntegerData data) {
            return switch (type) {
                case CLASS_HASH -> value ^ data.classNameHash() ^ key;
                case METHOD_HASH -> value ^ data.methodNameHash() ^ key;
                default -> value ^ key;
            };
        }

        enum Type {
            NORMAL, CLASS_HASH, METHOD_HASH
        }
    }

    private boolean shouldObfuscate(ClassWrapper classWrapper, MethodWrapper methodWrapper) {
        final ClassNode classNode = classWrapper.getClassNode();

        boolean annotationPresentAny = classNode.visibleAnnotations != null &&
                classNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;"));

        if (annotationPresentAny)
            return false;

        final MethodNode methodNode = methodWrapper.getMethodNode();

        return !(methodNode.visibleAnnotations != null &&
                methodNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;")));
    }

    class MemberNames {
        private final String cacheName, currentName, decryptName;

        public MemberNames() {
            this.cacheName = dictionary.next();
            this.currentName = dictionary.next();
            this.decryptName = dictionary.next();
        }
    }
}
