package war.metaphor.tree;

import org.objectweb.asm.Type;
import war.metaphor.base.ObfuscatorContext;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

public class Hierarchy {

    public static Hierarchy INSTANCE;

    private final ObfuscatorContext core = ObfuscatorContext.INSTANCE;

    private final ConcurrentMap<ClassMethod, Set<ClassMethod>> methodHierarchy = new ConcurrentHashMap<>();
    private final ConcurrentMap<ClassField,  Set<ClassField>>  fieldHierarchy  = new ConcurrentHashMap<>();
    private final ConcurrentMap<String, String> commonSuperCache = new ConcurrentHashMap<>();

    private final Map<String, List<ClassMethod>> methodIndex   = new HashMap<>();
    private final Map<String, List<ClassField>>  fieldIndex    = new HashMap<>();
    private final Map<JClassNode, Set<JClassNode>> ancestorsCache = new HashMap<>();

    private volatile boolean graphBuilt = false;

    public Hierarchy() {
        INSTANCE = this;
    }

    public void reset() {
        Set<JClassNode> all = new HashSet<>(core.getClasses());
        all.addAll(core.getLibraries());
        all.forEach(JClassNode::resetHierarchy);

        methodHierarchy.clear();
        fieldHierarchy.clear();
        commonSuperCache.clear();

        methodIndex.clear();
        fieldIndex.clear();
        ancestorsCache.clear();

        graphBuilt = false;
    }

    public synchronized void ensureGraphBuilt() {
        if (!graphBuilt) {
            buildHierarchy();
            graphBuilt = true;
        }
    }

    private void buildHierarchy() {

        // load any related class by iterating class path
        core.getClasses().forEach(this::iterateClass);

        // now we have all classes, build the full set
        Set<JClassNode> all = new HashSet<>(core.getClasses());
        all.addAll(core.getLibraries());

        all.forEach(this::iterateClass);
        computeAncestors(all);

        methodIndex.clear();

        all.forEach(cls ->
                cls.methods.forEach(m -> {
                    String sig = m.name + m.desc;
                    methodIndex
                            .computeIfAbsent(sig, _ -> new ArrayList<>())
                            .add(ClassMethod.of(cls, m));
                })
        );
        buildMethodHierarchy();

        fieldIndex.clear();
        all.forEach(cls ->
                cls.fields.forEach(f -> {
                    String sig = f.name + f.desc;
                    fieldIndex
                            .computeIfAbsent(sig, _ -> new ArrayList<>())
                            .add(ClassField.of(cls, f));
                })
        );
        buildFieldHierarchy();

        commonSuperCache.clear();
    }

    public void iterateClass(JClassNode cls) {
        Deque<String> queue = new ArrayDeque<>();
        if (cls.superName != null) queue.add(cls.superName);
        if (cls.interfaces != null) queue.addAll(cls.interfaces);

        while (!queue.isEmpty()) {
            String parentName = queue.poll();
            JClassNode parent = core.loadClass(parentName);
            if (parent == null) continue;

            parent.addChild(cls);
            cls.addParent(parent);

            if (parent.superName != null) queue.add(parent.superName);
            if (parent.interfaces != null) queue.addAll(parent.interfaces);
        }
    }

    private void computeAncestors(Set<JClassNode> classes) {
        ancestorsCache.clear();
        for (JClassNode cls : classes) {
            Set<JClassNode> anc = new HashSet<>();
            Deque<JClassNode> stack = new ArrayDeque<>(cls.getParents());
            while (!stack.isEmpty()) {
                JClassNode p = stack.pop();
                if (anc.add(p)) {
                    stack.addAll(p.getParents());
                }
            }
            ancestorsCache.put(cls, anc);
        }
    }

    private void buildMethodHierarchy() {
        methodHierarchy.clear();
        methodIndex.values().forEach(cluster -> {
            for (ClassMethod self : cluster) {
                if (self.getClassNode().isLibrary()) continue;

                Set<ClassMethod> group = new HashSet<>();
                for (ClassMethod other : cluster) {
                    if (sameChain(self.getClassNode(), other.getClassNode())) {
                        group.add(other);
                    }
                }
                methodHierarchy.put(self, Collections.unmodifiableSet(group));
            }
        });
    }

    private void buildFieldHierarchy() {
        fieldHierarchy.clear();
        fieldIndex.values().forEach(cluster -> {
            for (ClassField self : cluster) {
                if (self.getClassNode().isLibrary()) continue;

                Set<ClassField> group = new HashSet<>();
                for (ClassField other : cluster) {
                    if (sameChain(self.getClassNode(), other.getClassNode())) {
                        group.add(other);
                    }
                }
                fieldHierarchy.put(self, Collections.unmodifiableSet(group));
            }
        });
    }

    private boolean sameChain(JClassNode a, JClassNode b) {
        if (a.equals(b)) return true;
        Set<JClassNode> ancB = ancestorsCache.get(b);
        if (ancB != null && ancB.contains(a)) return true;
        Set<JClassNode> ancA = ancestorsCache.get(a);
        return ancA != null && ancA.contains(b);
    }

    public String getCommonSuperClass(String t1, String t2) {
        ensureGraphBuilt();

        String key = t1 + '#' + t2;
        if (!commonSuperCache.containsKey(key)) {
            String result = computeCommonSuperClass(t1, t2);
            commonSuperCache.put(key, result);
            commonSuperCache.put(t2 + '#' + t1, result);
            return result;
        }
        return commonSuperCache.get(key);
    }

    private String computeCommonSuperClass(String type1, String type2) {
        if (type1.startsWith("[")) {
            Type et1 = Type.getType(type1);
            return "[" + getCommonSuperClass(et1.getElementType().getInternalName(), type2);
        }
        if (type2.startsWith("[")) {
            Type et2 = Type.getType(type2);
            return "[" + getCommonSuperClass(type1, et2.getElementType().getInternalName());
        }

        JClassNode c1 = loadOrPrimitive(type1);
        JClassNode c2 = loadOrPrimitive(type2);
        if (c1 == null || c2 == null) return "java/lang/Object";

        if (c1.isAssignableFrom(c2)) return type1;
        if (c2.isAssignableFrom(c1)) return type2;

        if (!c1.isInterface() && !c2.isInterface()) {
            while (!c1.isAssignableFrom(c2)) {
                c1 = core.loadClass(c1.superName);
            }
            return c1.name.replace('.', '/');
        }
        return "java/lang/Object";
    }

    private JClassNode loadOrPrimitive(String name) {
        JClassNode cls = core.loadClass(name);
        if (cls != null) return cls;
        String boxed = switch (name) {
            case "B" -> "java/lang/Byte";
            case "C" -> "java/lang/Character";
            case "D" -> "java/lang/Double";
            case "F" -> "java/lang/Float";
            case "I" -> "java/lang/Integer";
            case "J" -> "java/lang/Long";
            case "S" -> "java/lang/Short";
            case "Z" -> "java/lang/Boolean";
            default -> name;
        };
        return core.loadClass(boxed);
    }

    public Set<ClassMethod> getMethodHierarchy(ClassMethod m) {
        ensureGraphBuilt();
        return methodHierarchy.getOrDefault(m, Collections.emptySet());
    }

    public Set<ClassField> getFieldHierarchy(ClassField f) {
        ensureGraphBuilt();
        return fieldHierarchy.getOrDefault(f, Collections.emptySet());
    }

    public Set<JClassNode> getClassHierarchy(JClassNode start) {
        ensureGraphBuilt();

        Set<JClassNode> result = new HashSet<>();

        Deque<JClassNode> ancQueue = new ArrayDeque<>(start.getParents());
        while (!ancQueue.isEmpty()) {
            JClassNode cls = ancQueue.poll();
            if (!result.add(cls)) continue;
            ancQueue.addAll(cls.getParents());
        }

        Deque<JClassNode> descQueue = new ArrayDeque<>(start.getChildren());
        while (!descQueue.isEmpty()) {
            JClassNode cls = descQueue.poll();
            if (!result.add(cls)) continue;
            descQueue.addAll(cls.getChildren());
        }

        return result;
    }
}