package war.ice;

import java.util.ArrayList;
import java.util.List;

public class Parser {
    private final List<Token> tokens;
    private int current = 0;

    public Parser(List<Token> tokens) {
        this.tokens = tokens;
    }

    public List<Statement> parse() {
        List<Statement> statements = new ArrayList<>();
        while (!isAtEnd()) {
            statements.add(declaration());
        }
        return statements;
    }

    private Statement declaration() {
        if (match(TokenType.IMPORT)) return importStatement();
        if (match(TokenType.VAR)) return varDeclaration();
        if (match(TokenType.INLINE)) return inlineFunctionDeclaration();
        if (check(TokenType.IDENTIFIER) && checkNext(TokenType.COLON_COLON)) return functionDeclaration();
        return statement();
    }

    private Statement importStatement() {
        consume(TokenType.LEFT_PAREN, "Expect '(' after @import.");

        Expression pathExpr = expression();
        if (!(pathExpr instanceof Expression.Literal) || !(((Expression.Literal)pathExpr).value instanceof String)) {
            throw new RuntimeException("Import path must be a string literal.");
        }
        String path = (String)((Expression.Literal)pathExpr).value;

        consume(TokenType.RIGHT_PAREN, "Expect ')' after import path.");
        consume(TokenType.SEMICOLON, "Expect ';' after import statement.");
        return new Statement.Import(new Token(TokenType.IDENTIFIER, path, null, 0));
    }

    private Statement varDeclaration() {
        Token name = consume(TokenType.IDENTIFIER, "Expect variable name.");

        Expression initializer = null;
        if (match(TokenType.EQUAL)) {
            initializer = expression();
        }

        consume(TokenType.SEMICOLON, "Expect ';' after variable declaration.");
        return new Statement.Var(name, initializer);
    }

    private Statement functionDeclaration() {
        Token name = consume(TokenType.IDENTIFIER, "Expect function name.");
        consume(TokenType.COLON_COLON, "Expect '::' after function name.");
        consume(TokenType.LEFT_PAREN, "Expect '(' after function name.");
        
        List<Token> parameters = new ArrayList<>();
        if (!check(TokenType.RIGHT_PAREN)) {
            do {
                parameters.add(consume(TokenType.IDENTIFIER, "Expect parameter name."));
            } while (match(TokenType.COMMA));
        }
        consume(TokenType.RIGHT_PAREN, "Expect ')' after parameters.");
        
        consume(TokenType.LEFT_BRACE, "Expect '{' before function body.");
        List<Statement> body = block();
        return new Statement.Function(name, parameters, body);
    }

    private Statement inlineFunctionDeclaration() {
        Token name = consume(TokenType.IDENTIFIER, "Expect function name.");
        consume(TokenType.COLON_COLON, "Expect '::' after function name.");
        consume(TokenType.LEFT_PAREN, "Expect '(' after function name.");
        
        List<Token> parameters = new ArrayList<>();
        if (!check(TokenType.RIGHT_PAREN)) {
            do {
                parameters.add(consume(TokenType.IDENTIFIER, "Expect parameter name."));
            } while (match(TokenType.COMMA));
        }
        consume(TokenType.RIGHT_PAREN, "Expect ')' after parameters.");
        
        consume(TokenType.ARROW, "Expect '->' after parameters.");

        Statement stmt = statement();
        
        List<Statement> body = new ArrayList<>();
        body.add(stmt);
        
        return new Statement.Function(name, parameters, body);
    }

    private List<Statement> block() {
        List<Statement> statements = new ArrayList<>();
        while (!check(TokenType.RIGHT_BRACE) && !isAtEnd()) {
            statements.add(declaration());
        }
        consume(TokenType.RIGHT_BRACE, "Expect '}' after block.");
        return statements;
    }

    private Statement statement() {
        if (match(TokenType.PRINT)) return printStatement();
        if (match(TokenType.PRINTLN)) return printlnStatement();
        if (match(TokenType.RET)) return returnStatement();
        if (match(TokenType.IF)) return ifStatement();
        if (match(TokenType.WHILE)) return whileStatement();
        if (match(TokenType.FOR)) return forStatement();
        if (match(TokenType.LEFT_BRACE)) return new Statement.Block(block());
        return expressionStatement();
    }

    private Statement ifStatement() {
        consume(TokenType.LEFT_PAREN, "Expect '(' after 'if'.");
        Expression condition = expression();
        if (condition == null) {
            throw new RuntimeException("Expect condition expression in if statement.");
        }
        consume(TokenType.RIGHT_PAREN, "Expect ')' after if condition.");

        Statement thenBranch = statement();
        Statement elseBranch = null;
        if (match(TokenType.ELSE)) {
            elseBranch = statement();
        }

        return new Statement.If(condition, thenBranch, elseBranch);
    }

    private Statement printStatement() {
        Expression value = expression();
        consume(TokenType.SEMICOLON, "Expect ';' after value.");
        return new Statement.Print(value);
    }

    private Statement printlnStatement() {
        consume(TokenType.LEFT_PAREN, "Expect '(' after 'println'.");
        Expression value = expression();
        consume(TokenType.RIGHT_PAREN, "Expect ')' after println argument.");
        consume(TokenType.SEMICOLON, "Expect ';' after value.");
        return new Statement.Println(value);
    }

    private Statement returnStatement() {
        Token keyword = previous();
        Expression value = null;
        if (!check(TokenType.SEMICOLON)) {
            value = expression();
        }
        consume(TokenType.SEMICOLON, "Expect ';' after return value.");
        return new Statement.Return(keyword, value);
    }

    private Statement expressionStatement() {
        Expression expr = expression();
        consume(TokenType.SEMICOLON, "Expect ';' after expression.");
        return new Statement.Expression(expr);
    }

    private Statement whileStatement() {
        consume(TokenType.LEFT_PAREN, "Expect '(' after 'while'.");
        Expression condition = expression();
        consume(TokenType.RIGHT_PAREN, "Expect ')' after while condition.");

        consume(TokenType.LEFT_BRACE, "Expect '{' before while body.");
        List<Statement> body = block();
        return new Statement.While(condition, new Statement.Block(body));
    }

    private Statement forStatement() {
        consume(TokenType.LEFT_PAREN, "Expect '(' after 'for'.");

        Statement initializer;
        if (match(TokenType.SEMICOLON)) {
            initializer = null;
        } else if (match(TokenType.VAR)) {
            initializer = varDeclaration();
        } else {
            initializer = expressionStatement();
        }

        Expression condition = null;
        if (!check(TokenType.SEMICOLON)) {
            condition = expression();
        }
        consume(TokenType.SEMICOLON, "Expect ';' after loop condition.");

        Expression increment = null;
        if (!check(TokenType.RIGHT_PAREN)) {
            increment = expression();
        }
        consume(TokenType.RIGHT_PAREN, "Expect ')' after for clauses.");

        consume(TokenType.LEFT_BRACE, "Expect '{' before for body.");
        List<Statement> body = block();

        if (increment != null) {
            body.add(new Statement.Expression(increment));
        }

        Statement whileBody = new Statement.Block(body);
        if (condition == null) {
            condition = new Expression.Literal(true);
        }
        Statement whileLoop = new Statement.While(condition, whileBody);

        if (initializer != null) {
            List<Statement> statements = new ArrayList<>();
            statements.add(initializer);
            statements.add(whileLoop);
            return new Statement.Block(statements);
        }

        return whileLoop;
    }

    private Expression expression() {
        return assignment();
    }

    private Expression assignment() {
        Expression expr = or();

        if (match(TokenType.EQUAL)) {
            Token equals = previous();
            Expression value = assignment();

            if (expr instanceof Expression.Variable) {
                Token name = ((Expression.Variable)expr).name;
                return new Expression.Assign(name, value);
            } else if (expr instanceof Expression.ArrayGet) {
                Expression.ArrayGet get = (Expression.ArrayGet)expr;
                return new Expression.ArraySet(get.array, get.index, value);
            }

            throw new RuntimeException("Invalid assignment target.");
        }

        return expr;
    }

    private Expression or() {
        Expression expr = and();

        while (match(TokenType.OR)) {
            Token operator = previous();
            Expression right = and();
            expr = new Expression.Logical(expr, operator, right);
        }

        return expr;
    }

    private Expression and() {
        Expression expr = equality();

        while (match(TokenType.AND)) {
            Token operator = previous();
            Expression right = equality();
            expr = new Expression.Logical(expr, operator, right);
        }

        return expr;
    }

    private Expression equality() {
        Expression expr = comparison();

        while (match(TokenType.BANG_EQUAL, TokenType.EQUAL_EQUAL)) {
            Token operator = previous();
            Expression right = comparison();
            expr = new Expression.Binary(expr, operator, right);
        }

        return expr;
    }

    private Expression comparison() {
        Expression expr = addition();

        while (match(TokenType.GREATER, TokenType.GREATER_EQUAL, TokenType.LESS, TokenType.LESS_EQUAL)) {
            Token operator = previous();
            Expression right = addition();
            expr = new Expression.Binary(expr, operator, right);
        }

        return expr;
    }

    private Expression addition() {
        Expression expr = multiplication();

        while (match(TokenType.PLUS, TokenType.MINUS)) {
            Token operator = previous();
            Expression right = multiplication();
            expr = new Expression.Binary(expr, operator, right);
        }

        return expr;
    }

    private Expression multiplication() {
        Expression expr = unary();

        while (match(TokenType.STAR, TokenType.SLASH, TokenType.PERCENT)) {
            Token operator = previous();
            Expression right = unary();
            expr = new Expression.Binary(expr, operator, right);
        }

        return expr;
    }

    private Expression unary() {
        if (match(TokenType.MINUS)) {
            Token operator = previous();
            Expression right = unary();
            return new Expression.Unary(operator, right);
        }

        return primary();
    }

    private Expression primary() {
        if (match(TokenType.NUMBER)) {
            return new Expression.Literal(previous().literal);
        }

        if (match(TokenType.STRING)) {
            return new Expression.Literal(previous().literal);
        }

        if (match(TokenType.IDENTIFIER)) {
            Expression expr = new Expression.Variable(previous());
            if (match(TokenType.LEFT_PAREN)) {
                return finishCall(expr);
            }
            if (match(TokenType.LEFT_BRACKET)) {
                Expression index = expression();
                consume(TokenType.RIGHT_BRACKET, "Expect ']' after array index.");
                return new Expression.ArrayGet(expr, index);
            }
            if (match(TokenType.DOT)) {
                Token name = consume(TokenType.IDENTIFIER, "Expect property name after '.'.");
                return new Expression.Property(expr, name);
            }
            return expr;
        }

        if (match(TokenType.LEFT_PAREN)) {
            Expression expr = expression();
            consume(TokenType.RIGHT_PAREN, "Expect ')' after expression.");
            return new Expression.Grouping(expr);
        }

        if (match(TokenType.LEFT_BRACKET)) {
            List<Expression> elements = new ArrayList<>();
            if (!check(TokenType.RIGHT_BRACKET)) {
                do {
                    elements.add(expression());
                } while (match(TokenType.COMMA));
            }
            consume(TokenType.RIGHT_BRACKET, "Expect ']' after array elements.");
            return new Expression.Array(elements);
        }

        if (match(TokenType.TRUE)) return new Expression.Literal(true);
        if (match(TokenType.FALSE)) return new Expression.Literal(false);
        if (match(TokenType.NIL)) return new Expression.Literal(null);

        throw new RuntimeException("Expect expression. Got: " + peek().type);
    }

    private Expression finishCall(Expression callee) {
        List<Expression> arguments = new ArrayList<>();
        if (!check(TokenType.RIGHT_PAREN)) {
            do {
                arguments.add(expression());
            } while (match(TokenType.COMMA));
        }

        Token paren = consume(TokenType.RIGHT_PAREN, "Expect ')' after arguments.");
        return new Expression.Call(callee, paren, arguments);
    }

    private boolean match(TokenType... types) {
        for (TokenType type : types) {
            if (check(type)) {
                advance();
                return true;
            }
        }
        return false;
    }

    private Token consume(TokenType type, String message) {
        if (check(type)) return advance();
        throw new RuntimeException(message);
    }

    private boolean check(TokenType type) {
        if (isAtEnd()) return false;
        return peek().type == type;
    }

    private Token advance() {
        if (!isAtEnd()) current++;
        return previous();
    }

    private boolean isAtEnd() {
        return peek().type == TokenType.EOF;
    }

    private Token peek() {
        return tokens.get(current);
    }

    private Token previous() {
        return tokens.get(current - 1);
    }

    private boolean checkNext(TokenType type) {
        if (isAtEnd()) return false;
        if (tokens.get(current + 1).type == TokenType.EOF) return false;
        return tokens.get(current + 1).type == type;
    }
} 