package tech.atani.client.util.game.render.shader;

import lombok.Getter;
import net.minecraft.client.renderer.GlStateManager;
import org.joml.Matrix4f;
import org.lwjgl.BufferUtils;
import org.lwjgl.opengl.GL45C;
import tech.atani.client.util.Util;
import tech.atani.client.util.client.interfaces.ILogger;
import tech.atani.client.util.game.render.shader.data.ShaderData;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.FloatBuffer;
import java.nio.ShortBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

@Getter
@SuppressWarnings("unused")
public abstract class Shader extends Util implements AutoCloseable {
    private final List<Integer> shaderPrograms = new ArrayList<>();
    public static volatile int currentlyBoundProgram = 0;
    private int currentProgramIndex = 0;

    private int vertexArrayObject = 0;
    private int vertexBufferObject = 0;
    private int elementBufferObject = 0;
    private int uniformBufferObject = 0;

    private static final Map<String, String> shaderSourceCache = new ConcurrentHashMap<>();
    private final Map<String, Integer> uniformLocations = new HashMap<>();
    private final Map<String, float[]> cachedUniformFloats = new HashMap<>();
    private final Map<String, int[]> cachedUniformInts = new HashMap<>();
    private final Map<String, float[]> cachedUniformMatrices = new HashMap<>();

    private final FloatBuffer matrixUploadBuffer = BufferUtils.createFloatBuffer(16);
    private final FloatBuffer floatUniformUploadBuffer = BufferUtils.createFloatBuffer(4);
    private final FloatBuffer vertexUploadBuffer = BufferUtils.createFloatBuffer(16);
    private final ShortBuffer elementIndexBuffer = BufferUtils.createShortBuffer(6);
    private final FloatBuffer uboBuffer = BufferUtils.createFloatBuffer(16);

    private static final int MAX_UNIFORM_COMPONENTS = 4;
    private static final int VERTEX_SIZE_FLOATS = 4;
    private static final int VERTEX_SIZE_BYTES = VERTEX_SIZE_FLOATS * Float.BYTES;
    private static final int UBO_BINDING_POINT = 0;

    public Shader() {
        short[] indices = {0, 1, 2, 2, 1, 3};
        elementIndexBuffer.put(indices).flip();
        initializePrograms();
        initializeQuadResources();
        initializeUBO();
    }

    public void bind(int programIndex) {
        if (programIndex < 0 || programIndex >= shaderPrograms.size()) {
            ILogger.logger.error("Invalid shader program index: {}", programIndex);
            return;
        }
        this.currentProgramIndex = programIndex;
        int programId = shaderPrograms.get(programIndex);
        if (currentlyBoundProgram != programId) {
            GL45C.glUseProgram(programId);
            currentlyBoundProgram = programId;
            if (uniformBufferObject != 0) {
                GL45C.glBindBufferBase(GL45C.GL_UNIFORM_BUFFER, UBO_BINDING_POINT, uniformBufferObject);
            }
        }
    }

    public void bind() {
        bind(this.currentProgramIndex);
    }

    public void unbind() {
        if (currentlyBoundProgram != 0) {
            GL45C.glUseProgram(0);
            GL45C.glBindBuffer(GL45C.GL_UNIFORM_BUFFER, 0);
            currentlyBoundProgram = 0;
        }
    }

    public void destroy() {
        unbind();
        for (int programId : shaderPrograms) {
            GL45C.glDeleteProgram(programId);
        }
        shaderPrograms.clear();
        if (vertexArrayObject != 0) {
            GL45C.glDeleteVertexArrays(vertexArrayObject);
            vertexArrayObject = 0;
        }
        if (vertexBufferObject != 0) {
            GL45C.glDeleteBuffers(vertexBufferObject);
            vertexBufferObject = 0;
        }
        if (elementBufferObject != 0) {
            GL45C.glDeleteBuffers(elementBufferObject);
            elementBufferObject = 0;
        }
        if (uniformBufferObject != 0) {
            GL45C.glDeleteBuffers(uniformBufferObject);
            uniformBufferObject = 0;
        }
        uniformLocations.clear();
        cachedUniformFloats.clear();
        cachedUniformInts.clear();
        cachedUniformMatrices.clear();
    }

    public void reload() {
        destroy();
        shaderSourceCache.clear();
        initializePrograms();
        initializeQuadResources();
        initializeUBO();
    }

    private String preprocessShaderSource(String source, String basePath) {
        StringBuilder processedSource = new StringBuilder();
        Pattern includePattern = Pattern.compile("^\\s*#include\\s+\"([^\"]+)\"\\s*$", Pattern.MULTILINE);
        Matcher matcher = includePattern.matcher(source);
        int lastPos = 0;

        while (matcher.find()) {
            processedSource.append(source, lastPos, matcher.start());
            String includeFile = matcher.group(1);
            String includePath = "/assets/atani/shaders/" + includeFile;
            String includeSource = shaderSourceCache.computeIfAbsent(includePath, path -> {
                try (InputStream inputStream = getClass().getResourceAsStream(path)) {
                    if (inputStream == null) {
                        ILogger.logger.error("Included shader resource not found: {}", path);
                        return "";
                    }
                    try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
                        return reader.lines().collect(Collectors.joining("\n"));
                    }
                } catch (Exception e) {
                    ILogger.logger.error("Failed to read included shader source: {}", path, e);
                    return "";
                }
            });
            processedSource.append(includeSource);
            lastPos = matcher.end();
        }
        processedSource.append(source, lastPos, source.length());
        return processedSource.toString();
    }

    private void initializePrograms() {
        ShaderData shaderData = getClass().getAnnotation(ShaderData.class);
        if (shaderData == null || shaderData.frag().length == 0 || shaderData.vertex() == null || shaderData.vertex().isEmpty()) {
            ILogger.logger.error("ShaderData annotation missing or incomplete for {}", getClass().getName());
            return;
        }
        int vertexShaderId = compileShader(shaderData.vertex(), GL45C.GL_VERTEX_SHADER);
        if (vertexShaderId == 0) {
            ILogger.logger.error("Failed to compile vertex shader: {}", shaderData.vertex());
            return;
        }
        try {
            for (String fragmentShaderPath : shaderData.frag()) {
                int fragmentShaderId = compileShader(fragmentShaderPath, GL45C.GL_FRAGMENT_SHADER);
                if (fragmentShaderId == 0) {
                    ILogger.logger.error("Failed to compile fragment shader: {}", fragmentShaderPath);
                    continue;
                }
                int programId = GL45C.glCreateProgram();
                if (programId == 0) {
                    ILogger.logger.error("Failed to create shader program.");
                    GL45C.glDeleteShader(fragmentShaderId);
                    continue;
                }
                GL45C.glAttachShader(programId, vertexShaderId);
                GL45C.glAttachShader(programId, fragmentShaderId);
                GL45C.glLinkProgram(programId);
                if (GL45C.glGetProgrami(programId, GL45C.GL_LINK_STATUS) == GL45C.GL_FALSE) {
                    String log = GL45C.glGetProgramInfoLog(programId);
                    ILogger.logger.error("Shader program linking failed for {} + {}: {}", shaderData.vertex(), fragmentShaderPath, log);
                    GL45C.glDeleteProgram(programId);
                } else {
                    int uboIndex = GL45C.glGetUniformBlockIndex(programId, "ShaderUBO");
                    if (uboIndex != -1) {
                        GL45C.glUniformBlockBinding(programId, uboIndex, UBO_BINDING_POINT);
                    }
                    shaderPrograms.add(programId);
                }
                GL45C.glDetachShader(programId, fragmentShaderId);
                GL45C.glDeleteShader(fragmentShaderId);
            }
        } finally {
            GL45C.glDeleteShader(vertexShaderId);
        }
        if (shaderPrograms.isEmpty()) {
            ILogger.logger.error("No shader programs were successfully created for {}", getClass().getName());
        }
    }

    private int compileShader(String resourcePath, int type) {
        String fullResourcePath = "/assets/atani/shaders/" + resourcePath;
        String source = shaderSourceCache.computeIfAbsent(fullResourcePath, path -> {
            try (InputStream inputStream = getClass().getResourceAsStream(path)) {
                if (inputStream == null) {
                    ILogger.logger.error("Shader resource not found: {}", path);
                    return null;
                }
                try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) {
                    return reader.lines().collect(Collectors.joining("\n"));
                }
            } catch (Exception e) {
                ILogger.logger.error("Failed to read shader source: {}", path, e);
                return null;
            }
        });
        if (source == null || source.isEmpty()) {
            return 0;
        }
        source = preprocessShaderSource(source, fullResourcePath);
        int shaderId = GL45C.glCreateShader(type);
        if (shaderId == 0) {
            ILogger.logger.error("Failed to create shader object for type: {}", type);
            return 0;
        }
        GL45C.glShaderSource(shaderId, source);
        GL45C.glCompileShader(shaderId);
        if (GL45C.glGetShaderi(shaderId, GL45C.GL_COMPILE_STATUS) == GL45C.GL_FALSE) {
            String log = GL45C.glGetShaderInfoLog(shaderId);
            String shaderTypeName = (type == GL45C.GL_VERTEX_SHADER) ? "Vertex" : (type == GL45C.GL_FRAGMENT_SHADER) ? "Fragment" : "Unknown";
            ILogger.logger.error("{} shader compilation failed for {}: {}", shaderTypeName, resourcePath, log);
            GL45C.glDeleteShader(shaderId);
            return 0;
        }
        return shaderId;
    }

    private void initializeQuadResources() {
        if (!shaderPrograms.isEmpty()) {
            vertexArrayObject = GL45C.glGenVertexArrays();
            vertexBufferObject = GL45C.glGenBuffers();
            elementBufferObject = GL45C.glGenBuffers();

            GL45C.glBindVertexArray(vertexArrayObject);
            GL45C.glBindBuffer(GL45C.GL_ARRAY_BUFFER, vertexBufferObject);
            GL45C.glBufferData(GL45C.GL_ARRAY_BUFFER, (long) vertexUploadBuffer.capacity() * Float.BYTES, GL45C.GL_STATIC_DRAW);
            GL45C.glBindBuffer(GL45C.GL_ELEMENT_ARRAY_BUFFER, elementBufferObject);
            GL45C.glBufferData(GL45C.GL_ELEMENT_ARRAY_BUFFER, elementIndexBuffer, GL45C.GL_STATIC_DRAW);

            GL45C.glEnableVertexAttribArray(0);
            GL45C.glVertexAttribPointer(0, 2, GL45C.GL_FLOAT, false, VERTEX_SIZE_BYTES, 0);
            GL45C.glEnableVertexAttribArray(1);
            GL45C.glVertexAttribPointer(1, 2, GL45C.GL_FLOAT, false, VERTEX_SIZE_BYTES, 2 * Float.BYTES);

            GL45C.glBindVertexArray(0);
            GL45C.glBindBuffer(GL45C.GL_ARRAY_BUFFER, 0);
            GL45C.glBindBuffer(GL45C.GL_ELEMENT_ARRAY_BUFFER, 0);
        }
    }

    private void initializeUBO() {
        if (!shaderPrograms.isEmpty()) {
            uniformBufferObject = GL45C.glGenBuffers();
            GL45C.glBindBuffer(GL45C.GL_UNIFORM_BUFFER, uniformBufferObject);
            GL45C.glBufferData(GL45C.GL_UNIFORM_BUFFER, (long) uboBuffer.capacity() * Float.BYTES, GL45C.GL_STATIC_DRAW);
            GL45C.glBindBuffer(GL45C.GL_UNIFORM_BUFFER, 0);
        }
    }

    private int getUniformLocation(String uniformName) {
        if (shaderPrograms.isEmpty() || currentProgramIndex >= shaderPrograms.size()) return -1;
        int programId = shaderPrograms.get(currentProgramIndex);
        String cacheKey = programId + "_" + uniformName;
        return uniformLocations.computeIfAbsent(cacheKey, k -> GL45C.glGetUniformLocation(programId, uniformName));
    }

    protected void setUniformFloats(String uniformName, float... values) {
        if (values.length == 0 || values.length > MAX_UNIFORM_COMPONENTS) return;
        if (shaderPrograms.isEmpty() || currentProgramIndex >= shaderPrograms.size()) return;
        int programId = shaderPrograms.get(currentProgramIndex);
        String cacheKey = programId + "_" + uniformName;
        float[] cached = cachedUniformFloats.get(cacheKey);
        if (cached != null && Arrays.equals(cached, values)) return;
        cachedUniformFloats.put(cacheKey, values.clone());
        int location = getUniformLocation(uniformName);
        if (location != -1) {
            if (currentlyBoundProgram != programId) {
                bind();
            }
            floatUniformUploadBuffer.clear().put(values).flip();
            switch (values.length) {
                case 1 -> GL45C.glUniform1fv(location, floatUniformUploadBuffer);
                case 2 -> GL45C.glUniform2fv(location, floatUniformUploadBuffer);
                case 3 -> GL45C.glUniform3fv(location, floatUniformUploadBuffer);
                case 4 -> GL45C.glUniform4fv(location, floatUniformUploadBuffer);
            }
        }
    }

    protected float[] getCachedUniformFloats(String uniformName) {
        if (shaderPrograms.isEmpty() || currentProgramIndex >= shaderPrograms.size()) return new float[0];
        int programId = shaderPrograms.get(currentProgramIndex);
        String cacheKey = programId + "_" + uniformName;
        return cachedUniformFloats.getOrDefault(cacheKey, new float[0]).clone();
    }

    protected void setUniformInts(String uniformName, int... values) {
        if (values.length == 0 || values.length > MAX_UNIFORM_COMPONENTS) return;
        if (shaderPrograms.isEmpty() || currentProgramIndex >= shaderPrograms.size()) return;
        int programId = shaderPrograms.get(currentProgramIndex);
        String cacheKey = programId + "_" + uniformName;
        int[] cached = cachedUniformInts.get(cacheKey);
        if (cached != null && Arrays.equals(cached, values)) return;
        cachedUniformInts.put(cacheKey, values.clone());
        int location = getUniformLocation(uniformName);
        if (location != -1) {
            if (currentlyBoundProgram != programId) {
                bind();
            }
            switch (values.length) {
                case 1 -> GL45C.glUniform1i(location, values[0]);
                case 2 -> GL45C.glUniform2i(location, values[0], values[1]);
                case 3 -> GL45C.glUniform3i(location, values[0], values[1], values[2]);
                case 4 -> GL45C.glUniform4i(location, values[0], values[1], values[2], values[3]);
            }
        }
    }

    protected int[] getCachedUniformInts(String uniformName) {
        if (shaderPrograms.isEmpty() || currentProgramIndex >= shaderPrograms.size()) return new int[0];
        int programId = shaderPrograms.get(currentProgramIndex);
        String cacheKey = programId + "_" + uniformName;
        return cachedUniformInts.getOrDefault(cacheKey, new int[0]).clone();
    }

    protected void setUniformMatrix4fv(float[] matrix) {
        if (matrix.length != 16) return;
        if (shaderPrograms.isEmpty() || currentProgramIndex >= shaderPrograms.size()) return;
        int programId = shaderPrograms.get(currentProgramIndex);
        String cacheKey = programId + "_mvpMatrix";
        float[] cached = cachedUniformMatrices.get(cacheKey);
        if (cached != null && Arrays.equals(cached, matrix)) return;
        cachedUniformMatrices.put(cacheKey, matrix.clone());
        int location = getUniformLocation("mvpMatrix");
        if (location != -1) {
            if (currentlyBoundProgram != programId) {
                bind();
            }
            matrixUploadBuffer.clear().put(matrix).flip();
            GL45C.glUniformMatrix4fv(location, false, matrixUploadBuffer);
        }
    }

    protected void updateUBO(float[] resolution, float[] loc, float[] size) {
        if (uniformBufferObject == 0 || resolution.length != 2 || loc.length != 2 || size.length != 2) return;

        uboBuffer.clear();
        uboBuffer.put(resolution);
        uboBuffer.put(loc);
        uboBuffer.put(size);
        uboBuffer.flip();

        GL45C.glBindBuffer(GL45C.GL_UNIFORM_BUFFER, uniformBufferObject);
        GL45C.glBufferSubData(GL45C.GL_UNIFORM_BUFFER, 0, uboBuffer);
        GL45C.glBindBuffer(GL45C.GL_UNIFORM_BUFFER, 0);
    }

    protected float[] getCachedUniformMatrix(String uniformName) {
        if (shaderPrograms.isEmpty() || currentProgramIndex >= shaderPrograms.size()) return new float[16];
        int programId = shaderPrograms.get(currentProgramIndex);
        String cacheKey = programId + "_" + uniformName;
        float[] cached = cachedUniformMatrices.get(cacheKey);
        return cached != null ? cached.clone() : new float[16];
    }

    protected float[] createOrthographicMatrix(float right, float bottom) {
        Matrix4f matrix = new Matrix4f().ortho(0, right, bottom, 0, -1, 1);
        float[] matrixArray = new float[16];
        matrix.get(matrixArray);
        return matrixArray;
    }

    public void drawQuad(float x, float y, float width, float height) {
        if (vertexArrayObject == 0 || vertexBufferObject == 0 || elementBufferObject == 0) return;
        if (currentlyBoundProgram == 0) {
            ILogger.logger.error("Cannot draw quad without a bound shader program.");
            return;
        }

        float[] vertices = {
                x,         y,          0.0f, 0.0f,
                x,         y + height, 0.0f, 1.0f,
                x + width, y,          1.0f, 0.0f,
                x + width, y + height, 1.0f, 1.0f
        };

        vertexUploadBuffer.clear();
        vertexUploadBuffer.put(vertices).flip();

        GL45C.glBindBuffer(GL45C.GL_ARRAY_BUFFER, vertexBufferObject);
        GL45C.glBufferSubData(GL45C.GL_ARRAY_BUFFER, 0, vertexUploadBuffer);
        GL45C.glBindVertexArray(vertexArrayObject);
        GL45C.glDrawElements(GL45C.GL_TRIANGLES, 6, GL45C.GL_UNSIGNED_SHORT, 0);
        GL45C.glBindVertexArray(0);
        GL45C.glBindBuffer(GL45C.GL_ARRAY_BUFFER, 0);
    }

    public void setBlend(boolean enabled) {
        if (enabled) {
            GlStateManager.enableBlend();
            GlStateManager.blendFunc(GL45C.GL_SRC_ALPHA, GL45C.GL_ONE_MINUS_SRC_ALPHA);
        } else {
            GlStateManager.disableBlend();
        }
    }

    @Override
    public void close() {
        destroy();
    }
}