package cc.polymorphism.obfuscator.mutator.impl.encryption.number;

import cc.polymorphism.assembly.BytecodeBlock;
import cc.polymorphism.assembly.WrappedType;
import cc.polymorphism.assembly.expressions.IRExpressions;
import cc.polymorphism.assembly.std.Compiler;
import cc.polymorphism.obfuscator.Polymorphism;
import cc.polymorphism.obfuscator.asm.wrapper.ClassWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.MethodWrapper;
import cc.polymorphism.obfuscator.dictionary.Dictionary;
import cc.polymorphism.obfuscator.dictionary.DictionaryFactory;
import cc.polymorphism.obfuscator.engine.hash.Hash;
import cc.polymorphism.obfuscator.engine.seed.Seed;
import cc.polymorphism.obfuscator.exceptions.obfuscation.FatalObfuscationException;
import cc.polymorphism.obfuscator.mutator.Mutator;
import cc.polymorphism.obfuscator.mutator.common.Executor;
import cc.polymorphism.obfuscator.mutator.impl.encryption.number.common.ByteData;
import cc.polymorphism.obfuscator.mutator.impl.encryption.number.common.XorByteExecutor;
import cc.polymorphism.obfuscator.mutator.impl.encryption.number.common.XorSwitchByteExecutor;
import cc.polymorphism.obfuscator.util.ASMUtils;
import cc.polymorphism.obfuscator.util.RandomUtils;
import org.objectweb.asm.tree.*;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * Encrypts longs and replaces their lconst or ldc instruction with an invokestatic to the decrypt method
 * and a getstatic to a field, the decrypt method does not return anything.
 *
 * @author Reowya
 * @since 1.0 - 12:17 (EST) 03/02/2025 (MM/DD/YYYY)
 */
public class LongMutator extends Mutator {
    private final Dictionary dictionary = DictionaryFactory.dictionaryOf("alphabetical");

    public LongMutator(Polymorphism ctx) {
        super(ctx);
    }

    @Override
    public void transform() {
        classStream().filter(o -> !o.isInterface() && !o.isAnnotation()).filter(this::isIncluded).forEach(classWrapper -> {
            final ClassNode classNode = classWrapper.getClassNode();

            final MemberNames memberNames = new MemberNames();

            final List<Executor<Byte, ByteData>> byteExecutors = new ArrayList<>();
            final List<Executor<Long, LongData>> longExecutors = new ArrayList<>();

            for (int i = 0; i < RandomUtils.randomInt(1, 8); i++) {
                if (RandomUtils.randomBoolean()) {
                    byteExecutors.add(new XorByteExecutor());
                } else {
                    byteExecutors.add(new XorSwitchByteExecutor());
                }
            }

            for (int i = 0; i < RandomUtils.randomInt(1, 8); i++) {
                longExecutors.add(new XorIntExecutor());
            }

            final FieldNode cacheField = new FieldNode(ACC_PRIVATE | ACC_STATIC, memberNames.cacheName, "[Ljava/lang/Long;",
                    null, null);
            final FieldNode currentField = new FieldNode(ACC_PRIVATE | ACC_STATIC, memberNames.currentName, "J",
                    null, null);

            final MethodNode decryptMethod = generateDecryption(byteExecutors, longExecutors, memberNames);

            for (final AbstractInsnNode ain : decryptMethod.instructions.toArray()) {
                if (ain instanceof FieldInsnNode fin)
                    fin.owner = classNode.name;
            }

            AtomicReference<Hash> hash = new AtomicReference<>(null);

            final AtomicInteger indexCounter = new AtomicInteger();
            classWrapper.methodStream().filter(MethodWrapper::hasInstructions).filter(o -> shouldObfuscate(classWrapper, o)).forEach(methodWrapper -> {
                final MethodNode methodNode = methodWrapper.getMethodNode();

                final InsnList content = methodNode.instructions;
                final AbstractInsnNode[] ctx = content.toArray();

                Seed seed = null;

                int leeway = methodWrapper.getLeewaySize();
                for (final AbstractInsnNode ain : ctx) {
                    if (leeway < 10000)
                        break;

                    if (!ASMUtils.isLong(ain))
                        continue;

                    if (hash.get() == null) {
                        hash.set(hashGenerator().generate(classWrapper));
                    }

                    if (seed == null) {
                        seed = seedGenerator().generateSeed(hash.get(), classWrapper, methodWrapper);
                    }

                    final int key = RandomUtils.randomInt();
                    final int value = seed.getValue() ^ key;

                    long number = ASMUtils.getLong(ain);

                    byte[] encrypted = encrypt(
                            longExecutors, byteExecutors,
                            number,
                            classWrapper.getNormalizedName().hashCode(),
                            methodNode.name.hashCode(),
                            value
                    );

                    final BytecodeBlock block = new BytecodeBlock()
                            .append(IRExpressions.invokeStatic(
                                    classNode,
                                    decryptMethod,
                                    IRExpressions.newArray(
                                            byte.class,
                                            IRExpressions.intConstant(encrypted[0]),
                                            IRExpressions.intConstant(encrypted[1]),
                                            IRExpressions.intConstant(encrypted[2]),
                                            IRExpressions.intConstant(encrypted[3]),
                                            IRExpressions.intConstant(encrypted[4]),
                                            IRExpressions.intConstant(encrypted[5]),
                                            IRExpressions.intConstant(encrypted[6]),
                                            IRExpressions.intConstant(encrypted[7])
                                    ),
                                    IRExpressions.intXor(
                                            seed.getVariable(),
                                            IRExpressions.intConstant(key)
                                    ),
                                    IRExpressions.intConstant(indexCounter.getAndIncrement())
                            ))
                            .append(IRExpressions.getStatic(WrappedType.from(classNode), currentField.name, WrappedType.from(long.class)));

                    content.insertBefore(ain, block.compile());
                    content.remove(ain);

                    leeway = methodWrapper.getLeewaySize();
                }
            });

            if (indexCounter.get() != 0) {
                classWrapper.addField(cacheField);
                classWrapper.addField(currentField);

                final var initializer = ASMUtils.findOrCreateInitializer(classNode);
                final var list = new InsnList();
                list.add(new LdcInsnNode(indexCounter.incrementAndGet()));
                list.add(new TypeInsnNode(ANEWARRAY, "java/lang/Long"));
                list.add(new FieldInsnNode(PUTSTATIC, classNode.name, cacheField.name, cacheField.desc));
                initializer.instructions.insertBefore(initializer.instructions.getFirst(), list);

                classNode.methods.add(decryptMethod);
            }
        });
    }

    record LongData(
            int classNameHash,
            int methodNameHash,
            int seed
    ) {
        /* */
    }

    /*
     * Encryption test.
     */
    @SuppressWarnings("ALL")
    private byte[] encrypt(List<Executor<Long, LongData>> longExecutors,
                           List<Executor<Byte, ByteData>> byteExecutors,
                           long value,
                           int classNameHash,
                           int methodNameHash,
                           int seed) {
        value ^= seed;

        for (Executor<Long, LongData> intExecutor : longExecutors) {
            final LongData longData = new LongData(classNameHash, methodNameHash, seed);
            value = intExecutor.execute(value, longData);
        }

        final byte[] bytes = new byte[8];
        for (int i = 0; i < 8; i++) {
            bytes[7 - i] = (byte) (value >>> (i * 8));
        }

        for (int i = 0; i < bytes.length; i++) {
            byte b = bytes[i];
            for (Executor<Byte, ByteData> byteExecutor : byteExecutors) {
                final ByteData data = new ByteData(i, classNameHash, methodNameHash);
                b = byteExecutor.execute(b, data);
            }
            bytes[i] = b;
        }

        return bytes;
    }

    /*
     * Decryption test.
     */
    @SuppressWarnings("ALL")
    private long decrypt(List<Executor<Long, LongData>> longExecutors,
                         List<Executor<Byte, ByteData>> byteExecutors,
                         byte[] bytes,
                         int classNameHash,
                         int methodNameHash,
                         int seed) {
        for (int i = 0; i < bytes.length; i++) {
            byte b = bytes[i];
            for (Executor<Byte, ByteData> byteExecutor : byteExecutors) {
                final ByteData data = new ByteData(i, classNameHash, methodNameHash);
                b = byteExecutor.execute(b, data);
            }
            bytes[i] = b;
        }

        long value = 0;
        for (int i = 0; i < 8; i++) {
            value = (value << 8) | (bytes[i] & 0xFF);
        }

        for (Executor<Long, LongData> intExecutor : longExecutors) {
            final LongData longData = new LongData(classNameHash, methodNameHash, seed);
            value = intExecutor.execute(value, longData);
        }

        return value ^ seed;
    }

    static class XorIntExecutor extends Executor<Long, LongData> {
        private final Type type;
        private final int key;

        public XorIntExecutor() {
            this.type = Type.values()[RandomUtils.randomInt(Type.values().length)];
            this.key = RandomUtils.randomInt();
        }

        @Override
        public String code() {
            final var code = switch (type) {
                case CLASS_HASH ->
                        "REPLACE_VALUE_NAME ^= new Throwable().getStackTrace()[0].getClassName().hashCode() ^ REPLACE_KEY";
                case METHOD_HASH ->
                        "REPLACE_VALUE_NAME ^= new Throwable().getStackTrace()[1].getMethodName().hashCode() ^ REPLACE_KEY";
                default -> "REPLACE_VALUE_NAME ^= REPLACE_KEY";
            };

            // @formatter:off
            return  """
                    REPLACE_CODE;
                    """
                    .replaceAll("REPLACE_CODE", code)
                    .replaceAll("REPLACE_VALUE_NAME", "value")
                    .replaceAll("REPLACE_KEY", String.valueOf(key));
            // @formatter:on
        }

        @Override
        public Long execute(Long value, LongData data) {
            return switch (type) {
                case CLASS_HASH -> value ^ data.classNameHash() ^ key;
                case METHOD_HASH -> value ^ data.methodNameHash() ^ key;
                default -> value ^ key;
            };
        }

        enum Type {
            NORMAL, CLASS_HASH, METHOD_HASH
        }
    }


    private MethodNode generateDecryption(final List<Executor<Byte, ByteData>> byteExecutors,
                                          final List<Executor<Long, LongData>> longExecutors,
                                          final MemberNames memberNames) {
        final var compiler = new Compiler();

        final StringBuilder codeByte = new StringBuilder(), codeInt = new StringBuilder();

        for (Executor<Byte, ByteData> executor : byteExecutors) {
            codeByte.append(executor.code());
        }

        for (Executor<Long, LongData> executor : longExecutors) {
            codeInt.append(executor.code());
        }

        final var compilation = """
                public class Decryption {
                    private static Long[] REPLACE_CACHE_NAME = new Long[1024];
                    private static long REPLACE_CURRENT_NAME = 0;
                
                    private static void REPLACE_DECRYPT_NAME(byte[] bytes, int seed, int index) {
                        if (REPLACE_CACHE_NAME[index] == null) {
                            for (int i = 0; i < bytes.length; i++) {
                                byte b = bytes[i];
                
                REPLACE_CODE_BYTE
                
                                bytes[i] = b;
                            }
                
                            long value = 0;
                            for (int i = 0; i < 8; i++) {
                                value = (value << 8) | (bytes[i] & 0xFF); // Shift and combine bytes
                            }
                
                REPLACE_CODE_INT
                
                            REPLACE_CACHE_NAME[index] = value ^ seed;
                        }
                        REPLACE_CURRENT_NAME = REPLACE_CACHE_NAME[index];
                    }
                }
                """
                // Replace names
                .replaceAll("REPLACE_CACHE_NAME", memberNames.cacheName)
                .replaceAll("REPLACE_CURRENT_NAME", memberNames.currentName)
                .replaceAll("REPLACE_DECRYPT_NAME", memberNames.decryptName)

                // Replace code
                .replaceAll("REPLACE_CODE_BYTE", codeByte.toString())
                .replaceAll("REPLACE_CODE_INT", codeInt.toString());

        final var classNode = compiler.compile("Decryption", compilation);

        if (classNode == null) {
            System.out.println(compilation);
            throw new FatalObfuscationException("Could not compile string encryption logic!");
        }

        for (final var node : new ArrayList<>(classNode.methods)) {
            if (Objects.equals(node.name, memberNames.decryptName)) {
                return node;
            }
        }

        throw new FatalObfuscationException("Decryption Method not found in compiled string encryption logic!");
    }

    private boolean shouldObfuscate(ClassWrapper classWrapper, MethodWrapper methodWrapper) {
        final ClassNode classNode = classWrapper.getClassNode();

        boolean annotationPresentAny = classNode.visibleAnnotations != null &&
                classNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;"));

        if (annotationPresentAny)
            return false;

        final MethodNode methodNode = methodWrapper.getMethodNode();

        return !(methodNode.visibleAnnotations != null &&
                methodNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/ExcludeConstant;")));
    }

    class MemberNames {
        private final String cacheName, currentName, decryptName;

        public MemberNames() {
            this.cacheName = dictionary.next();
            this.currentName = dictionary.next();
            this.decryptName = dictionary.next();
        }
    }

    @Override
    public String getConfigName() {
        return "long_encryption";
    }
}
