package rip.marie.mutator.data.string.task;

import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;
import rip.marie.util.wrapper.ClassWrapper;

import java.util.Arrays;

import static org.objectweb.asm.Opcodes.*;

public class RemoveConcatenationTask extends ClassWrapperTask {
    private static final char STACK_ARG_CONSTANT = '\u0001';
    private static final char BSM_ARG_CONSTANT = '\u0002';

    private static InsnList convertStringConcatFactory(final String pattern, final Type[] stackArgs, final int[] stackIndices, final Object[] bsmArgs) {
        final InsnList replacement = new InsnList();
        final char[] arr = pattern.toCharArray();

        int stackArgsIndex = 0;
        int bsmArgsIndex = 0;

        StringBuilder assembler = new StringBuilder();

        replacement.add(new TypeInsnNode(NEW, "java/lang/StringBuilder"));
        replacement.add(new InsnNode(DUP));
        replacement.add(new MethodInsnNode(INVOKESPECIAL, "java/lang/StringBuilder", "<init>", "()V"));
        for (char c : arr) {
            if (c == STACK_ARG_CONSTANT) {
                if (!assembler.isEmpty()) {
                    replacement.add(new LdcInsnNode(assembler.toString()));
                    replacement.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;"));
                    assembler = new StringBuilder();
                }

                final Type stackArg = stackArgs[stackArgsIndex++];
                final int stackIndex = stackIndices[stackArgsIndex - 1];

                if (stackArg.getSort() == Type.OBJECT) {
                    replacement.add(new VarInsnNode(ALOAD, stackIndex));
                    replacement.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/Object;)Ljava/lang/StringBuilder;"));
                } else if (stackArg.getSort() == Type.ARRAY) {
                    replacement.add(new VarInsnNode(ALOAD, stackIndex));
                    replacement.add(new MethodInsnNode(INVOKESTATIC, "java/util/Arrays", "toString", "([Ljava/lang/Object;)Ljava/lang/String;"));
                    replacement.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;"));
                } else {
                    replacement.add(new VarInsnNode(stackArg.getOpcode(ILOAD), stackIndex));

                    String adaptedDescriptor = stackArg.getDescriptor();
                    if (adaptedDescriptor.equals("B")
                            || adaptedDescriptor.equals("S")) {
                        adaptedDescriptor = "I";
                    }

                    replacement.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(" + adaptedDescriptor + ")Ljava/lang/StringBuilder;"));
                }
            } else if (c == BSM_ARG_CONSTANT) {
                replacement.add(new LdcInsnNode(bsmArgs[bsmArgsIndex++]));
                replacement.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/Object;)Ljava/lang/StringBuilder;"));
            } else {
                assembler.append(c);
            }
        }

        if (!assembler.isEmpty()) {
            replacement.add(new LdcInsnNode(assembler.toString()));
            replacement.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;"));
        }

        replacement.add(new MethodInsnNode(INVOKEVIRTUAL, "java/lang/StringBuilder", "toString", "()Ljava/lang/String;"));

        return replacement;
    }

    private static int count(final String string, final char query) {
        final char[] arr = string.toCharArray();
        int count = 0;

        for (char c : arr) {
            if (c == query)
                count++;
        }

        return count;
    }

    @Override
    public void execute(ClassWrapper wrapper) {
        wrapper.getMethods().forEach(methodWrapper -> {
            final MethodNode methodNode = methodWrapper.getBase();

            for (AbstractInsnNode instruction : methodNode.instructions.toArray()) {
                if (instruction.getOpcode() == INVOKEDYNAMIC) {
                    final InvokeDynamicInsnNode indy = (InvokeDynamicInsnNode) instruction;
                    if (indy.bsm.getOwner().equals("java/lang/invoke/StringConcatFactory") && indy.bsm.getName().equals("makeConcatWithConstants")) {
                        final String pattern = (String) indy.bsmArgs[0];
                        final Type[] stackArgs = Type.getArgumentTypes(indy.desc);
                        final Object[] bsmArgs = Arrays.copyOfRange(indy.bsmArgs, 1, indy.bsmArgs.length);
                        final int stackArgsCount = count(pattern, STACK_ARG_CONSTANT);
                        final int bsmArgsCount = count(pattern, BSM_ARG_CONSTANT);

                        if (stackArgs.length != stackArgsCount)
                            return;
                        if (bsmArgs.length != bsmArgsCount)
                            return;

                        int freeVarIndex = methodNode.maxLocals++;
                        final int[] stackIndices = new int[stackArgsCount];

                        for (int i = 0; i < stackArgs.length; i++) {
                            stackIndices[i] = freeVarIndex;
                            freeVarIndex += stackArgs[i].getSize();
                        }

                        for (int i = stackIndices.length - 1; i >= 0; i--) {
                            methodNode.instructions.insertBefore(indy, new VarInsnNode(stackArgs[i].getOpcode(ISTORE), stackIndices[i]));
                        }

                        final InsnList converted = convertStringConcatFactory(pattern, stackArgs, stackIndices, bsmArgs);
                        methodNode.instructions.insertBefore(indy, converted);
                        methodNode.instructions.remove(indy);
                    }
                }
            }
        });
    }
}