package cc.polymorphism.obfuscator.mutator.impl.encryption.ref;

import cc.polymorphism.annot.IncludeReference;
import cc.polymorphism.assembly.std.Compiler;
import cc.polymorphism.obfuscator.Polymorphism;
import cc.polymorphism.obfuscator.asm.remapper.PolymorphismRemapper;
import cc.polymorphism.obfuscator.asm.wrapper.ClassWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.FieldWrapper;
import cc.polymorphism.obfuscator.asm.wrapper.MethodWrapper;
import cc.polymorphism.obfuscator.dictionary.Dictionary;
import cc.polymorphism.obfuscator.dictionary.DictionaryFactory;
import cc.polymorphism.obfuscator.exceptions.missing.MissingClassException;
import cc.polymorphism.obfuscator.exceptions.obfuscation.FatalObfuscationException;
import cc.polymorphism.obfuscator.mutator.Mutator;
import cc.polymorphism.obfuscator.mutator.common.Executor;
import cc.polymorphism.obfuscator.util.RandomUtils;
import lombok.Getter;
import lombok.Setter;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.ClassRemapper;
import org.objectweb.asm.tree.*;

import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Replaces the following instructions:
 * - getstatic
 * - getfield
 * - putstatic
 * - putfield
 * - invokestatic
 * - invokevirtual
 * - invokeinterface
 * - invokespecial
 * with an encrypted invokedynamic equivalent.
 *
 * @author Reowya, Liticane (Compilation)
 * @since 1.0 - 06:58 (EST) 03/01/2025 (MM/DD/YYYY)
 */
public class ReferenceMutator extends Mutator {
    private final Dictionary dictionary = DictionaryFactory.dictionaryOf("alphabetical");

    public ReferenceMutator(Polymorphism ctx) {
        super(ctx);
    }

    @Override
    public String getConfigName() {
        return "reference_encryption";
    }

    @Override
    public void transform() {
        final MemberNames memberNames = new MemberNames();

        final List<Executor<Integer, Data>> executors = new ArrayList<>();

        for (int i = 0; i < RandomUtils.randomInt(1, 8); i++) {
            switch (RandomUtils.randomInt(2)) {
                case 0:
                    executors.add(new XorIntegerExecutor());
                    break;
                case 1:
                    executors.add(new XorSwitchIntegerExecutor());
                    break;
                default:
                    throw new FatalObfuscationException("Failed to match random value!");
            }
        }

        final String referenceClassName = "cc/polymorphism/stdlib/" + UUID.randomUUID() + "/" + memberNames.name;

        final AtomicInteger counter = new AtomicInteger(0);

        final Handle bsmHandle = new Handle(H_INVOKESTATIC, referenceClassName, memberNames.bootstrapName,
                "(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;I)Ljava/lang/Object;", false);

        classStream().filter(o -> !o.isEnum() && !o.isInterface() && !o.isAnnotation() && o.isInvokeDynamicAllowed()).filter(this::isIncluded).forEach(classWrapper -> {
            final ClassNode classNode = classWrapper.getClassNode();

            classWrapper.methodStream().filter(MethodWrapper::hasInstructions).filter(o -> shouldObfuscate(classWrapper, o)).forEach(methodWrapper -> {
                final MethodNode methodNode = methodWrapper.getMethodNode();

                int leeway = methodWrapper.getLeewaySize();
                for (AbstractInsnNode ain : methodNode.instructions.toArray()) {
                    if (leeway < 10000)
                        break;

                    if (ain instanceof MethodInsnNode min) {
                        if (!"<init>".equals(min.name) && !"<clinit>".equals(min.name)) {
                            boolean isStatic = (min.getOpcode() == INVOKESTATIC);

                            String newSignature = isStatic ? min.desc : min.desc.replace("(", "(Ljava/lang/Object;");

                            Type returnType = Type.getReturnType(min.desc);
                            Type[] args = Type.getArgumentTypes(newSignature);

                            for (int i = 0; i < args.length; i++) {
                                Type arg = args[i];

                                if (arg.getSort() == Type.OBJECT)
                                    args[i] = Type.getType("Ljava/lang/Object;");
                            }

                            newSignature = Type.getMethodDescriptor(returnType, args);

                            StringBuilder dest = new StringBuilder();
                            dest.append(min.owner.replace("/", "."))
                                    .append("<>")
                                    .append(min.name)
                                    .append("<>");

                            switch (ain.getOpcode()) {
                                case INVOKESTATIC:
                                    dest.append("0<>")
                                            .append(min.desc);
                                    break;
                                case INVOKEINTERFACE:
                                case INVOKEVIRTUAL:
                                    dest.append("1<>")
                                            .append(min.desc);
                                    break;
                                case INVOKESPECIAL:
                                    dest.append("2<>")
                                            .append(min.desc)
                                            .append("<>")
                                            .append(classNode.name.replace("/", "."));
                                    break;
                                default:
                                    break;
                            }

                            InvokeDynamicInsnNode indy = new InvokeDynamicInsnNode(
                                    encrypt(
                                            executors, dest.toString(),
                                            classNode.name.replace("/", ".").hashCode(),
                                            methodNode.name.hashCode()
                                    ),
                                    newSignature,
                                    bsmHandle,
                                    counter.getAndIncrement()
                            );

                            methodNode.instructions.set(ain, indy);
                            if (returnType.getSort() == Type.ARRAY)
                                methodNode.instructions.insert(indy, new TypeInsnNode(CHECKCAST,
                                        returnType.getInternalName()));

                            counter.getAndIncrement();
                        }
                    } else if (ain instanceof FieldInsnNode fin && !"<init>".equals(methodNode.name) && !"<clinit>".equals(methodNode.name)) {
                        final ClassWrapper pathOwner = classPathMap().get(fin.owner);
                        if (pathOwner == null)
                            throw MissingClassException.forLibraryClass(fin.owner);

                        FieldWrapper pathField = pathOwner.getFields().stream().filter(fieldWrapper ->
                                        fieldWrapper.getFieldNode().name.equals(fin.name)
                                                && fieldWrapper.getFieldNode().desc.equals(fin.desc)).findFirst()
                                .orElse(null);

                        if (pathField != null && Modifier.isFinal(pathField.getFieldNode().access))
                            continue;

                        String newSignature = getFieldSignature(fin);

                        StringBuilder dest = new StringBuilder();
                        dest.append(fin.owner.replace("/", ".")).append("<>").append(fin.name).append("<>");

                        switch (ain.getOpcode()) {
                            case GETSTATIC:
                                dest.append("3");
                                break;
                            case GETFIELD:
                                dest.append("4");
                                break;
                            case PUTSTATIC:
                                dest.append("5");
                                break;
                            case PUTFIELD:
                                dest.append("6");
                            default:
                                break;
                        }

                        InvokeDynamicInsnNode indy = new InvokeDynamicInsnNode(
                                encrypt(
                                        executors, dest.toString(),
                                        classNode.name.replace("/", ".").hashCode(),
                                        methodNode.name.hashCode()
                                ),
                                newSignature,
                                bsmHandle,
                                counter.getAndIncrement()
                        );

                        methodNode.instructions.set(ain, indy);
                    }

                    leeway = methodWrapper.getLeewaySize();
                }
            });
        });

        final PolymorphismRemapper mapper = new PolymorphismRemapper(new HashMap<>() {{
            this.put(memberNames.name, referenceClassName);
        }});

        ClassNode referenceClass = createBootstrap(executors, memberNames, counter.incrementAndGet());

        final ClassNode copy = new ClassNode();
        referenceClass.accept(new ClassRemapper(copy, mapper));

        referenceClass = copy;

        referenceClass.version = V1_8;
        addClass(referenceClass);
    }

    private static String getFieldSignature(FieldInsnNode fin) {
        boolean isStatic = (fin.getOpcode() == GETSTATIC
                || fin.getOpcode() == PUTSTATIC);
        boolean isSetter = (fin.getOpcode() == PUTFIELD
                || fin.getOpcode() == PUTSTATIC);

        String signature = (isSetter) ? "(" + fin.desc + ")V" : "()" + fin.desc;
        if (!isStatic)
            signature = signature.replace("(", "(Ljava/lang/Object;");

        return signature;
    }

    private String encrypt(final List<Executor<Integer, Data>> executors,
                           final String string,
                           final int classNameHash,
                           final int methodNameHash) {
        final var characters = string.toCharArray();
        for (var index = 0; index < characters.length; index++) {
            var result = (int) characters[index];
            for (final var executor : executors) {
                final Data longData = new Data(index, classNameHash, methodNameHash);
                result = executor.execute(result, longData);
            }
            characters[index] = (char) result;
        }

        return new String(characters);
    }

    @Getter
    @Setter
    public static class XorIntegerExecutor extends Executor<Integer, Data> {
        private final int key;

        public XorIntegerExecutor() {
            this.key = RandomUtils.randomInt();
        }

        @Override
        public String code() {
            final var code = "REPLACE_VALUE_NAME ^= REPLACE_KEY";

            // @formatter:off
            return  """
                    REPLACE_CODE;
                    """
                    .replaceAll("REPLACE_CODE", code)
                    .replaceAll("REPLACE_VALUE_NAME", "result")
                    .replaceAll("REPLACE_KEY", String.valueOf(key));
            // @formatter:on
        }

        @Override
        public Integer execute(Integer value, Data longData) {
            return value ^ key;
        }
    }

    @Getter
    @Setter
    public static class XorSwitchIntegerExecutor extends Executor<Integer, Data> {
        private final XorIntegerExecutor[] key = new XorIntegerExecutor[RandomUtils.randomInt(2, 7) + 1];

        public XorSwitchIntegerExecutor() {
            for (int i = 0; i < key.length; i++) {
                this.key[i] = new XorIntegerExecutor();
            }
        }

        @Override
        public String code() {
            // @formatter:off
            final var cases = new StringBuilder();
            for (var i = 0; i < this.key.length; i++) {
                cases.append("""
                                case REPLACE_VALUE:
                                    REPLACE_EXECUTOR
                                    break;
                             """
                        .replaceAll("REPLACE_VALUE", String.valueOf(i))
                        .replaceAll("REPLACE_EXECUTOR", key[i].code())
                );
            }

            return  """
                    switch (i % REPLACE_LENGTH) {
                        REPLACE_CASES
                    }
                    """
                    .replaceAll("REPLACE_LENGTH", String.valueOf(this.key.length))
                    .replaceAll("REPLACE_CASES", cases.toString());
            // @formatter:on
        }

        @Override
        public Integer execute(Integer value, Data longData) {
            return this.key[longData.index() % this.key.length].execute(value, longData);
        }
    }

    public record Data(
            int index,
            int classNameHash,
            int methodNameHash
    ) {
        /* */
    }

    private ClassNode createBootstrap(final List<Executor<Integer, Data>> executors,
                                      final MemberNames memberNames,
                                      final int cacheSize) {
        final var compiler = new Compiler();

        final var code = new StringBuilder();
        for (Executor<Integer, Data> executor : executors) {
            code.append(executor.code());
        }

        final var compilation = """
                import java.lang.invoke.ConstantCallSite;
                import java.lang.invoke.MethodHandle;
                import java.lang.invoke.MethodHandles;
                import java.lang.invoke.MethodType;
                import java.lang.reflect.Field;
                
                public class REPLACE_NAME {
                    private static String[] REPLACE_CACHE_NAME = new String[REPLACE_CACHE_SIZE];
                
                    private static String REPLACE_DECRYPT_NAME(String string, int index) {
                        if (REPLACE_CACHE_NAME[index] == null) {
                            final char[] characters = string.toCharArray();
                            for (int i = 0; i < characters.length; i++) {
                                int result = characters[i];
                REPLACE_CODE
                                characters[i] = (char) result;
                            }
                            REPLACE_CACHE_NAME[index] = new String(characters);
                        }
                        return REPLACE_CACHE_NAME[index];
                    }
                
                    public static Object REPLACE_BOOTSTRAP_NAME(final Object arg1, final Object arg2, final Object arg3, int index) {
                        try {
                            final String string = (String) arg2;
                            final String[] longData = REPLACE_DECRYPT_NAME(string, index).split("<>");
                
                            final Class<?> klass = Class.forName(longData[0]);
                            final String name = longData[1];
                
                            final int type = Integer.parseInt(longData[2]);
                            final MethodHandles.Lookup lookup = (MethodHandles.Lookup) arg1;
                
                            MethodHandle methodHandle = null;
                            switch (type) {
                                case 0:
                                    methodHandle = lookup.findStatic(
                                            klass,
                                            name,
                                            MethodType.fromMethodDescriptorString(longData[3], REPLACE_NAME.class.getClassLoader())
                                    );
                                    break;
                                case 1:
                                    methodHandle = lookup.findVirtual(
                                            klass,
                                            name,
                                            MethodType.fromMethodDescriptorString(longData[3], REPLACE_NAME.class.getClassLoader())
                                    );
                                    break;
                                case 2:
                                    methodHandle = lookup.findSpecial(
                                            klass,
                                            name,
                                            MethodType.fromMethodDescriptorString(longData[3], REPLACE_NAME.class.getClassLoader()),
                                            Class.forName(longData[4])
                                    );
                                    break;
                                case 3:
                                    final Field d = REPLACE_SEARCH_NAME(klass, name);
                                    if (d != null) {
                                        methodHandle = lookup.findStaticGetter(klass, name, d.getType());
                                    }
                                    break;
                                case 4:
                                    final Field d2 = REPLACE_SEARCH_NAME(klass, name);
                                    if (d2 != null) {
                                        methodHandle = lookup.findGetter(klass, name, d2.getType());
                                    }
                                    break;
                                case 5:
                                    final Field d3 = REPLACE_SEARCH_NAME(klass, name);
                                    if (d3 != null) {
                                        methodHandle = lookup.findStaticSetter(klass, name, d3.getType());
                                    }
                                    break;
                                case 6:
                                    final Field d4 = REPLACE_SEARCH_NAME(klass, name);
                                    if (d4 != null) {
                                        methodHandle = lookup.findSetter(klass, name, d4.getType());
                                    }
                                    break;
                                default:
                                    throw new IllegalArgumentException("Invalid: " + type + "!");
                            }
                
                            if (methodHandle == null) {
                                throw new RuntimeException("Method Handle was undefined.");
                            }
                
                            return new ConstantCallSite(methodHandle.asType((MethodType) arg3));
                        } catch (final Throwable t) {
                            t.printStackTrace(System.err);
                            return null;
                        }
                    }
                
                    private static Field REPLACE_SEARCH_NAME(final Class<?> klass, final String name) {
                        try {
                            final Field declaredField = klass.getDeclaredField(name);
                            declaredField.setAccessible(true);
                            return declaredField;
                        } catch (final NoSuchFieldException firstException) {
                            try {
                                final Class<?> superclass = klass.getSuperclass();
                                if (superclass == null) {
                                    throw new NoSuchFieldException();
                                }
                
                                final Field d = REPLACE_SEARCH_NAME(superclass, name);
                                if (d != null) {
                                    return d;
                                }
                            } catch (final NoSuchFieldException secondException) {
                                final Class<?>[] interfaces = klass.getInterfaces();
                
                                for (Class<?> anInterface : interfaces) {
                                    final Field d2 = REPLACE_SEARCH_NAME(anInterface, name);
                                    if (d2 != null) {
                                        return d2;
                                    }
                                }
                            }
                
                            return null;
                        }
                    }
                }
                """
                // Replace names
                .replaceAll("REPLACE_NAME", memberNames.name)
                .replaceAll("REPLACE_CACHE_NAME", memberNames.cacheName)
                .replaceAll("REPLACE_DECRYPT_NAME", memberNames.decryptName)
                .replaceAll("REPLACE_BOOTSTRAP_NAME", memberNames.bootstrapName)
                .replaceAll("REPLACE_SEARCH_NAME", memberNames.searchName)

                // Replace code
                .replaceAll("REPLACE_CODE", code.toString())

                // Replace sizes
                .replaceAll("REPLACE_CACHE_SIZE", String.valueOf(cacheSize));

        final var classNode = compiler.compile(memberNames.name, compilation);

        if (classNode == null) {
            System.out.println(compilation);
            throw new FatalObfuscationException("Could not compile string encryption logic!");
        }

        classNode.version = V1_8;

        return classNode;
    }

    private boolean shouldObfuscate(ClassWrapper classWrapper, MethodWrapper methodWrapper) {
        final ClassNode classNode = classWrapper.getClassNode();

        boolean annotationPresentAny = classNode.visibleAnnotations != null &&
                classNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/IncludeReference;"));

        if (annotationPresentAny)
            return true;

        final MethodNode methodNode = methodWrapper.getMethodNode();

        return methodNode.visibleAnnotations != null &&
                methodNode.visibleAnnotations.stream().anyMatch(node -> node.desc.equals("Lcc/polymorphism/annot/IncludeReference;"));
    }

    private class MemberNames {
        private final String name;
        private final String cacheName;
        private final String decryptName;
        private final String bootstrapName;
        private final String searchName;

        private MemberNames() {
            this.name = dictionary.next();
            this.cacheName = dictionary.next();
            this.decryptName = dictionary.next();
            this.bootstrapName = dictionary.next();
            this.searchName = dictionary.next();
        }
    }
}
