ADD: function literal (no closure yet).

main
bog 2023-09-10 09:59:51 +02:00
parent 749f6c0a95
commit 227a9b7a25
21 changed files with 270 additions and 38 deletions

View File

@ -3,6 +3,10 @@ EXPR ::=
LITERAL LITERAL
| FUNCALL | FUNCALL
| VARDECL | VARDECL
| LAMBDA
LAMBDA ::= opar rarrow opar PARAMS cpar BODY cpar
PARAMS ::= ident*
BODY ::= expr*
VARDECL ::= opar decl ident EXPR cpar VARDECL ::= opar decl ident EXPR cpar
FUNCALL ::= opar ident EXPR* cpar FUNCALL ::= opar ident EXPR* cpar
LITERAL ::= int | ident LITERAL ::= int | ident

View File

@ -7,6 +7,11 @@ namespace jk
{ {
} }
/*explicit*/ Code::Code(std::shared_ptr<Program> program)
: m_program { program }
{
}
/*virtual*/ Code::~Code() /*virtual*/ Code::~Code()
{ {
} }

View File

@ -10,12 +10,15 @@ namespace jk
{ {
public: public:
explicit Code(foreign_t foreign); explicit Code(foreign_t foreign);
explicit Code(std::shared_ptr<Program> program);
virtual ~Code(); virtual ~Code();
foreign_t foreign() const { return m_foreign; } foreign_t foreign() const { return m_foreign; }
std::shared_ptr<Program> program() const { return m_program; }
private: private:
foreign_t m_foreign; foreign_t m_foreign;
std::shared_ptr<Program> m_program;
}; };
} }

View File

@ -4,10 +4,8 @@
namespace jk namespace jk
{ {
/*explicit*/ Compiler::Compiler(std::shared_ptr<SymTable> sym, /*explicit*/ Compiler::Compiler(Logger& logger)
Logger& logger) : m_logger(logger)
: m_sym { sym }
, m_logger(logger)
{ {
} }
@ -16,21 +14,23 @@ namespace jk
} }
void Compiler::compile(std::shared_ptr<Node> node, void Compiler::compile(std::shared_ptr<Node> node,
std::shared_ptr<Program> program) std::shared_ptr<Program> program,
std::shared_ptr<SymTable> sym)
{ {
switch (node->type()) switch (node->type())
{ {
case NODE_PROG: { case NODE_PROG:
case NODE_BODY: {
for (size_t i=0; i<node->size(); i++) for (size_t i=0; i<node->size(); i++)
{ {
compile(node->child(i).lock(), program); compile(node->child(i).lock(), program, sym);
} }
} break; } break;
case NODE_FUNCALL: { case NODE_FUNCALL: {
for (size_t i=0; i<node->size(); i++) for (size_t i=0; i<node->size(); i++)
{ {
compile(node->child(i).lock(), program); compile(node->child(i).lock(), program, sym);
} }
program->push_instr(OPCODE_CALL, node->size() - 1); program->push_instr(OPCODE_CALL, node->size() - 1);
@ -38,26 +38,50 @@ namespace jk
case NODE_VARDECL: { case NODE_VARDECL: {
std::string ident = node->child(0).lock()->repr(); std::string ident = node->child(0).lock()->repr();
auto entry = m_sym->find(ident); auto entry = sym->find(ident);
assert(entry); assert(entry);
compile(node->child(1).lock(), program); compile(node->child(1).lock(), program, sym);
program->push_instr(OPCODE_STORE, entry->addr); program->push_instr(OPCODE_STORE, entry->addr);
} break; } break;
case NODE_LAMBDA: {
auto params = node->child(0).lock();
auto body = node->child(1).lock();
auto prog = std::make_shared<Program>();
auto fun_sym = std::make_shared<SymTable>(m_logger, sym);
for (size_t i=0; i<params->size(); i++)
{
std::string param = params->child(i).lock()->repr();
fun_sym->declare(param,
std::make_shared<Type>(TYPE_NIL),
node->loc());
}
compile(body, prog, fun_sym);
prog->push_instr(OPCODE_RET);
auto code = std::make_shared<Code>(prog);
size_t addr = program->push_constant(Value::make_code(code));
program->push_instr(OPCODE_PUSH_CONST, addr);
program->push_instr(OPCODE_MK_FUNCTION);
//program->push_instr(OpcodeType opcode)
} break;
case NODE_IDENT: { case NODE_IDENT: {
std::string ident = node->repr(); std::string ident = node->repr();
auto sym = m_sym->find(ident); auto mysym = sym->find(ident);
assert(sym); assert(sym);
OpcodeType op_load = OPCODE_LOAD; OpcodeType op_load = OPCODE_LOAD;
if (sym->is_global) if (mysym->is_global)
{ {
op_load = OPCODE_LOAD_GLOBAL; op_load = OPCODE_LOAD_GLOBAL;
} }
program->push_instr(op_load, sym->addr); program->push_instr(op_load, mysym->addr);
} break; } break;

View File

@ -15,15 +15,14 @@ namespace jk
class Compiler class Compiler
{ {
public: public:
explicit Compiler(std::shared_ptr<SymTable> sym, explicit Compiler(Logger& logger);
Logger& logger);
virtual ~Compiler(); virtual ~Compiler();
void compile(std::shared_ptr<Node> node, void compile(std::shared_ptr<Node> node,
std::shared_ptr<Program> program); std::shared_ptr<Program> program,
std::shared_ptr<SymTable> sym);
private: private:
std::shared_ptr<SymTable> m_sym;
Logger& m_logger; Logger& m_logger;
}; };
} }

View File

@ -8,12 +8,22 @@ namespace jk
{ {
} }
/*explicit*/ Function::Function(std::shared_ptr<Program> program)
: m_program { program }
{
}
/*virtual*/ Function::~Function() /*virtual*/ Function::~Function()
{ {
} }
value_t Function::call(std::vector<value_t> const& args) value_t Function::call(std::vector<value_t> const& args)
{ {
return m_foreign(args); if (m_foreign)
{
return m_foreign(args);
}
return Value::make_nil();
} }
} }

View File

@ -2,6 +2,7 @@
#define jk_FUNCTION_HPP #define jk_FUNCTION_HPP
#include "commons.hpp" #include "commons.hpp"
#include "Program.hpp"
namespace jk namespace jk
{ {
@ -14,12 +15,17 @@ namespace jk
{ {
public: public:
explicit Function(foreign_t foreign); explicit Function(foreign_t foreign);
explicit Function(std::shared_ptr<Program> program);
virtual ~Function(); virtual ~Function();
std::shared_ptr<Program> program() const { return m_program; }
value_t call(std::vector<value_t> const& args); value_t call(std::vector<value_t> const& args);
private: private:
foreign_t m_foreign; foreign_t m_foreign;
std::shared_ptr<Program> m_program;
}; };
} }

View File

@ -7,6 +7,7 @@ namespace jk
, m_loc { loc } , m_loc { loc }
{ {
std::vector<std::tuple<NodeType, std::string, bool>> texts = { std::vector<std::tuple<NodeType, std::string, bool>> texts = {
{NODE_RARROW, "->", false},
{NODE_DECL, "$", false}, {NODE_DECL, "$", false},
{NODE_OPAR, "(", false}, {NODE_OPAR, "(", false},
{NODE_CPAR, ")", false} {NODE_CPAR, ")", false}

View File

@ -12,7 +12,12 @@
G(NODE_IDENT), \ G(NODE_IDENT), \
G(NODE_DECL), \ G(NODE_DECL), \
G(NODE_FUNCALL), \ G(NODE_FUNCALL), \
G(NODE_VARDECL) G(NODE_VARDECL), \
G(NODE_RARROW), \
G(NODE_LAMBDA), \
G(NODE_PARAMS), \
G(NODE_BODY),
namespace jk namespace jk

View File

@ -9,6 +9,7 @@
G(OPCODE_LOAD), \ G(OPCODE_LOAD), \
G(OPCODE_STORE), \ G(OPCODE_STORE), \
G(OPCODE_LOAD_GLOBAL), \ G(OPCODE_LOAD_GLOBAL), \
G(OPCODE_RET), \
G(OPCODE_MK_FUNCTION) G(OPCODE_MK_FUNCTION)
namespace jk namespace jk

View File

@ -111,6 +111,12 @@ namespace jk
return parse_vardecl(); return parse_vardecl();
} }
if (type_is(NODE_OPAR)
&& type_is(NODE_RARROW, 1))
{
return parse_lambda();
}
return parse_literal(); return parse_literal();
} }
@ -126,6 +132,44 @@ namespace jk
return root; return root;
} }
std::shared_ptr<Node> Parser::parse_lambda()
{
auto root = std::make_shared<Node>(NODE_LAMBDA, "", loc());
consume(NODE_OPAR);
consume(NODE_RARROW);
consume(NODE_OPAR);
root->add_child(parse_params());
consume(NODE_CPAR);
root->add_child(parse_body());
consume(NODE_CPAR);
return root;
}
std::shared_ptr<Node> Parser::parse_params()
{
auto root = std::make_shared<Node>(NODE_PARAMS, "", loc());
while (!type_is(NODE_CPAR))
{
root->add_child(consume(NODE_IDENT));
}
return root;
}
std::shared_ptr<Node> Parser::parse_body()
{
auto root = std::make_shared<Node>(NODE_BODY, "", loc());
while (!type_is(NODE_CPAR))
{
root->add_child(parse_expr());
}
return root;
}
std::shared_ptr<Node> Parser::parse_funcall() std::shared_ptr<Node> Parser::parse_funcall()
{ {
auto root = std::make_shared<Node>(NODE_FUNCALL, "", loc()); auto root = std::make_shared<Node>(NODE_FUNCALL, "", loc());

View File

@ -32,6 +32,9 @@ namespace jk
std::shared_ptr<Node> parse_prog(); std::shared_ptr<Node> parse_prog();
std::shared_ptr<Node> parse_expr(); std::shared_ptr<Node> parse_expr();
std::shared_ptr<Node> parse_vardecl(); std::shared_ptr<Node> parse_vardecl();
std::shared_ptr<Node> parse_lambda();
std::shared_ptr<Node> parse_params();
std::shared_ptr<Node> parse_body();
std::shared_ptr<Node> parse_funcall(); std::shared_ptr<Node> parse_funcall();
std::shared_ptr<Node> parse_literal(); std::shared_ptr<Node> parse_literal();
}; };

View File

@ -48,6 +48,7 @@ namespace jk
size_t Program::push_constant(std::shared_ptr<Value> value) size_t Program::push_constant(std::shared_ptr<Value> value)
{ {
assert(value);
m_constants.push_back(value); m_constants.push_back(value);
return m_constants.size() - 1; return m_constants.size() - 1;
} }

View File

@ -17,7 +17,8 @@ namespace jk
{ {
switch (node->type()) switch (node->type())
{ {
case NODE_PROG: { case NODE_PROG:
case NODE_BODY: {
for (size_t i=0; i<node->size(); i++) for (size_t i=0; i<node->size(); i++)
{ {
pass(node->child(i).lock()); pass(node->child(i).lock());
@ -25,6 +26,24 @@ namespace jk
} }
} break; } break;
case NODE_LAMBDA: {
auto params = node->child(0).lock();
auto body = node->child(1).lock();
for (size_t i=0; i<params->size(); i++)
{
auto param = params->child(i).lock();
std::string ident = params->child(i).lock()->repr();
m_sym->declare(ident,
std::make_shared<Type>(TYPE_NIL),
param->loc());
}
pass(body);
push(std::make_shared<Type>(TYPE_FUNCTION));
} break;
case NODE_INT: { case NODE_INT: {
push(std::make_shared<Type>(TYPE_INT)); push(std::make_shared<Type>(TYPE_INT));
} break; } break;

View File

@ -2,8 +2,10 @@
namespace jk namespace jk
{ {
/*explicit*/ SymTable::SymTable(Logger& logger) /*explicit*/ SymTable::SymTable(Logger& logger,
std::shared_ptr<SymTable> parent)
: m_logger { logger } : m_logger { logger }
, m_parent { parent }
{ {
} }
@ -16,7 +18,7 @@ namespace jk
std::shared_ptr<Type> type, std::shared_ptr<Type> type,
Loc const& loc) Loc const& loc)
{ {
if (find(name)) if (find_local(name))
{ {
m_logger.log<symbolic_error>(LOG_ERROR, m_logger.log<symbolic_error>(LOG_ERROR,
loc, loc,
@ -49,7 +51,7 @@ namespace jk
return addr; return addr;
} }
std::optional<Sym> SymTable::find(std::string const& name) const std::optional<Sym> SymTable::find_local(std::string const& name) const
{ {
for (auto const& sym: m_syms) for (auto const& sym: m_syms)
{ {
@ -61,4 +63,17 @@ namespace jk
return std::nullopt; return std::nullopt;
} }
std::optional<Sym> SymTable::find(std::string const& name) const
{
auto entry = find_local(name);
if (entry) { return entry; }
if (m_parent)
{
return m_parent->find(name);
}
return std::nullopt;
}
} }

View File

@ -20,7 +20,9 @@ namespace jk
class SymTable class SymTable
{ {
public: public:
explicit SymTable(Logger& logger); explicit SymTable(Logger& logger,
std::shared_ptr<SymTable> parent=nullptr);
virtual ~SymTable(); virtual ~SymTable();
size_t declare(std::string const& name, size_t declare(std::string const& name,
@ -32,9 +34,11 @@ namespace jk
Loc const& loc); Loc const& loc);
std::optional<Sym> find(std::string const& name) const; std::optional<Sym> find(std::string const& name) const;
std::optional<Sym> find_local(std::string const& name) const;
private: private:
Logger& m_logger; Logger& m_logger;
std::shared_ptr<SymTable> m_parent;
std::vector<Sym> m_syms; std::vector<Sym> m_syms;
size_t m_addr = 0; size_t m_addr = 0;
}; };

View File

@ -18,7 +18,11 @@ namespace jk
Frame frame; Frame frame;
frame.program = program; frame.program = program;
m_frames.push_back(frame); m_frames.push_back(frame);
execute();
}
void VM::execute()
{
while (m_pc < m_frames.back().program->size()) while (m_pc < m_frames.back().program->size())
{ {
Instr instr = m_frames.back().program->get(m_pc); Instr instr = m_frames.back().program->get(m_pc);
@ -26,17 +30,28 @@ namespace jk
switch (instr.opcode) switch (instr.opcode)
{ {
case OPCODE_MK_FUNCTION: { case OPCODE_MK_FUNCTION: {
auto code = program->constant(pop()); auto code = program()->constant(pop());
auto function = std::make_shared<Function> auto prog = code->as_code()->program();
(code->as_code()->foreign());
std::shared_ptr<Function> function;
if (prog)
{
function = std::make_shared<Function>(prog);
}
else
{
function = std::make_shared<Function>
(code->as_code()->foreign());
}
auto value = Value::make_function(function); auto value = Value::make_function(function);
size_t addr = push_heap(value); size_t addr = push_heap(value);
auto ref = Value::make_ref(addr); auto ref = Value::make_ref(addr);
push(program->push_constant(ref)); push(program()->push_constant(ref));
m_pc++; m_pc++;
@ -49,7 +64,7 @@ namespace jk
case OPCODE_LOAD: { case OPCODE_LOAD: {
auto value = m_frames.back().locals[*instr.param]; auto value = m_frames.back().locals[*instr.param];
push(program->push_constant(value)); push(program()->push_constant(value));
m_pc++; m_pc++;
} break; } break;
@ -57,35 +72,73 @@ namespace jk
auto idx = pop(); auto idx = pop();
m_frames.back().locals[*instr.param] = m_frames.back().locals[*instr.param] =
program->constant(idx); program()->constant(idx);
m_pc++; m_pc++;
} break; } break;
case OPCODE_LOAD_GLOBAL: { case OPCODE_LOAD_GLOBAL: {
auto value = m_globals[*instr.param]; auto value = m_globals[*instr.param];
push(program->push_constant(value)); push(program()->push_constant(value));
m_pc++; m_pc++;
} break; } break;
case OPCODE_RET: {
size_t pc = m_frames.back().ret_addr;
size_t stack_sz = m_frames.back().stack_sz;
auto prog = m_frames.back().program;
param_t ret = pop();
auto ret_val = prog->constant(ret);
assert(ret_val);
while (m_stack.size() > stack_sz)
{
pop();
}
m_frames.pop_back();
push(program()->push_constant(ret_val));
m_pc = pc + 1;
} break;
case OPCODE_CALL: { case OPCODE_CALL: {
std::vector<std::shared_ptr<Value>> args; std::vector<std::shared_ptr<Value>> args;
for (size_t i=0; i<*instr.param; i++) for (size_t i=0; i<*instr.param; i++)
{ {
args.insert(std::begin(args), program->constant(pop())); args.insert(std::begin(args), program()->constant(pop()));
} }
auto ref_val = program->constant(pop()); auto ref_val = program()->constant(pop());
auto ref = ref_val->as_ref(); auto ref = ref_val->as_ref();
auto fun_val = heap(ref); auto fun_val = heap(ref);
auto fun = fun_val->as_function(); auto fun = fun_val->as_function();
auto result = fun->call(args); if (auto prog = fun->program();
prog)
{
Frame frame;
frame.program = prog;
frame.ret_addr = m_pc;
frame.stack_sz = m_stack.size();
m_pc = 0;
push(program->push_constant(result)); for (size_t i=0; i<args.size(); i++)
{
frame.locals[i] = args[i];
}
m_pc++; m_frames.push_back(frame);
}
else
{
auto result = fun->call(args);
push(program()->push_constant(result));
m_pc++;
}
} break; } break;
default: default:
@ -137,4 +190,9 @@ namespace jk
m_stack.pop_back(); m_stack.pop_back();
return param; return param;
} }
std::shared_ptr<Program> VM::program() const
{
return m_frames.back().program;
}
} }

View File

@ -9,6 +9,8 @@ namespace jk
struct Frame { struct Frame {
std::shared_ptr<Program> program; std::shared_ptr<Program> program;
std::unordered_map<int, std::shared_ptr<Value>> locals; std::unordered_map<int, std::shared_ptr<Value>> locals;
size_t ret_addr;
size_t stack_sz;
}; };
class VM class VM
@ -18,6 +20,7 @@ namespace jk
virtual ~VM(); virtual ~VM();
void execute(std::shared_ptr<Program> program); void execute(std::shared_ptr<Program> program);
void execute();
std::string string() const; std::string string() const;
@ -35,6 +38,8 @@ namespace jk
void push(param_t param); void push(param_t param);
param_t pop(); param_t pop();
std::shared_ptr<Program> program() const;
}; };
} }

View File

@ -84,7 +84,7 @@ int main(int argc, char** argv)
auto sym = std::make_shared<jk::SymTable>(logger); auto sym = std::make_shared<jk::SymTable>(logger);
auto compiler = std::make_shared<jk::Compiler>(sym, logger); auto compiler = std::make_shared<jk::Compiler>(logger);
auto program = std::make_shared<jk::Program>(); auto program = std::make_shared<jk::Program>();
auto vm = std::make_shared<jk::VM>(); auto vm = std::make_shared<jk::VM>();
auto loader = std::make_shared<jk::Loader>(vm, sym); auto loader = std::make_shared<jk::Loader>(vm, sym);
@ -93,7 +93,7 @@ int main(int argc, char** argv)
auto static_pass = std::make_shared<jk::StaticPass>(sym, logger); auto static_pass = std::make_shared<jk::StaticPass>(sym, logger);
static_pass->pass(ast); static_pass->pass(ast);
compiler->compile(ast, program); compiler->compile(ast, program, sym);
if (debug_mode) if (debug_mode)
{ {

View File

@ -73,3 +73,11 @@ TEST_CASE_METHOD(LexerTest, "Lexer_vardecl")
test_next(*lexer, "DECL"); test_next(*lexer, "DECL");
test_end(*lexer); test_end(*lexer);
} }
TEST_CASE_METHOD(LexerTest, "Lexer_lambda")
{
auto lexer = jk::Factory(m_logger, "tests/lexer").make_lexer();
lexer->scan(" -> ");
test_next(*lexer, "RARROW");
test_end(*lexer);
}

View File

@ -44,3 +44,20 @@ TEST_CASE_METHOD(ParserTest, "Parser_vardecl")
"FUNCALL(IDENT[f],INT[3],INT[2])))", "FUNCALL(IDENT[f],INT[3],INT[2])))",
" ($ world (f 3 2)) "); " ($ world (f 3 2)) ");
} }
TEST_CASE_METHOD(ParserTest, "Parser_lambda")
{
test_parser("PROG(LAMBDA(PARAMS("
"IDENT[x],IDENT[y]"
"),BODY("
"FUNCALL(IDENT[add],IDENT[y],IDENT[x])"
")))",
" (-> (x y) (add y x)) ");
test_parser("PROG(LAMBDA(PARAMS("
"IDENT[x]"
"),BODY("
"IDENT[x]"
")))",
" (-> (x) x)");
}