From 227a9b7a25373e5a56712369cefdc98c823ebdbe Mon Sep 17 00:00:00 2001 From: bog Date: Sun, 10 Sep 2023 09:59:51 +0200 Subject: [PATCH] ADD: function literal (no closure yet). --- doc/grammar.bnf | 4 +++ lib/Code.cpp | 5 +++ lib/Code.hpp | 3 ++ lib/Compiler.cpp | 50 ++++++++++++++++++++-------- lib/Compiler.hpp | 7 ++-- lib/Function.cpp | 12 ++++++- lib/Function.hpp | 6 ++++ lib/Lexer.cpp | 1 + lib/Node.hpp | 7 +++- lib/Opcodes.hpp | 1 + lib/Parser.cpp | 44 +++++++++++++++++++++++++ lib/Parser.hpp | 3 ++ lib/Program.cpp | 1 + lib/StaticPass.cpp | 21 +++++++++++- lib/SymTable.cpp | 21 ++++++++++-- lib/SymTable.hpp | 6 +++- lib/VM.cpp | 82 +++++++++++++++++++++++++++++++++++++++------- lib/VM.hpp | 5 +++ src/main.cpp | 4 +-- tests/Lexer.cpp | 8 +++++ tests/Parser.cpp | 17 ++++++++++ 21 files changed, 270 insertions(+), 38 deletions(-) diff --git a/doc/grammar.bnf b/doc/grammar.bnf index 050043d..f3a9f41 100644 --- a/doc/grammar.bnf +++ b/doc/grammar.bnf @@ -3,6 +3,10 @@ EXPR ::= LITERAL | FUNCALL | VARDECL +| LAMBDA +LAMBDA ::= opar rarrow opar PARAMS cpar BODY cpar +PARAMS ::= ident* +BODY ::= expr* VARDECL ::= opar decl ident EXPR cpar FUNCALL ::= opar ident EXPR* cpar LITERAL ::= int | ident diff --git a/lib/Code.cpp b/lib/Code.cpp index 125f13c..cc177cb 100644 --- a/lib/Code.cpp +++ b/lib/Code.cpp @@ -7,6 +7,11 @@ namespace jk { } + /*explicit*/ Code::Code(std::shared_ptr program) + : m_program { program } + { + } + /*virtual*/ Code::~Code() { } diff --git a/lib/Code.hpp b/lib/Code.hpp index 654c79e..493ee5f 100644 --- a/lib/Code.hpp +++ b/lib/Code.hpp @@ -10,12 +10,15 @@ namespace jk { public: explicit Code(foreign_t foreign); + explicit Code(std::shared_ptr program); virtual ~Code(); foreign_t foreign() const { return m_foreign; } + std::shared_ptr program() const { return m_program; } private: foreign_t m_foreign; + std::shared_ptr m_program; }; } diff --git a/lib/Compiler.cpp b/lib/Compiler.cpp index 111d1e6..45f489c 100644 --- a/lib/Compiler.cpp +++ b/lib/Compiler.cpp @@ -4,10 +4,8 @@ namespace jk { - /*explicit*/ Compiler::Compiler(std::shared_ptr sym, - Logger& logger) - : m_sym { sym } - , m_logger(logger) + /*explicit*/ Compiler::Compiler(Logger& logger) + : m_logger(logger) { } @@ -16,21 +14,23 @@ namespace jk } void Compiler::compile(std::shared_ptr node, - std::shared_ptr program) + std::shared_ptr program, + std::shared_ptr sym) { switch (node->type()) { - case NODE_PROG: { + case NODE_PROG: + case NODE_BODY: { for (size_t i=0; isize(); i++) { - compile(node->child(i).lock(), program); + compile(node->child(i).lock(), program, sym); } } break; case NODE_FUNCALL: { for (size_t i=0; isize(); i++) { - compile(node->child(i).lock(), program); + compile(node->child(i).lock(), program, sym); } program->push_instr(OPCODE_CALL, node->size() - 1); @@ -38,26 +38,50 @@ namespace jk case NODE_VARDECL: { std::string ident = node->child(0).lock()->repr(); - auto entry = m_sym->find(ident); + auto entry = sym->find(ident); assert(entry); - compile(node->child(1).lock(), program); + compile(node->child(1).lock(), program, sym); program->push_instr(OPCODE_STORE, entry->addr); } break; + case NODE_LAMBDA: { + auto params = node->child(0).lock(); + auto body = node->child(1).lock(); + auto prog = std::make_shared(); + auto fun_sym = std::make_shared(m_logger, sym); + + for (size_t i=0; isize(); i++) + { + std::string param = params->child(i).lock()->repr(); + fun_sym->declare(param, + std::make_shared(TYPE_NIL), + node->loc()); + } + + compile(body, prog, fun_sym); + prog->push_instr(OPCODE_RET); + + auto code = std::make_shared(prog); + size_t addr = program->push_constant(Value::make_code(code)); + program->push_instr(OPCODE_PUSH_CONST, addr); + program->push_instr(OPCODE_MK_FUNCTION); + //program->push_instr(OpcodeType opcode) + } break; + case NODE_IDENT: { std::string ident = node->repr(); - auto sym = m_sym->find(ident); + auto mysym = sym->find(ident); assert(sym); OpcodeType op_load = OPCODE_LOAD; - if (sym->is_global) + if (mysym->is_global) { op_load = OPCODE_LOAD_GLOBAL; } - program->push_instr(op_load, sym->addr); + program->push_instr(op_load, mysym->addr); } break; diff --git a/lib/Compiler.hpp b/lib/Compiler.hpp index cbeec58..6a6a898 100644 --- a/lib/Compiler.hpp +++ b/lib/Compiler.hpp @@ -15,15 +15,14 @@ namespace jk class Compiler { public: - explicit Compiler(std::shared_ptr sym, - Logger& logger); + explicit Compiler(Logger& logger); virtual ~Compiler(); void compile(std::shared_ptr node, - std::shared_ptr program); + std::shared_ptr program, + std::shared_ptr sym); private: - std::shared_ptr m_sym; Logger& m_logger; }; } diff --git a/lib/Function.cpp b/lib/Function.cpp index ae6a9f8..0289535 100644 --- a/lib/Function.cpp +++ b/lib/Function.cpp @@ -8,12 +8,22 @@ namespace jk { } + /*explicit*/ Function::Function(std::shared_ptr program) + : m_program { program } + { + } + /*virtual*/ Function::~Function() { } value_t Function::call(std::vector const& args) { - return m_foreign(args); + if (m_foreign) + { + return m_foreign(args); + } + + return Value::make_nil(); } } diff --git a/lib/Function.hpp b/lib/Function.hpp index 4fda67f..d5c943e 100644 --- a/lib/Function.hpp +++ b/lib/Function.hpp @@ -2,6 +2,7 @@ #define jk_FUNCTION_HPP #include "commons.hpp" +#include "Program.hpp" namespace jk { @@ -14,12 +15,17 @@ namespace jk { public: explicit Function(foreign_t foreign); + explicit Function(std::shared_ptr program); + virtual ~Function(); + std::shared_ptr program() const { return m_program; } + value_t call(std::vector const& args); private: foreign_t m_foreign; + std::shared_ptr m_program; }; } diff --git a/lib/Lexer.cpp b/lib/Lexer.cpp index a65b352..4c06dfd 100644 --- a/lib/Lexer.cpp +++ b/lib/Lexer.cpp @@ -7,6 +7,7 @@ namespace jk , m_loc { loc } { std::vector> texts = { + {NODE_RARROW, "->", false}, {NODE_DECL, "$", false}, {NODE_OPAR, "(", false}, {NODE_CPAR, ")", false} diff --git a/lib/Node.hpp b/lib/Node.hpp index 5363772..73908f5 100644 --- a/lib/Node.hpp +++ b/lib/Node.hpp @@ -12,7 +12,12 @@ G(NODE_IDENT), \ G(NODE_DECL), \ G(NODE_FUNCALL), \ - G(NODE_VARDECL) + G(NODE_VARDECL), \ + G(NODE_RARROW), \ + G(NODE_LAMBDA), \ + G(NODE_PARAMS), \ + G(NODE_BODY), + namespace jk diff --git a/lib/Opcodes.hpp b/lib/Opcodes.hpp index fb8d941..f23c1d9 100644 --- a/lib/Opcodes.hpp +++ b/lib/Opcodes.hpp @@ -9,6 +9,7 @@ G(OPCODE_LOAD), \ G(OPCODE_STORE), \ G(OPCODE_LOAD_GLOBAL), \ + G(OPCODE_RET), \ G(OPCODE_MK_FUNCTION) namespace jk diff --git a/lib/Parser.cpp b/lib/Parser.cpp index 573b7f7..72b3116 100644 --- a/lib/Parser.cpp +++ b/lib/Parser.cpp @@ -111,6 +111,12 @@ namespace jk return parse_vardecl(); } + if (type_is(NODE_OPAR) + && type_is(NODE_RARROW, 1)) + { + return parse_lambda(); + } + return parse_literal(); } @@ -126,6 +132,44 @@ namespace jk return root; } + std::shared_ptr Parser::parse_lambda() + { + auto root = std::make_shared(NODE_LAMBDA, "", loc()); + consume(NODE_OPAR); + consume(NODE_RARROW); + consume(NODE_OPAR); + root->add_child(parse_params()); + consume(NODE_CPAR); + root->add_child(parse_body()); + consume(NODE_CPAR); + + return root; + } + + std::shared_ptr Parser::parse_params() + { + auto root = std::make_shared(NODE_PARAMS, "", loc()); + + while (!type_is(NODE_CPAR)) + { + root->add_child(consume(NODE_IDENT)); + } + + return root; + } + + std::shared_ptr Parser::parse_body() + { + auto root = std::make_shared(NODE_BODY, "", loc()); + + while (!type_is(NODE_CPAR)) + { + root->add_child(parse_expr()); + } + + return root; + } + std::shared_ptr Parser::parse_funcall() { auto root = std::make_shared(NODE_FUNCALL, "", loc()); diff --git a/lib/Parser.hpp b/lib/Parser.hpp index d96c09f..9830478 100644 --- a/lib/Parser.hpp +++ b/lib/Parser.hpp @@ -32,6 +32,9 @@ namespace jk std::shared_ptr parse_prog(); std::shared_ptr parse_expr(); std::shared_ptr parse_vardecl(); + std::shared_ptr parse_lambda(); + std::shared_ptr parse_params(); + std::shared_ptr parse_body(); std::shared_ptr parse_funcall(); std::shared_ptr parse_literal(); }; diff --git a/lib/Program.cpp b/lib/Program.cpp index 46b8bf7..a514921 100644 --- a/lib/Program.cpp +++ b/lib/Program.cpp @@ -48,6 +48,7 @@ namespace jk size_t Program::push_constant(std::shared_ptr value) { + assert(value); m_constants.push_back(value); return m_constants.size() - 1; } diff --git a/lib/StaticPass.cpp b/lib/StaticPass.cpp index bbe99a5..3f97509 100644 --- a/lib/StaticPass.cpp +++ b/lib/StaticPass.cpp @@ -17,7 +17,8 @@ namespace jk { switch (node->type()) { - case NODE_PROG: { + case NODE_PROG: + case NODE_BODY: { for (size_t i=0; isize(); i++) { pass(node->child(i).lock()); @@ -25,6 +26,24 @@ namespace jk } } break; + case NODE_LAMBDA: { + auto params = node->child(0).lock(); + auto body = node->child(1).lock(); + + for (size_t i=0; isize(); i++) + { + auto param = params->child(i).lock(); + std::string ident = params->child(i).lock()->repr(); + m_sym->declare(ident, + std::make_shared(TYPE_NIL), + param->loc()); + } + + pass(body); + + push(std::make_shared(TYPE_FUNCTION)); + } break; + case NODE_INT: { push(std::make_shared(TYPE_INT)); } break; diff --git a/lib/SymTable.cpp b/lib/SymTable.cpp index 70603a1..4bc3d51 100644 --- a/lib/SymTable.cpp +++ b/lib/SymTable.cpp @@ -2,8 +2,10 @@ namespace jk { - /*explicit*/ SymTable::SymTable(Logger& logger) + /*explicit*/ SymTable::SymTable(Logger& logger, + std::shared_ptr parent) : m_logger { logger } + , m_parent { parent } { } @@ -16,7 +18,7 @@ namespace jk std::shared_ptr type, Loc const& loc) { - if (find(name)) + if (find_local(name)) { m_logger.log(LOG_ERROR, loc, @@ -49,7 +51,7 @@ namespace jk return addr; } - std::optional SymTable::find(std::string const& name) const + std::optional SymTable::find_local(std::string const& name) const { for (auto const& sym: m_syms) { @@ -61,4 +63,17 @@ namespace jk return std::nullopt; } + + std::optional SymTable::find(std::string const& name) const + { + auto entry = find_local(name); + if (entry) { return entry; } + + if (m_parent) + { + return m_parent->find(name); + } + + return std::nullopt; + } } diff --git a/lib/SymTable.hpp b/lib/SymTable.hpp index 0696ca4..f7b1e05 100644 --- a/lib/SymTable.hpp +++ b/lib/SymTable.hpp @@ -20,7 +20,9 @@ namespace jk class SymTable { public: - explicit SymTable(Logger& logger); + explicit SymTable(Logger& logger, + std::shared_ptr parent=nullptr); + virtual ~SymTable(); size_t declare(std::string const& name, @@ -32,9 +34,11 @@ namespace jk Loc const& loc); std::optional find(std::string const& name) const; + std::optional find_local(std::string const& name) const; private: Logger& m_logger; + std::shared_ptr m_parent; std::vector m_syms; size_t m_addr = 0; }; diff --git a/lib/VM.cpp b/lib/VM.cpp index d426f69..8a1b404 100644 --- a/lib/VM.cpp +++ b/lib/VM.cpp @@ -18,7 +18,11 @@ namespace jk Frame frame; frame.program = program; m_frames.push_back(frame); + execute(); + } + void VM::execute() + { while (m_pc < m_frames.back().program->size()) { Instr instr = m_frames.back().program->get(m_pc); @@ -26,17 +30,28 @@ namespace jk switch (instr.opcode) { case OPCODE_MK_FUNCTION: { - auto code = program->constant(pop()); + auto code = program()->constant(pop()); - auto function = std::make_shared - (code->as_code()->foreign()); + auto prog = code->as_code()->program(); + + std::shared_ptr function; + + if (prog) + { + function = std::make_shared(prog); + } + else + { + function = std::make_shared + (code->as_code()->foreign()); + } auto value = Value::make_function(function); size_t addr = push_heap(value); auto ref = Value::make_ref(addr); - push(program->push_constant(ref)); + push(program()->push_constant(ref)); m_pc++; @@ -49,7 +64,7 @@ namespace jk case OPCODE_LOAD: { auto value = m_frames.back().locals[*instr.param]; - push(program->push_constant(value)); + push(program()->push_constant(value)); m_pc++; } break; @@ -57,35 +72,73 @@ namespace jk auto idx = pop(); m_frames.back().locals[*instr.param] = - program->constant(idx); + program()->constant(idx); m_pc++; } break; case OPCODE_LOAD_GLOBAL: { auto value = m_globals[*instr.param]; - push(program->push_constant(value)); + push(program()->push_constant(value)); m_pc++; } break; + case OPCODE_RET: { + size_t pc = m_frames.back().ret_addr; + size_t stack_sz = m_frames.back().stack_sz; + auto prog = m_frames.back().program; + + param_t ret = pop(); + auto ret_val = prog->constant(ret); + assert(ret_val); + + while (m_stack.size() > stack_sz) + { + pop(); + } + + m_frames.pop_back(); + + push(program()->push_constant(ret_val)); + + m_pc = pc + 1; + } break; + case OPCODE_CALL: { std::vector> args; for (size_t i=0; i<*instr.param; i++) { - args.insert(std::begin(args), program->constant(pop())); + args.insert(std::begin(args), program()->constant(pop())); } - auto ref_val = program->constant(pop()); + auto ref_val = program()->constant(pop()); auto ref = ref_val->as_ref(); auto fun_val = heap(ref); auto fun = fun_val->as_function(); - auto result = fun->call(args); + if (auto prog = fun->program(); + prog) + { + Frame frame; + frame.program = prog; + frame.ret_addr = m_pc; + frame.stack_sz = m_stack.size(); + m_pc = 0; - push(program->push_constant(result)); + for (size_t i=0; icall(args); + push(program()->push_constant(result)); + m_pc++; + } } break; default: @@ -137,4 +190,9 @@ namespace jk m_stack.pop_back(); return param; } + + std::shared_ptr VM::program() const + { + return m_frames.back().program; + } } diff --git a/lib/VM.hpp b/lib/VM.hpp index 90a6bdd..b0e2e0f 100644 --- a/lib/VM.hpp +++ b/lib/VM.hpp @@ -9,6 +9,8 @@ namespace jk struct Frame { std::shared_ptr program; std::unordered_map> locals; + size_t ret_addr; + size_t stack_sz; }; class VM @@ -18,6 +20,7 @@ namespace jk virtual ~VM(); void execute(std::shared_ptr program); + void execute(); std::string string() const; @@ -35,6 +38,8 @@ namespace jk void push(param_t param); param_t pop(); + + std::shared_ptr program() const; }; } diff --git a/src/main.cpp b/src/main.cpp index a9894b9..f43093c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -84,7 +84,7 @@ int main(int argc, char** argv) auto sym = std::make_shared(logger); - auto compiler = std::make_shared(sym, logger); + auto compiler = std::make_shared(logger); auto program = std::make_shared(); auto vm = std::make_shared(); auto loader = std::make_shared(vm, sym); @@ -93,7 +93,7 @@ int main(int argc, char** argv) auto static_pass = std::make_shared(sym, logger); static_pass->pass(ast); - compiler->compile(ast, program); + compiler->compile(ast, program, sym); if (debug_mode) { diff --git a/tests/Lexer.cpp b/tests/Lexer.cpp index b7d71f3..098db6f 100644 --- a/tests/Lexer.cpp +++ b/tests/Lexer.cpp @@ -73,3 +73,11 @@ TEST_CASE_METHOD(LexerTest, "Lexer_vardecl") test_next(*lexer, "DECL"); test_end(*lexer); } + +TEST_CASE_METHOD(LexerTest, "Lexer_lambda") +{ + auto lexer = jk::Factory(m_logger, "tests/lexer").make_lexer(); + lexer->scan(" -> "); + test_next(*lexer, "RARROW"); + test_end(*lexer); +} diff --git a/tests/Parser.cpp b/tests/Parser.cpp index 4c18aab..2160df6 100644 --- a/tests/Parser.cpp +++ b/tests/Parser.cpp @@ -44,3 +44,20 @@ TEST_CASE_METHOD(ParserTest, "Parser_vardecl") "FUNCALL(IDENT[f],INT[3],INT[2])))", " ($ world (f 3 2)) "); } + +TEST_CASE_METHOD(ParserTest, "Parser_lambda") +{ + test_parser("PROG(LAMBDA(PARAMS(" + "IDENT[x],IDENT[y]" + "),BODY(" + "FUNCALL(IDENT[add],IDENT[y],IDENT[x])" + ")))", + " (-> (x y) (add y x)) "); + + test_parser("PROG(LAMBDA(PARAMS(" + "IDENT[x]" + "),BODY(" + "IDENT[x]" + ")))", + " (-> (x) x)"); +}