package war.ice;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Interpreter implements Expression.Visitor<Object>, Statement.Visitor<Void> {
    private Map<String, Object> environment = new HashMap<>();
    private final Map<String, Callable> functions = new HashMap<>();

    public Interpreter() {
        functions.put("println", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 1) {
                    throw new RuntimeException("println() expects exactly 1 argument.");
                }
                System.out.println(stringify(arguments.get(0)));
                return null;
            }
        });

        functions.put("time_ms", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 0) {
                    throw new RuntimeException("time_ms() expects no arguments.");
                }
                return (double)System.currentTimeMillis();
            }
        });

        functions.put("time_ns", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 0) {
                    throw new RuntimeException("time_ns() expects no arguments.");
                }
                return (double)System.nanoTime();
            }
        });

        functions.put("sleep_ms", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 1) {
                    throw new RuntimeException("sleep_ms() expects exactly 1 argument (milliseconds).");
                }
                if (!(arguments.get(0) instanceof Double)) {
                    throw new RuntimeException("sleep_ms() argument must be a number.");
                }
                try {
                    Thread.sleep(((Double)arguments.get(0)).longValue());
                } catch (InterruptedException e) {
                    throw new RuntimeException("Sleep interrupted: " + e.getMessage());
                }
                return null;
            }
        });

        functions.put("read_file", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 1) {
                    throw new RuntimeException("read_file() expects exactly 1 argument (file path).");
                }
                String path = stringify(arguments.get(0));
                try {
                    return new String(java.nio.file.Files.readAllBytes(java.nio.file.Paths.get(path)));
                } catch (java.io.IOException e) {
                    throw new RuntimeException("Error reading file: " + e.getMessage());
                }
            }
        });

        functions.put("write_file", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 2) {
                    throw new RuntimeException("write_file() expects exactly 2 arguments (file path, content).");
                }
                String path = stringify(arguments.get(0));
                String content = stringify(arguments.get(1));
                try {
                    java.nio.file.Files.write(java.nio.file.Paths.get(path), content.getBytes());
                    return null;
                } catch (java.io.IOException e) {
                    throw new RuntimeException("Error writing file: " + e.getMessage());
                }
            }
        });

        functions.put("append_file", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 2) {
                    throw new RuntimeException("append_file() expects exactly 2 arguments (file path, content).");
                }
                String path = stringify(arguments.get(0));
                String content = stringify(arguments.get(1));
                try {
                    java.nio.file.Files.write(
                        java.nio.file.Paths.get(path),
                        content.getBytes(),
                        java.nio.file.StandardOpenOption.APPEND
                    );
                    return null;
                } catch (java.io.IOException e) {
                    throw new RuntimeException("Error appending to file: " + e.getMessage());
                }
            }
        });

        functions.put("file_exists", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 1) {
                    throw new RuntimeException("file_exists() expects exactly 1 argument (file path).");
                }
                String path = stringify(arguments.get(0));
                return java.nio.file.Files.exists(java.nio.file.Paths.get(path));
            }
        });

        functions.put("delete_file", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() != 1) {
                    throw new RuntimeException("delete_file() expects exactly 1 argument (file path).");
                }
                String path = stringify(arguments.get(0));
                try {
                    java.nio.file.Files.delete(java.nio.file.Paths.get(path));
                    return null;
                } catch (java.io.IOException e) {
                    throw new RuntimeException("Error deleting file: " + e.getMessage());
                }
            }
        });

        functions.put("input", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() > 1) {
                    throw new RuntimeException("input() expects at most 1 argument (prompt).");
                }
                if (arguments.size() == 1) {
                    System.out.print(stringify(arguments.get(0)));
                }
                try {
                    java.io.BufferedReader reader = new java.io.BufferedReader(new java.io.InputStreamReader(System.in));
                    return reader.readLine();
                } catch (java.io.IOException e) {
                    throw new RuntimeException("Error reading input: " + e.getMessage());
                }
            }
        });

        functions.put("input_number", new Callable() {
            @Override
            public Object call(List<Object> arguments) {
                if (arguments.size() > 1) {
                    throw new RuntimeException("input_number() expects at most 1 argument (prompt).");
                }
                if (arguments.size() == 1) {
                    System.out.print(stringify(arguments.get(0)));
                }
                try {
                    java.io.BufferedReader reader = new java.io.BufferedReader(new java.io.InputStreamReader(System.in));
                    String input = reader.readLine();
                    return Double.parseDouble(input);
                } catch (java.io.IOException e) {
                    throw new RuntimeException("Error reading input: " + e.getMessage());
                } catch (NumberFormatException e) {
                    throw new RuntimeException("Invalid number input.");
                }
            }
        });
    }

    public void interpret(List<Statement> statements) {
        try {
            for (Statement statement : statements) {
                execute(statement);
            }
        } catch (RuntimeException error) {
            System.err.println(error.getMessage());
        }
    }

    private void execute(Statement stmt) {
        stmt.accept(this);
    }

    @Override
    public Void visitVarStmt(Statement.Var stmt) {
        Object value = null;
        if (stmt.initializer != null) {
            value = evaluate(stmt.initializer);
        }

        environment.put(stmt.name.lexeme, value);
        return null;
    }

    @Override
    public Void visitPrintStmt(Statement.Print stmt) {
        Object value = evaluate(stmt.expression);
        System.out.print(stringify(value));
        return null;
    }

    @Override
    public Void visitPrintlnStmt(Statement.Println stmt) {
        Object value = evaluate(stmt.expression);
        System.out.println(stringify(value));
        return null;
    }

    @Override
    public Object visitLiteralExpr(Expression.Literal expr) {
        if (expr.value instanceof String str) {
            StringBuilder result = new StringBuilder();
            int i = 0;
            while (i < str.length()) {
                if (str.charAt(i) == '$' && i + 1 < str.length() && isAlpha(str.charAt(i + 1))) {
                    int end = i + 1;
                    while (end < str.length() && isAlphaNumeric(str.charAt(end))) {
                        end++;
                    }
                    String varName = str.substring(i + 1, end);
                    Object value = lookUpVariable(new Token(TokenType.IDENTIFIER, varName, null, 0));
                    result.append(stringify(value));
                    i = end;
                } else {
                    result.append(str.charAt(i));
                    i++;
                }
            }
            return result.toString();
        }
        return expr.value;
    }

    private boolean isAlpha(char c) {
        return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_';
    }

    private boolean isAlphaNumeric(char c) {
        return isAlpha(c) || (c >= '0' && c <= '9');
    }

    @Override
    public Object visitVariableExpr(Expression.Variable expr) {
        return lookUpVariable(expr.name);
    }

    @Override
    public Object visitAssignExpr(Expression.Assign expr) {
        Object value = evaluate(expr.value);
        environment.put(expr.name.lexeme, value);
        return value;
    }

    @Override
    public Object visitCallExpr(Expression.Call expr) {
        Object callee = evaluate(expr.callee);
        List<Object> arguments = new ArrayList<>();
        for (Expression argument : expr.arguments) {
            arguments.add(evaluate(argument));
        }

        if (!(callee instanceof Callable function)) {
            throw new RuntimeException("Can only call functions.");
        }

        return function.call(arguments);
    }

    @Override
    public Void visitReturnStmt(Statement.Return stmt) {
        Object value = null;
        if (stmt.value != null) value = evaluate(stmt.value);
        throw new Return(value);
    }

    @Override
    public Void visitBlockStmt(Statement.Block stmt) {
        executeBlock(stmt.statements, environment);
        return null;
    }

    @Override
    public Void visitExpressionStmt(Statement.Expression stmt) {
        evaluate(stmt.expression);
        return null;
    }

    @Override
    public Void visitFunctionStmt(Statement.Function stmt) {
        functions.put(stmt.name.lexeme, new Function(stmt));
        return null;
    }

    @Override
    public Object visitBinaryExpr(Expression.Binary expr) {
        Object left = evaluate(expr.left);
        Object right = evaluate(expr.right);

        switch (expr.operator.type) {
            case MINUS:
                checkNumberOperands(expr.operator, left, right);
                return (double)left - (double)right;
            case SLASH:
                checkNumberOperands(expr.operator, left, right);
                if ((double)right == 0) throw new RuntimeException("Division by zero.");
                return (double)left / (double)right;
            case STAR:
                checkNumberOperands(expr.operator, left, right);
                return (double)left * (double)right;
            case PERCENT:
                checkNumberOperands(expr.operator, left, right);
                if ((double)right == 0) throw new RuntimeException("Modulo by zero.");
                return (double)left % (double)right;
            case PLUS:
                if (left instanceof Double && right instanceof Double) {
                    return (double)left + (double)right;
                }
                if (left instanceof String && right instanceof String) {
                    return left + (String)right;
                }
                throw new RuntimeException("Operands must be two numbers or two strings.");
            case GREATER:
                checkNumberOperands(expr.operator, left, right);
                return (double)left > (double)right;
            case GREATER_EQUAL:
                checkNumberOperands(expr.operator, left, right);
                return (double)left >= (double)right;
            case LESS:
                checkNumberOperands(expr.operator, left, right);
                return (double)left < (double)right;
            case LESS_EQUAL:
                checkNumberOperands(expr.operator, left, right);
                return (double)left <= (double)right;
            case BANG_EQUAL: return !isEqual(left, right);
            case EQUAL_EQUAL: return isEqual(left, right);
        }

        return null;
    }

    @Override
    public Object visitUnaryExpr(Expression.Unary expr) {
        Object right = evaluate(expr.right);

        switch (expr.operator.type) {
            case MINUS:
                checkNumberOperand(expr.operator, right);
                return -(Double)right;
            default:
                throw new RuntimeException("Unknown operator.");
        }
    }

    @Override
    public Object visitGroupingExpr(Expression.Grouping expr) {
        return evaluate(expr.expression);
    }

    @Override
    public Void visitIfStmt(Statement.If stmt) {
        if (isTruthy(evaluate(stmt.condition))) {
            execute(stmt.thenBranch);
        } else if (stmt.elseBranch != null) {
            execute(stmt.elseBranch);
        }
        return null;
    }

    @Override
    public Object visitLogicalExpr(Expression.Logical expr) {
        Object left = evaluate(expr.left);

        if (expr.operator.type == TokenType.OR) {
            if (isTruthy(left)) return left;
        } else {
            if (!isTruthy(left)) return left;
        }

        return evaluate(expr.right);
    }

    private boolean isTruthy(Object object) {
        if (object == null) return false;
        if (object instanceof Boolean) return (Boolean) object;
        return true;
    }

    private Object evaluate(Expression expr) {
        return expr.accept(this);
    }

    private Object lookUpVariable(Token name) {
        if (environment.containsKey(name.lexeme)) {
            return environment.get(name.lexeme);
        }
        if (functions.containsKey(name.lexeme)) {
            return functions.get(name.lexeme);
        }
        throw new RuntimeException("Undefined variable '" + name.lexeme + "'.");
    }

    private String stringify(Object object) {
        if (object == null) return "nil";
        if (object instanceof Double) {
            String text = object.toString();
            if (text.endsWith(".0")) {
                text = text.substring(0, text.length() - 2);
            }
            return text;
        }
        return object.toString();
    }

    private void executeBlock(List<Statement> statements, Map<String, Object> environment) {
        Map<String, Object> previous = this.environment;
        try {
            this.environment = environment;
            for (Statement statement : statements) {
                execute(statement);
            }
        } finally {
            this.environment = previous;
        }
    }

    private static class Return extends RuntimeException {
        final Object value;
        Return(Object value) {
            super(null, null, false, false);
            this.value = value;
        }
    }

    private class Function implements Callable {
        private final Statement.Function declaration;
        Function(Statement.Function declaration) {
            this.declaration = declaration;
        }

        @Override
        public Object call(List<Object> arguments) {
            Map<String, Object> environment = new HashMap(Interpreter.this.environment);
            for (int i = 0; i < declaration.params.size(); i++) {
                environment.put(declaration.params.get(i).lexeme, arguments.get(i));
            }

            try {
                executeBlock(declaration.body, environment);
            } catch (Return returnValue) {
                return returnValue.value;
            }
            return null;
        }
    }

    private interface Callable {
        Object call(List<Object> arguments);
    }

    private void checkNumberOperand(Token operator, Object operand) {
        if (operand instanceof Double) return;
        throw new RuntimeException("Operand must be a number.");
    }

    private void checkNumberOperands(Token operator, Object left, Object right) {
        if (left instanceof Double && right instanceof Double) return;
        throw new RuntimeException("Operands must be numbers.");
    }

    private boolean isEqual(Object a, Object b) {
        if (a == null && b == null) return true;
        if (a == null) return false;
        return a.equals(b);
    }

    @Override
    public Void visitImportStmt(Statement.Import stmt) {
        String modulePath = stmt.moduleName.lexeme;
        try {
            String content = new String(java.nio.file.Files.readAllBytes(
                java.nio.file.Paths.get(modulePath + ".ice")));

            Lexer lexer = new Lexer(content);
            List<Token> tokens = lexer.tokenize();
            Parser parser = new Parser(tokens);
            List<Statement> statements = parser.parse();

            for (Statement statement : statements) {
                execute(statement);
            }
        } catch (java.io.IOException e) {
            throw new RuntimeException("Error importing module '" + modulePath + "': " + e.getMessage());
        }
        return null;
    }

    @Override
    public Void visitWhileStmt(Statement.While stmt) {
        while (isTruthy(evaluate(stmt.condition))) {
            executeBlock(((Statement.Block)stmt.body).statements, environment);
        }
        return null;
    }

    @Override
    public Object visitArrayExpr(Expression.Array expr) {
        List<Object> elements = new ArrayList<>();
        for (Expression element : expr.elements) {
            elements.add(evaluate(element));
        }
        return elements;
    }

    @Override
    public Object visitArrayGetExpr(Expression.ArrayGet expr) {
        Object array = evaluate(expr.array);
        Object index = evaluate(expr.index);

        if (!(index instanceof Double)) {
            throw new RuntimeException("Index must be a number.");
        }

        int idx = ((Double)index).intValue();
        
        if (array instanceof List) {
            List<Object> list = (List<Object>)array;
            if (idx < 0 || idx >= list.size()) {
                throw new RuntimeException("Array index out of bounds.");
            }
            return list.get(idx);
        } else if (array instanceof String str) {
            if (idx < 0 || idx >= str.length()) {
                throw new RuntimeException("String index out of bounds.");
            }
            return String.valueOf(str.charAt(idx));
        }
        
        throw new RuntimeException("Only arrays and strings can be indexed.");
    }

    @Override
    public Object visitArraySetExpr(Expression.ArraySet expr) {
        Object array = evaluate(expr.array);
        Object index = evaluate(expr.index);
        Object value = evaluate(expr.value);

        if (!(array instanceof List)) {
            throw new RuntimeException("Only arrays can be indexed.");
        }
        if (!(index instanceof Double)) {
            throw new RuntimeException("Array index must be a number.");
        }

        List<Object> list = (List<Object>)array;
        int idx = ((Double)index).intValue();
        
        if (idx < 0 || idx >= list.size()) {
            throw new RuntimeException("Array index out of bounds.");
        }

        list.set(idx, value);
        return value;
    }

    @Override
    public Object visitPropertyExpr(Expression.Property expr) {
        Object object = evaluate(expr.object);
        
        if (object instanceof List) {
            if (expr.name.lexeme.equals("length")) {
                return (double)((List<?>)object).size();
            }
            throw new RuntimeException("Arrays only have 'length' property.");
        } else if (object instanceof String) {
            if (expr.name.lexeme.equals("length")) {
                return (double)((String)object).length();
            }
            throw new RuntimeException("Strings only have 'length' property.");
        }
        
        throw new RuntimeException("Only arrays and strings have properties.");
    }
} 