package org.jnt.crackme.vm;

import org.jnt.crackme.vm.data.*;
import org.jnt.crackme.vm.data.RegisterGroup.RegisterFamily;
import org.jnt.crackme.vm.ex.IllegalJumpException;
import org.jnt.crackme.vm.ex.UnknownRegisterException;
import org.jnt.crackme.vm.struct.Function;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static org.jnt.crackme.vm.def.Opcodes.*;

/**
 * @author etho
 */
public class VirtualMachine {
    private final Function[] functions;
    private int fnPtr = 0;

    private final RegisterGroup regGroup = new RegisterGroup();

    public VirtualMachine(int maxFns) {
        this.functions = new Function[maxFns];
    }

    public void reg(Function fn) {
        functions[fnPtr] = fn;
        fnPtr++;
    }

    public void call(int id) {
        var fn = functions[id];

        var code = fn.code();
        var stack = fn.stack();
        var locals = fn.locals();

        int pc = 0;
        while (pc < code.length) {
            byte op = code[pc];

            switch (op) {
                case PUSH -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int value = (b1 << 8) | b2;
                    stack.push(value);

                    pc += 3;
                }
                case POP -> {
                    stack.pop();
                    pc++;
                }
                case ADD -> {
                    stack.push(stack.pop() + stack.pop());
                    pc++;
                }
                case SUB -> {
                    int a = stack.pop();
                    int b = stack.pop();

                    stack.push(b - a);
                    pc++;
                }
                case DIV -> {
                    int a = stack.pop();
                    int b = stack.pop();

                    stack.push(b / a);
                    pc++;
                }
                case MUL -> {
                    stack.push(stack.pop() * stack.pop());
                    pc++;
                }
                case IMM0 -> {
                    stack.push(0);
                    pc++;
                }
                case IMM1 -> {
                    stack.push(1);
                    pc++;
                }
                case IMM2 -> {
                    stack.push(2);
                    pc++;
                }
                case IMM3 -> {
                    stack.push(3);
                    pc++;
                }
                case IMM4 -> {
                    stack.push(4);
                    pc++;
                }
                case IMM5 -> {
                    stack.push(5);
                    pc++;
                }
                case XOR -> {
                    int a = stack.pop();
                    int b = stack.pop();

                    stack.push(b ^ a);
                    pc++;
                }
                case NEG -> {
                    stack.push(-stack.pop());
                    pc++;
                }
                case CALL -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int callId = (b1 << 8) | b2;

                    call(callId);
                    pc += 3;
                }
                case LST -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int value = stack.pop();
                    int index = (b1 << 8) | b2;

                    locals[index][0] = value;
                    pc += 3;
                }
                case LLD -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int index = (b1 << 8) | b2;

                    stack.push(locals[index][0]);
                    pc += 3;
                }
                case ISLT -> {
                    int a = stack.pop();
                    int b = stack.pop();

                    if (a < b) {
                        stack.push(1);
                    } else {
                        stack.push(0);
                    }
                    pc++;
                }
                case JMP -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int relative = (b1 << 8) | b2;

                    if (relative > (pc + code.length)) {
                        throw new IllegalJumpException("Illegal jump (relative offset is higher than the size of the code.)", "VirtualMachine");
                    }

                    pc = pc + 3 + relative;
                }
                case DYNJMP -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int relative = stack.pop();

                    if (relative > (pc + code.length)) {
                        throw new IllegalJumpException("Illegal dynamic jump (relative offset is higher than the size of the code.)", "VirtualMachine");
                    }

                    pc = pc + 3 + relative;
                }
                case LDR -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int regID = (b1 << 8) | b2;

                    Optional<AbstractRegister> reg = regGroup.lookup(regID, RegisterFamily.GENERAL_PURPOSE);

                    if (reg.isPresent()) {
                        ((Register)reg.get()).value = stack.pop();
                    } else {
                        throw new UnknownRegisterException("Failed to find register with ID: " + regID, "VirtualMachine");
                    }
                    pc += 3;
                }
                case DMP -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int regID = (b1 << 8) | b2;

                    Optional<AbstractRegister> reg = regGroup.lookup(regID, RegisterFamily.GENERAL_PURPOSE);

                    if (reg.isPresent()) {
                        stack.push(((Register)reg.get()).value);
                    } else {
                        throw new UnknownRegisterException("Failed to find register with ID: " + regID, "VirtualMachine");
                    }
                    pc += 3;
                }
                case RET -> {
                    int value = stack.pop();

                    Optional<AbstractRegister> reg = regGroup.lookup(0, RegisterFamily.RESERVED); // RR

                    if (reg.isPresent()) {
                        ((Register)reg.get()).value = value;
                    } else {
                        throw new UnknownRegisterException("Failed to find reserved return register.", "VirtualMachine");
                    }
                    return;
                }
                case CPY -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];
                    byte b3 = code[pc + 3];
                    byte b4 = code[pc + 4];

                    int srcRegIdx = (b1 << 8) | b2;
                    int dstRegIdx = (b3 << 8) | b4;

                    Optional<AbstractRegister> srcReg = regGroup.lookup(srcRegIdx, RegisterFamily.GENERAL_PURPOSE);
                    Optional<AbstractRegister> dstReg = regGroup.lookup(dstRegIdx, RegisterFamily.GENERAL_PURPOSE);

                    if (!dstReg.isPresent()) throw new UnknownRegisterException("Failed to find destination register for cpy.", "VirtualMachine");
                    if (!srcReg.isPresent()) throw new UnknownRegisterException("Failed to find source register for cpy.", "VirtualMachine");

                    ((Register)dstReg.get()).value = ((Register)srcReg.get()).value;
                    pc += 5;
                }
                case ISNULL -> {
                    if (dumpObjReg(0) == null) {
                        stack.push(1);
                    } else {
                        stack.push(0);
                    }
                    pc++;
                }
                case UPCALL -> {
                    byte b1 = code[pc + 1];
                    byte b2 = code[pc + 2];

                    int index = (b1 << 8) | b2;

                    try {
                        MethodHandle handle = UpcallGroup.upcallHandle(index);

                        switch (index) {
                            case 0 ->  {
                                handle.invoke(stack.pop());
                                pc += 3;
                            }
                            case 1 -> {
                                byte len = code[pc + 3];
                                byte[] dataUtf = new byte[len];

                                System.arraycopy(code, pc + 4, dataUtf, 0, len);

                                handle.invoke(System.out, new String(dataUtf));

                                pc += 4 + len;
                            }
                            case 2 -> {
                                int idx = stack.pop();
                                loadObjReg(handle.invoke(dumpObjReg(0), idx), 0);
                                pc += 3;
                            }
                            case 3, 4, 5 -> {
                                loadObjReg(handle.invoke(dumpObjReg(0), dumpObjReg(1)), 2);
                                pc += 3;
                            }
                        }
                    } catch (Throwable t) {
                        throw new RuntimeException(t);
                    }
                }
            }
        }
    }

    public void loadObjReg(Object value, int id) {
        ((ObjectRegister)regGroup.lookup(id, RegisterFamily.OBJECT).get()).obj = value;
    }

    public Object dumpObjReg(int id) {
        return ((ObjectRegister)regGroup.lookup(id, RegisterFamily.OBJECT).get()).obj;
    }

    public int rr() {
        return ((Register)regGroup.lookup(0, RegisterFamily.RESERVED).get()).value;
    }
}
