package war.metaphor.analysis.callgraph;

import lombok.Getter;
import org.objectweb.asm.Handle;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.InvokeDynamicInsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import war.jnt.dash.Ansi;
import war.jnt.dash.Level;
import war.jnt.dash.Logger;
import war.jnt.dash.Origin;
import war.metaphor.base.ObfuscatorContext;
import war.metaphor.tree.ClassMethod;
import war.metaphor.tree.JClassNode;
import war.metaphor.util.asm.BytecodeUtil;

import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import static war.jnt.dash.Ansi.Color.BRIGHT_YELLOW;
import static war.jnt.dash.Ansi.Color.YELLOW;

public class CallGraph {

    private final Logger logger = Logger.INSTANCE;

    @Getter
    private final Map<ClassMethod, Set<CallGraphNode>> xrefs;

    public CallGraph() {
        this.xrefs = new ConcurrentHashMap<>();
    }

    public void buildGraph() {
        ObfuscatorContext core = ObfuscatorContext.INSTANCE;
        xrefs.clear();
        core.getClasses().forEach(classNode -> classNode.methods.forEach(method -> method.instructions.stream().forEach(instruction -> {
            ClassMethod caller = ClassMethod.of(classNode, method);
            if (instruction instanceof MethodInsnNode insn) {
                ClassMethod called = BytecodeUtil.getMethodNode(insn.owner, insn.name, insn.desc);
                if (called == null) return;
                addNode(caller, called, insn);
            } else if (instruction instanceof InvokeDynamicInsnNode insn) {
                Handle handle = insn.bsm;
                JClassNode owner = core.loadClass(handle.getOwner());
                if (owner == null) return;
                ClassMethod called = BytecodeUtil.getMethodNode(owner, handle.getName(), handle.getDesc());
                if (called == null) return;
                addNode(caller, called, insn);
                for (Object bsmArg : insn.bsmArgs) {
                    if (bsmArg instanceof Handle _handle) {
                        owner = core.loadClass(_handle.getOwner());
                        if (owner == null) continue;
                        called = BytecodeUtil.getMethodNode(owner, _handle.getName(), _handle.getDesc());
                        if (called == null) {
                            logger.logln(Level.WARNING, Origin.CORE, String.format("Failed to find method: %s", new Ansi().c(YELLOW).s(String.format("%s%s", _handle.getName(), _handle.getDesc())).r(false).c(BRIGHT_YELLOW)));
                            continue;
                        }
                        addNode(caller, called, insn);
                    }
                }
            }
        })));
    }

    private void addNode(ClassMethod caller, ClassMethod called, AbstractInsnNode instruction) {
        CallGraphNode node = new CallGraphNode(caller, instruction);
        xrefs.computeIfAbsent(called, _ -> new HashSet<>()).add(node);
    }

    public Set<CallGraphNode> getNodes(ClassMethod method) {
        ClassMethod newMethod = xrefs.keySet().parallelStream().filter(method::equals).findAny().orElse(method);

        return xrefs.getOrDefault(newMethod, Collections.emptySet());
    }

    public record CallGraphNode(ClassMethod member, AbstractInsnNode instruction) {
        public ClassMethod getMember() {
            return member;
        }
        public MethodNode getMethod() {
            return member.getMember();
        }
        public AbstractInsnNode getInstruction() {
            return instruction;
        }
        public boolean isEdit() {
            return !member.getClassNode().isLibrary() && !isInvokeDynamic();
        }
        public boolean isInvokeDynamic() {
            return instruction instanceof InvokeDynamicInsnNode;
        }
    }
}
