package cc.polymorphism.assembly.std;

import cc.polymorphism.assembly.WrappedType;
import cc.polymorphism.assembly.instructions.BytecodeLabel;
import cc.polymorphism.assembly.instructions.ConstantNode;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.MethodNode;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.*;

public class Utility {
    public static List<WrappedType> wrapMethodNodeParameters(MethodNode methodNode) {
        var wrappedTypes = new ArrayList<WrappedType>();
        List.of(Type.getArgumentTypes(methodNode.desc)).forEach(type -> wrappedTypes.add(new WrappedType(type)));
        return wrappedTypes;
    }

    public static List<WrappedType> wrapMethodParameters(Method method) {
        var wrappedTypes = new ArrayList<WrappedType>();
        List.of(method.getParameterTypes()).forEach(clazz -> wrappedTypes.add(WrappedType.from(clazz)));
        return wrappedTypes;
    }

    public static List<WrappedType> wrapConstructorParameters(Constructor<?> constructor) {
        var wrappedTypes = new ArrayList<WrappedType>();
        List.of(constructor.getParameterTypes()).forEach(clazz -> wrappedTypes.add(WrappedType.from(clazz)));
        return wrappedTypes;
    }

    public static String unwrapMethodDescriptor(List<WrappedType> parameterTypes, WrappedType returnType) {
        var sb = new StringBuilder("(");
        parameterTypes.forEach(type -> sb.append(type.unwrap()));
        sb.append(')').append(returnType.unwrap());
        return sb.toString();
    }

    public static ArrayList<LabelNode> unwrapLabels(List<BytecodeLabel> wrappedLabels) {
        var unwrappedLabels = new ArrayList<LabelNode>(wrappedLabels.size());
        wrappedLabels.forEach(wrappedLabel -> unwrappedLabels.add(wrappedLabel.getLabel()));
        return unwrappedLabels;
    }

    public static Object[] unpackConstants(List<ConstantNode> constants) {
        var unpacked = new Object[constants.size()];
        for (var i = 0; i < unpacked.length; i++) {
            unpacked[i] = constants.get(i).getValue();
        }
        return unpacked;
    }

    public static WrappedType box(WrappedType primitive) {
        if (!primitive.isPrimitive()) {
            throw new IllegalArgumentException("Attempted to box non-primitive type: " + primitive);
        }

        return switch (primitive.getSort()) {
            case Type.BOOLEAN -> WrappedType.from(Boolean.class);
            case Type.CHAR -> WrappedType.from(Character.class);
            case Type.BYTE -> WrappedType.from(Byte.class);
            case Type.SHORT -> WrappedType.from(Short.class);
            case Type.INT -> WrappedType.from(Integer.class);
            case Type.LONG -> WrappedType.from(Long.class);
            case Type.FLOAT -> WrappedType.from(Float.class);
            case Type.DOUBLE -> WrappedType.from(Double.class);
            default -> throw new IllegalArgumentException("Unknown primitive type: " + primitive);
        };
    }

    public static ClassNode makeClass(byte[] bytes) {
        return makeClass(bytes, ClassReader.SKIP_FRAMES | ClassReader.SKIP_DEBUG);
    }

    public static ClassNode makeClass(byte[] bytes, int flags) {
        final ClassReader reader = new ClassReader(bytes);
        final ClassNode node = new ClassNode();

        reader.accept(node, flags);

        return node;
    }

    private static byte[] readStream(final InputStream inputStream, final boolean close)
            throws IOException {

        if (inputStream == null) {
            throw new IOException("Class not found");
        }

        try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) {
            byte[] data = new byte[4096];
            int bytesRead;
            while ((bytesRead = inputStream.read(data, 0, data.length)) != -1) {
                outputStream.write(data, 0, bytesRead);
            }
            outputStream.flush();
            return outputStream.toByteArray();
        } finally {
            if (close) {
                inputStream.close();
            }
        }
    }
}
