ADD: function literal (no closure yet).

main
bog 2023-09-10 09:59:51 +02:00
parent 749f6c0a95
commit 227a9b7a25
21 changed files with 270 additions and 38 deletions

View File

@ -3,6 +3,10 @@ EXPR ::=
LITERAL
| FUNCALL
| VARDECL
| LAMBDA
LAMBDA ::= opar rarrow opar PARAMS cpar BODY cpar
PARAMS ::= ident*
BODY ::= expr*
VARDECL ::= opar decl ident EXPR cpar
FUNCALL ::= opar ident EXPR* cpar
LITERAL ::= int | ident

View File

@ -7,6 +7,11 @@ namespace jk
{
}
/*explicit*/ Code::Code(std::shared_ptr<Program> program)
: m_program { program }
{
}
/*virtual*/ Code::~Code()
{
}

View File

@ -10,12 +10,15 @@ namespace jk
{
public:
explicit Code(foreign_t foreign);
explicit Code(std::shared_ptr<Program> program);
virtual ~Code();
foreign_t foreign() const { return m_foreign; }
std::shared_ptr<Program> program() const { return m_program; }
private:
foreign_t m_foreign;
std::shared_ptr<Program> m_program;
};
}

View File

@ -4,10 +4,8 @@
namespace jk
{
/*explicit*/ Compiler::Compiler(std::shared_ptr<SymTable> sym,
Logger& logger)
: m_sym { sym }
, m_logger(logger)
/*explicit*/ Compiler::Compiler(Logger& logger)
: m_logger(logger)
{
}
@ -16,21 +14,23 @@ namespace jk
}
void Compiler::compile(std::shared_ptr<Node> node,
std::shared_ptr<Program> program)
std::shared_ptr<Program> program,
std::shared_ptr<SymTable> sym)
{
switch (node->type())
{
case NODE_PROG: {
case NODE_PROG:
case NODE_BODY: {
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i).lock(), program);
compile(node->child(i).lock(), program, sym);
}
} break;
case NODE_FUNCALL: {
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i).lock(), program);
compile(node->child(i).lock(), program, sym);
}
program->push_instr(OPCODE_CALL, node->size() - 1);
@ -38,26 +38,50 @@ namespace jk
case NODE_VARDECL: {
std::string ident = node->child(0).lock()->repr();
auto entry = m_sym->find(ident);
auto entry = sym->find(ident);
assert(entry);
compile(node->child(1).lock(), program);
compile(node->child(1).lock(), program, sym);
program->push_instr(OPCODE_STORE, entry->addr);
} break;
case NODE_LAMBDA: {
auto params = node->child(0).lock();
auto body = node->child(1).lock();
auto prog = std::make_shared<Program>();
auto fun_sym = std::make_shared<SymTable>(m_logger, sym);
for (size_t i=0; i<params->size(); i++)
{
std::string param = params->child(i).lock()->repr();
fun_sym->declare(param,
std::make_shared<Type>(TYPE_NIL),
node->loc());
}
compile(body, prog, fun_sym);
prog->push_instr(OPCODE_RET);
auto code = std::make_shared<Code>(prog);
size_t addr = program->push_constant(Value::make_code(code));
program->push_instr(OPCODE_PUSH_CONST, addr);
program->push_instr(OPCODE_MK_FUNCTION);
//program->push_instr(OpcodeType opcode)
} break;
case NODE_IDENT: {
std::string ident = node->repr();
auto sym = m_sym->find(ident);
auto mysym = sym->find(ident);
assert(sym);
OpcodeType op_load = OPCODE_LOAD;
if (sym->is_global)
if (mysym->is_global)
{
op_load = OPCODE_LOAD_GLOBAL;
}
program->push_instr(op_load, sym->addr);
program->push_instr(op_load, mysym->addr);
} break;

View File

@ -15,15 +15,14 @@ namespace jk
class Compiler
{
public:
explicit Compiler(std::shared_ptr<SymTable> sym,
Logger& logger);
explicit Compiler(Logger& logger);
virtual ~Compiler();
void compile(std::shared_ptr<Node> node,
std::shared_ptr<Program> program);
std::shared_ptr<Program> program,
std::shared_ptr<SymTable> sym);
private:
std::shared_ptr<SymTable> m_sym;
Logger& m_logger;
};
}

View File

@ -8,12 +8,22 @@ namespace jk
{
}
/*explicit*/ Function::Function(std::shared_ptr<Program> program)
: m_program { program }
{
}
/*virtual*/ Function::~Function()
{
}
value_t Function::call(std::vector<value_t> const& args)
{
if (m_foreign)
{
return m_foreign(args);
}
return Value::make_nil();
}
}

View File

@ -2,6 +2,7 @@
#define jk_FUNCTION_HPP
#include "commons.hpp"
#include "Program.hpp"
namespace jk
{
@ -14,12 +15,17 @@ namespace jk
{
public:
explicit Function(foreign_t foreign);
explicit Function(std::shared_ptr<Program> program);
virtual ~Function();
std::shared_ptr<Program> program() const { return m_program; }
value_t call(std::vector<value_t> const& args);
private:
foreign_t m_foreign;
std::shared_ptr<Program> m_program;
};
}

View File

@ -7,6 +7,7 @@ namespace jk
, m_loc { loc }
{
std::vector<std::tuple<NodeType, std::string, bool>> texts = {
{NODE_RARROW, "->", false},
{NODE_DECL, "$", false},
{NODE_OPAR, "(", false},
{NODE_CPAR, ")", false}

View File

@ -12,7 +12,12 @@
G(NODE_IDENT), \
G(NODE_DECL), \
G(NODE_FUNCALL), \
G(NODE_VARDECL)
G(NODE_VARDECL), \
G(NODE_RARROW), \
G(NODE_LAMBDA), \
G(NODE_PARAMS), \
G(NODE_BODY),
namespace jk

View File

@ -9,6 +9,7 @@
G(OPCODE_LOAD), \
G(OPCODE_STORE), \
G(OPCODE_LOAD_GLOBAL), \
G(OPCODE_RET), \
G(OPCODE_MK_FUNCTION)
namespace jk

View File

@ -111,6 +111,12 @@ namespace jk
return parse_vardecl();
}
if (type_is(NODE_OPAR)
&& type_is(NODE_RARROW, 1))
{
return parse_lambda();
}
return parse_literal();
}
@ -126,6 +132,44 @@ namespace jk
return root;
}
std::shared_ptr<Node> Parser::parse_lambda()
{
auto root = std::make_shared<Node>(NODE_LAMBDA, "", loc());
consume(NODE_OPAR);
consume(NODE_RARROW);
consume(NODE_OPAR);
root->add_child(parse_params());
consume(NODE_CPAR);
root->add_child(parse_body());
consume(NODE_CPAR);
return root;
}
std::shared_ptr<Node> Parser::parse_params()
{
auto root = std::make_shared<Node>(NODE_PARAMS, "", loc());
while (!type_is(NODE_CPAR))
{
root->add_child(consume(NODE_IDENT));
}
return root;
}
std::shared_ptr<Node> Parser::parse_body()
{
auto root = std::make_shared<Node>(NODE_BODY, "", loc());
while (!type_is(NODE_CPAR))
{
root->add_child(parse_expr());
}
return root;
}
std::shared_ptr<Node> Parser::parse_funcall()
{
auto root = std::make_shared<Node>(NODE_FUNCALL, "", loc());

View File

@ -32,6 +32,9 @@ namespace jk
std::shared_ptr<Node> parse_prog();
std::shared_ptr<Node> parse_expr();
std::shared_ptr<Node> parse_vardecl();
std::shared_ptr<Node> parse_lambda();
std::shared_ptr<Node> parse_params();
std::shared_ptr<Node> parse_body();
std::shared_ptr<Node> parse_funcall();
std::shared_ptr<Node> parse_literal();
};

View File

@ -48,6 +48,7 @@ namespace jk
size_t Program::push_constant(std::shared_ptr<Value> value)
{
assert(value);
m_constants.push_back(value);
return m_constants.size() - 1;
}

View File

@ -17,7 +17,8 @@ namespace jk
{
switch (node->type())
{
case NODE_PROG: {
case NODE_PROG:
case NODE_BODY: {
for (size_t i=0; i<node->size(); i++)
{
pass(node->child(i).lock());
@ -25,6 +26,24 @@ namespace jk
}
} break;
case NODE_LAMBDA: {
auto params = node->child(0).lock();
auto body = node->child(1).lock();
for (size_t i=0; i<params->size(); i++)
{
auto param = params->child(i).lock();
std::string ident = params->child(i).lock()->repr();
m_sym->declare(ident,
std::make_shared<Type>(TYPE_NIL),
param->loc());
}
pass(body);
push(std::make_shared<Type>(TYPE_FUNCTION));
} break;
case NODE_INT: {
push(std::make_shared<Type>(TYPE_INT));
} break;

View File

@ -2,8 +2,10 @@
namespace jk
{
/*explicit*/ SymTable::SymTable(Logger& logger)
/*explicit*/ SymTable::SymTable(Logger& logger,
std::shared_ptr<SymTable> parent)
: m_logger { logger }
, m_parent { parent }
{
}
@ -16,7 +18,7 @@ namespace jk
std::shared_ptr<Type> type,
Loc const& loc)
{
if (find(name))
if (find_local(name))
{
m_logger.log<symbolic_error>(LOG_ERROR,
loc,
@ -49,7 +51,7 @@ namespace jk
return addr;
}
std::optional<Sym> SymTable::find(std::string const& name) const
std::optional<Sym> SymTable::find_local(std::string const& name) const
{
for (auto const& sym: m_syms)
{
@ -61,4 +63,17 @@ namespace jk
return std::nullopt;
}
std::optional<Sym> SymTable::find(std::string const& name) const
{
auto entry = find_local(name);
if (entry) { return entry; }
if (m_parent)
{
return m_parent->find(name);
}
return std::nullopt;
}
}

View File

@ -20,7 +20,9 @@ namespace jk
class SymTable
{
public:
explicit SymTable(Logger& logger);
explicit SymTable(Logger& logger,
std::shared_ptr<SymTable> parent=nullptr);
virtual ~SymTable();
size_t declare(std::string const& name,
@ -32,9 +34,11 @@ namespace jk
Loc const& loc);
std::optional<Sym> find(std::string const& name) const;
std::optional<Sym> find_local(std::string const& name) const;
private:
Logger& m_logger;
std::shared_ptr<SymTable> m_parent;
std::vector<Sym> m_syms;
size_t m_addr = 0;
};

View File

@ -18,7 +18,11 @@ namespace jk
Frame frame;
frame.program = program;
m_frames.push_back(frame);
execute();
}
void VM::execute()
{
while (m_pc < m_frames.back().program->size())
{
Instr instr = m_frames.back().program->get(m_pc);
@ -26,17 +30,28 @@ namespace jk
switch (instr.opcode)
{
case OPCODE_MK_FUNCTION: {
auto code = program->constant(pop());
auto code = program()->constant(pop());
auto function = std::make_shared<Function>
auto prog = code->as_code()->program();
std::shared_ptr<Function> function;
if (prog)
{
function = std::make_shared<Function>(prog);
}
else
{
function = std::make_shared<Function>
(code->as_code()->foreign());
}
auto value = Value::make_function(function);
size_t addr = push_heap(value);
auto ref = Value::make_ref(addr);
push(program->push_constant(ref));
push(program()->push_constant(ref));
m_pc++;
@ -49,7 +64,7 @@ namespace jk
case OPCODE_LOAD: {
auto value = m_frames.back().locals[*instr.param];
push(program->push_constant(value));
push(program()->push_constant(value));
m_pc++;
} break;
@ -57,35 +72,73 @@ namespace jk
auto idx = pop();
m_frames.back().locals[*instr.param] =
program->constant(idx);
program()->constant(idx);
m_pc++;
} break;
case OPCODE_LOAD_GLOBAL: {
auto value = m_globals[*instr.param];
push(program->push_constant(value));
push(program()->push_constant(value));
m_pc++;
} break;
case OPCODE_RET: {
size_t pc = m_frames.back().ret_addr;
size_t stack_sz = m_frames.back().stack_sz;
auto prog = m_frames.back().program;
param_t ret = pop();
auto ret_val = prog->constant(ret);
assert(ret_val);
while (m_stack.size() > stack_sz)
{
pop();
}
m_frames.pop_back();
push(program()->push_constant(ret_val));
m_pc = pc + 1;
} break;
case OPCODE_CALL: {
std::vector<std::shared_ptr<Value>> args;
for (size_t i=0; i<*instr.param; i++)
{
args.insert(std::begin(args), program->constant(pop()));
args.insert(std::begin(args), program()->constant(pop()));
}
auto ref_val = program->constant(pop());
auto ref_val = program()->constant(pop());
auto ref = ref_val->as_ref();
auto fun_val = heap(ref);
auto fun = fun_val->as_function();
if (auto prog = fun->program();
prog)
{
Frame frame;
frame.program = prog;
frame.ret_addr = m_pc;
frame.stack_sz = m_stack.size();
m_pc = 0;
for (size_t i=0; i<args.size(); i++)
{
frame.locals[i] = args[i];
}
m_frames.push_back(frame);
}
else
{
auto result = fun->call(args);
push(program->push_constant(result));
push(program()->push_constant(result));
m_pc++;
}
} break;
default:
@ -137,4 +190,9 @@ namespace jk
m_stack.pop_back();
return param;
}
std::shared_ptr<Program> VM::program() const
{
return m_frames.back().program;
}
}

View File

@ -9,6 +9,8 @@ namespace jk
struct Frame {
std::shared_ptr<Program> program;
std::unordered_map<int, std::shared_ptr<Value>> locals;
size_t ret_addr;
size_t stack_sz;
};
class VM
@ -18,6 +20,7 @@ namespace jk
virtual ~VM();
void execute(std::shared_ptr<Program> program);
void execute();
std::string string() const;
@ -35,6 +38,8 @@ namespace jk
void push(param_t param);
param_t pop();
std::shared_ptr<Program> program() const;
};
}

View File

@ -84,7 +84,7 @@ int main(int argc, char** argv)
auto sym = std::make_shared<jk::SymTable>(logger);
auto compiler = std::make_shared<jk::Compiler>(sym, logger);
auto compiler = std::make_shared<jk::Compiler>(logger);
auto program = std::make_shared<jk::Program>();
auto vm = std::make_shared<jk::VM>();
auto loader = std::make_shared<jk::Loader>(vm, sym);
@ -93,7 +93,7 @@ int main(int argc, char** argv)
auto static_pass = std::make_shared<jk::StaticPass>(sym, logger);
static_pass->pass(ast);
compiler->compile(ast, program);
compiler->compile(ast, program, sym);
if (debug_mode)
{

View File

@ -73,3 +73,11 @@ TEST_CASE_METHOD(LexerTest, "Lexer_vardecl")
test_next(*lexer, "DECL");
test_end(*lexer);
}
TEST_CASE_METHOD(LexerTest, "Lexer_lambda")
{
auto lexer = jk::Factory(m_logger, "tests/lexer").make_lexer();
lexer->scan(" -> ");
test_next(*lexer, "RARROW");
test_end(*lexer);
}

View File

@ -44,3 +44,20 @@ TEST_CASE_METHOD(ParserTest, "Parser_vardecl")
"FUNCALL(IDENT[f],INT[3],INT[2])))",
" ($ world (f 3 2)) ");
}
TEST_CASE_METHOD(ParserTest, "Parser_lambda")
{
test_parser("PROG(LAMBDA(PARAMS("
"IDENT[x],IDENT[y]"
"),BODY("
"FUNCALL(IDENT[add],IDENT[y],IDENT[x])"
")))",
" (-> (x y) (add y x)) ");
test_parser("PROG(LAMBDA(PARAMS("
"IDENT[x]"
"),BODY("
"IDENT[x]"
")))",
" (-> (x) x)");
}