package tech.atani.client.util.game.player.point;

import net.minecraft.util.AxisAlignedBB;
import net.minecraft.util.Vec3;
import tech.atani.client.util.Util;

import java.util.*;

import static net.minecraft.entity.Entity.isPointVisible;
import static net.minecraft.entity.Entity.squareDistance;

public class PointFinder extends Util {
    private static final int MAX_CHECKS = 100;

    public static Vec3 findClosestVisiblePoint(Vec3 origin, AxisAlignedBB box, double stepSize) {
        if (origin == null || box == null || stepSize <= 0) return null;

        Vec3 closestOnBox = getClosestPointOnBox(origin, box);
        if (isPointVisible(origin, closestOnBox)) return closestOnBox;

        PriorityQueue<PointWithDistance> queue = new PriorityQueue<>(Comparator.comparingDouble(p -> p.distanceSq));
        Set<PointKey> visited = new HashSet<>();
        queue.add(new PointWithDistance(closestOnBox, 0));

        Vec3 bestPoint = null;
        double bestDistSq = Double.MAX_VALUE;
        int checks = 0;

        while (!queue.isEmpty() && checks++ < MAX_CHECKS) {
            PointWithDistance current = queue.poll();
            if (current.distanceSq >= bestDistSq) continue;

            if (isPointVisible(origin, current.point)) {
                bestPoint = current.point;
                bestDistSq = current.distanceSq;
            }

            addPoint(current.point, box, stepSize, queue, visited, origin);
        }

        return bestPoint;
    }

    private static void addPoint(Vec3 point, AxisAlignedBB box, double step, PriorityQueue<PointWithDistance> queue, Set<PointKey> visited, Vec3 origin) {
        double x = point.xCoord, y = point.yCoord, z = point.zCoord;

        for (double dx : new double[]{-step, step}) {
            double nx = x + dx;
            if (nx >= box.minX && nx <= box.maxX) {
                Vec3 p = new Vec3(nx, y, z);
                addIfNew(p, queue, visited, origin, step);
            }
        }
        for (double dy : new double[]{-step, step}) {
            double ny = y + dy;
            if (ny >= box.minY && ny <= box.maxY) {
                Vec3 p = new Vec3(x, ny, z);
                addIfNew(p, queue, visited, origin, step);
            }
        }
        for (double dz : new double[]{-step, step}) {
            double nz = z + dz;
            if (nz >= box.minZ && nz <= box.maxZ) {
                Vec3 p = new Vec3(x, y, nz);
                addIfNew(p, queue, visited, origin, step);
            }
        }
    }

    private static void addIfNew(Vec3 point, PriorityQueue<PointWithDistance> queue, Set<PointKey> visited, Vec3 origin, double step) {
        PointKey key = new PointKey(point, step);
        if (!visited.contains(key)) {
            visited.add(key);
            queue.add(new PointWithDistance(point, squareDistance(origin, point)));
        }
    }

    private static Vec3 getClosestPointOnBox(Vec3 origin, AxisAlignedBB box) {
        double x = clamp(origin.xCoord, box.minX, box.maxX);
        double y = clamp(origin.yCoord, box.minY, box.maxY);
        double z = clamp(origin.zCoord, box.minZ, box.maxZ);
        return new Vec3(x, y, z);
    }

    private static double clamp(double value, double min, double max) {
        return Math.max(min, Math.min(max, value));
    }
}