ADD: function literal.

main
bog 2023-09-11 20:59:54 +02:00
parent 59af9d809e
commit a5ec909a5a
27 changed files with 383 additions and 69 deletions

View File

@ -5,5 +5,9 @@ EXPR ::=
| ident
| VARDECL
| FUNCALL
| LAMBDA
VARDECL ::= opar decl ident EXPR cpar
FUNCALL ::= opar ident EXPR* cpar
FUNCALL ::= opar EXPR EXPR* cpar
LAMBDA ::= opar lambda opar PARAMS cpar BODY cpar
PARAMS ::= ident*
BODY ::= EXPR*

10
examples/fun.gri Normal file
View File

@ -0,0 +1,10 @@
;; declare lambda
($ a (-> (x) (* x 2)))
(assert-eq? 14 (a 7))
;; high order function
($ b (-> (x y) (x (x y))))
(assert-eq? 12 (b a 3))
;; calling function literal
(assert-eq? 7 ( (-> (x y) (+ x y 1)) 2 4 ))

View File

@ -195,12 +195,14 @@ extern "C" void lib_bool(grino::Loader& loader)
return grino::Value::make_bool(args[0]->loc(), !args[0]->as_bool());
});
loader.add_static("and", [](auto& compiler, auto node, auto& program){
loader.add_static("and", [](auto& compiler, auto node,
auto& program,
auto& sym){
std::vector<size_t> to_false;
for (size_t i=1; i<node->size(); i++)
{
compiler.compile(node->child(i).lock(), program);
compiler.compile(node->child(i).lock(), program, sym);
to_false.push_back(program.size());
program.push_instr(grino::OPCODE_BRF, 0 /* to false */);
}
@ -223,12 +225,13 @@ extern "C" void lib_bool(grino::Loader& loader)
program.set_param(to_end, program.size());
});
loader.add_static("or", [](auto& compiler, auto node, auto& program){
loader.add_static("or", [](auto& compiler, auto node,
auto& program, auto& sym){
std::vector<size_t> to_true;
for (size_t i=1; i<node->size(); i++)
{
compiler.compile(node->child(i).lock(), program);
compiler.compile(node->child(i).lock(), program, sym);
program.push_instr(grino::OPCODE_NOT);
to_true.push_back(program.size());
program.push_instr(grino::OPCODE_BRF, 0 /* to true */);

View File

@ -6,10 +6,10 @@
namespace grino
{
/*explicit*/ Compiler::Compiler(Logger& logger, SymTable& sym)
/*explicit*/ Compiler::Compiler(Logger& logger)
: m_logger { logger }
, m_sym { sym }
{
enter_scope();
}
/*virtual*/ Compiler::~Compiler()
@ -17,18 +17,31 @@ namespace grino
}
void Compiler::compile(std::shared_ptr<Node> node,
Program& program)
Program& program,
SymTable& sym)
{
switch (node->type())
{
case NODE_MODULE: {
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i).lock(), program);
compile(node->child(i).lock(), program, sym);
program.push_instr(OPCODE_POP);
}
} break;
case NODE_BODY: {
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i).lock(), program, sym);
if (i < node->size() - 1)
{
program.push_instr(OPCODE_POP);
}
}
} break;
case NODE_FUNCALL: {
std::string ident = node->child(0).lock()->repr();
@ -36,13 +49,13 @@ namespace grino
itr != std::end(m_statics))
{
auto fun = itr->second;
fun->call(*this, node, program);
fun->call(*this, node, program, sym);
}
else
{
for (size_t i=0; i<node->size(); i++)
{
compile(node->child(i).lock(), program);
compile(node->child(i).lock(), program, sym);
}
program.push_instr(OPCODE_CALL, node->size() - 1);
@ -50,6 +63,32 @@ namespace grino
} break;
case NODE_LAMBDA: {
auto params = node->child(0).lock();
enter_scope();
for (size_t i=0; i<params->size(); i++)
{
auto param = params->child(i).lock();
std::string ident = param->repr();
sym.declare(param->loc(), ident, get_local_address(),
m_scope.size());
}
auto prog = std::make_shared<Program>();
auto body = node->child(1).lock();
compile(body, *prog, sym);
prog->push_instr(OPCODE_RET);
leave_scope();
sym.purge(m_scope.size());
program.push_value(Value::make_program(node->loc(), prog));
program.push_instr(OPCODE_MK_FUN);
} break;
case NODE_BOOL: {
std::string repr = node->repr();
auto value = Value::make_bool(node->loc(), repr == "true");
@ -71,15 +110,15 @@ namespace grino
auto expr = node->child(1).lock();
size_t address = get_local_address();
compile(expr, program);
compile(expr, program, sym);
m_sym.declare(node->loc(), ident, address);
sym.declare(node->loc(), ident, address, m_scope.size());
program.push_instr(OPCODE_STORE_LOCAL, address);
} break;
case NODE_IDENT: {
std::string ident = node->repr();
auto entry = m_sym.find(ident);
auto entry = sym.find(ident, m_scope.size());
if (entry == std::nullopt)
{
@ -114,8 +153,18 @@ namespace grino
size_t Compiler::get_local_address()
{
static size_t addr = 0;
addr++;
return addr - 1;
m_scope.back() += 1;
return m_scope.back() - 1;
}
void Compiler::enter_scope()
{
m_scope.push_back(0);
}
void Compiler::leave_scope()
{
assert(m_scope.empty() == false);
m_scope.pop_back();
}
}

View File

@ -17,22 +17,26 @@ namespace grino
class Compiler
{
public:
explicit Compiler(Logger& logger, SymTable& sym);
explicit Compiler(Logger& logger);
virtual ~Compiler();
void compile(std::shared_ptr<Node> node,
Program& program);
Program& program,
SymTable& sym);
void add_static_func(std::string const& name,
std::shared_ptr<StaticFunction> fun);
private:
Logger& m_logger;
SymTable& m_sym;
std::unordered_map<std::string,
std::shared_ptr<StaticFunction>> m_statics;
std::vector<size_t> m_scope;
size_t get_local_address();
void enter_scope();
void leave_scope();
};
}

View File

@ -1,5 +1,6 @@
#include "Function.hpp"
#include "Value.hpp"
#include "Program.hpp"
namespace grino
{
@ -8,10 +9,25 @@ namespace grino
{
}
/*explicit*/ Function::Function(std::shared_ptr<Program> prog)
: m_prog { prog }
{
}
/*virtual*/ Function::~Function()
{
}
bool Function::is_native() const
{
return m_native != nullptr;
}
std::shared_ptr<Program> Function::program() const
{
return m_prog;
}
value_t Function::call(args_t args)
{
assert(m_native);

View File

@ -6,6 +6,7 @@
namespace grino
{
class Value;
class Program;
using value_t = std::shared_ptr<Value>;
using args_t = std::vector<value_t>;
@ -15,12 +16,16 @@ namespace grino
{
public:
explicit Function(native_t native);
explicit Function(std::shared_ptr<Program> prog);
virtual ~Function();
bool is_native() const;
std::shared_ptr<Program> program() const;
value_t call(args_t args);
private:
native_t m_native;
std::shared_ptr<Program> m_prog;
};
}

View File

@ -11,6 +11,7 @@ namespace grino
add_text(NODE_CPAR, ")", false);
add_text(NODE_DECL, "$", false);
add_keyword(NODE_LAMBDA, "->", false);
add_keyword(NODE_BOOL, "true", true);
add_keyword(NODE_BOOL, "false", true);

View File

@ -43,7 +43,7 @@ namespace grino
grino::Loc loc {"???", 0};
m_vm.set_heap(addr, grino::Value::make_native_function(loc, native));
m_sym_table.declare_object(loc, name, addr);
m_sym_table.declare_object(loc, name, addr, 0);
}
void Loader::add_static(std::string const& name, static_fun_t fun)

View File

@ -13,7 +13,10 @@
G(NODE_IDENT), \
G(NODE_OPAR), \
G(NODE_CPAR), \
G(NODE_DECL),
G(NODE_DECL), \
G(NODE_LAMBDA), \
G(NODE_PARAMS), \
G(NODE_BODY),
namespace grino
{

View File

@ -1,4 +1,5 @@
#include "Parser.hpp"
#include "src/Node.hpp"
#include "src/mutils.hpp"
namespace grino
@ -117,16 +118,21 @@ namespace grino
std::shared_ptr<Node> Parser::parse_expr()
{
if (type_is({NODE_OPAR, NODE_IDENT}))
{
return parse_funcall();
}
if (type_is({NODE_OPAR, NODE_DECL}))
{
return parse_vardecl();
}
if (type_is({NODE_OPAR, NODE_LAMBDA}))
{
return parse_lambda();
}
if (type_is(NODE_OPAR))
{
return parse_funcall();
}
if (type_is(NODE_IDENT)
|| type_is(NODE_BOOL)
|| type_is(NODE_INT))
@ -163,7 +169,7 @@ namespace grino
consume(NODE_OPAR);
auto node = make_node(NODE_FUNCALL);
node->add_child(consume(NODE_IDENT));
node->add_child(parse_expr());
while (!type_is(NODE_CPAR))
{
@ -174,4 +180,44 @@ namespace grino
return node;
}
std::shared_ptr<Node> Parser::parse_lambda()
{
consume(NODE_OPAR);
auto node = consume(NODE_LAMBDA);
consume(NODE_OPAR);
node->add_child(parse_params());
consume(NODE_CPAR);
node->add_child(parse_body());
consume(NODE_CPAR);
return node;
}
std::shared_ptr<Node> Parser::parse_params()
{
auto node = make_node(NODE_PARAMS);
while (!type_is(NODE_CPAR))
{
node->add_child(consume(NODE_IDENT));
}
return node;
}
std::shared_ptr<Node> Parser::parse_body()
{
auto node = make_node(NODE_BODY);
while (!type_is(NODE_CPAR))
{
node->add_child(parse_expr());
}
return node;
}
}

View File

@ -36,6 +36,9 @@ namespace grino
std::shared_ptr<Node> parse_expr();
std::shared_ptr<Node> parse_vardecl();
std::shared_ptr<Node> parse_funcall();
std::shared_ptr<Node> parse_lambda();
std::shared_ptr<Node> parse_params();
std::shared_ptr<Node> parse_body();
};
}

View File

@ -54,6 +54,12 @@ namespace grino
return m_constants[index];
}
void Program::push_value(std::shared_ptr<Value> value)
{
size_t addr = push_constant(value);
push_instr(OPCODE_LOAD_CONST, addr);
}
std::string Program::string() const
{
std::stringstream ss;

View File

@ -31,6 +31,8 @@ namespace grino
size_t push_constant(std::shared_ptr<Value> value);
std::shared_ptr<Value> constant(size_t index) const;
void push_value(std::shared_ptr<Value> value);
std::string string() const;
private:

View File

@ -14,8 +14,9 @@ namespace grino
void StaticFunction::call(Compiler& compiler,
node_t node,
prog_t prog)
prog_t prog,
SymTable& sym)
{
m_fun(compiler, node, prog);
m_fun(compiler, node, prog,sym);
}
}

View File

@ -5,13 +5,15 @@
#include "Function.hpp"
#include "Node.hpp"
#include "Program.hpp"
#include "SymTable.hpp"
namespace grino
{
class Compiler;
using node_t = std::shared_ptr<Node>;
using prog_t = Program&;
using static_fun_t = std::function<void (Compiler&, node_t, prog_t)>;
using sym_t = SymTable&;
using static_fun_t = std::function<void (Compiler&, node_t, prog_t, sym_t)>;
class StaticFunction
{
@ -19,7 +21,7 @@ namespace grino
explicit StaticFunction(static_fun_t fun);
virtual ~StaticFunction();
void call(Compiler& compiler, node_t node, prog_t prog);
void call(Compiler& compiler, node_t node, prog_t prog, sym_t sym);
private:
static_fun_t m_fun;

View File

@ -11,9 +11,11 @@ namespace grino
{
}
void SymTable::declare(Loc const& loc, std::string const& name, size_t addr)
void SymTable::declare(Loc const& loc, std::string const& name, size_t addr,
size_t scope)
{
if (find(name))
if (auto e = find(name, scope);
e && e->scope == scope)
{
m_logger.log<symbolic_error>(LOG_ERROR, loc, "'"
+ name
@ -24,29 +26,46 @@ namespace grino
entry.addr = addr;
entry.name = name;
entry.is_object = false;
entry.scope = scope;
m_entries.push_back(entry);
}
void SymTable::declare_object(Loc const& loc,
std::string const& name,
size_t addr)
size_t addr,
size_t scope)
{
declare(loc, name, addr);
declare(loc, name, addr, scope);
m_entries.back().is_object = true;
}
std::optional<SymEntry> SymTable::find(std::string const& name)
std::optional<SymEntry> SymTable::find(std::string const& name, size_t scope)
{
std::optional<SymEntry> entry;
for (size_t i=0; i<m_entries.size(); i++)
{
if (m_entries[i].name == name)
if (m_entries[i].name == name
&& m_entries[i].scope <= scope
&& (!entry || m_entries[i].scope > entry->scope))
{
return m_entries[i];
entry = m_entries[i];
}
}
return std::nullopt;
return entry;
}
void SymTable::purge(size_t scope)
{
auto itr = std::remove_if(std::begin(m_entries),
std::end(m_entries),
[scope](auto const& entry){
return entry.scope > scope;
});
m_entries.erase(itr, std::end(m_entries));
}
std::string SymTable::string() const

View File

@ -10,6 +10,7 @@ namespace grino
std::string name;
bool is_object; /* object are on the heap instead of the stack */
size_t addr; /* address on the heap if object, local address otherwise */
size_t scope;
};
GRINO_ERROR(symbolic_error);
@ -20,10 +21,15 @@ namespace grino
explicit SymTable(Logger& logger);
virtual ~SymTable();
void declare(Loc const& loc, std::string const& name, size_t addr);
void declare_object(Loc const& loc, std::string const& name, size_t addr);
void declare(Loc const& loc, std::string const& name, size_t addr,
size_t scope);
std::optional<SymEntry> find(std::string const& name);
void declare_object(Loc const& loc, std::string const& name, size_t addr,
size_t scope);
std::optional<SymEntry> find(std::string const& name, size_t scope);
void purge(size_t scope);
std::string string() const;

View File

@ -1,33 +1,52 @@
#include "VM.hpp"
#include "src/opcodes.hpp"
#include <optional>
namespace grino
{
/*explicit*/ VM::VM(Logger& logger)
/*explicit*/ VM::VM(Logger& logger, Program& program)
: m_logger { logger }
{
m_frames.push_back(Frame {});
Frame frame;
frame.pc = m_pc;
frame.sp = m_sp;
frame.program = &program;
m_frames.push_back(frame);
}
/*virtual*/ VM::~VM()
{
}
void VM::run(Program& program)
void VM::run()
{
m_bp = 0;
m_sp = 0;
m_pc = 0;
while (m_pc < program.size())
while (m_pc < program().size())
{
Instr instr = program.get(m_pc);
Instr instr = program().get(m_pc);
switch (instr.opcode)
{
case OPCODE_MK_FUN: {
auto prog_val = program().constant(pop());
auto prog = prog_val->as_program();
auto fun_val = Value::make_function(prog_val->loc(), prog);
size_t addr = m_heap.size();
set_heap(addr, fun_val);
auto ref = Value::make_ref(prog_val->loc(), addr);
push(program().push_constant(ref));
m_pc++;
} break;
case OPCODE_NOT: {
auto val = program.constant(pop());
push(program.push_constant(Value::make_bool(val->loc(),
auto val = program().constant(pop());
push(program().push_constant(Value::make_bool(val->loc(),
!val->as_bool())));
m_pc++;
} break;
@ -37,7 +56,7 @@ namespace grino
} break;
case OPCODE_BRF: {
auto val = program.constant(pop())->as_bool();
auto val = program().constant(pop())->as_bool();
size_t addr = *instr.param;
if (!val)
@ -62,21 +81,39 @@ namespace grino
case OPCODE_STORE_LOCAL: {
size_t addr = *instr.param;
auto value = program.constant(top());
auto value = program().constant(top());
set_local(addr, value);
m_pc++;
} break;
case OPCODE_LOAD_LOCAL: {
size_t addr = *instr.param;
push(program.push_constant(local(addr)));
push(program().push_constant(local(addr)));
m_pc++;
} break;
case OPCODE_LOAD_OBJ: {
size_t addr = *instr.param;
auto ref = Value::make_ref(Loc {"???", 0}, addr);
push(program.push_constant(ref));
push(program().push_constant(ref));
m_pc++;
} break;
case OPCODE_RET: {
m_pc = m_frames.back().pc;
size_t old_sp = m_frames.back().sp;
auto ret_val = program().constant(m_stack[m_sp - 1]);
m_frames.pop_back();
while (m_sp > old_sp)
{
pop();
}
push(program().push_constant(ret_val));
m_pc++;
} break;
@ -87,18 +124,33 @@ namespace grino
for (size_t i=0; i<N; i++)
{
args.insert(std::begin(args), program.constant(pop()));
args.insert(std::begin(args), program().constant(pop()));
}
size_t ref = program.constant(pop())->as_ref();
size_t ref = program().constant(pop())->as_ref();
auto fun = heap(ref)->as_function();
auto ret = fun->call(args);
if (fun->is_native())
{
auto ret = fun->call(args);
push(program().push_constant(ret));
m_pc++;
}
else
{
Frame frame;
frame.pc = m_pc;
frame.sp = m_sp;
frame.program = fun->program().get();
m_frames.push_back(frame);
push(program.push_constant(ret));
m_pc++;
for (size_t i=0; i<args.size(); i++)
{
set_local(i, args[i]);
}
m_pc = 0;
}
} break;
default:
@ -110,6 +162,7 @@ namespace grino
}
}
std::string VM::string() const
{
std::stringstream ss;
@ -173,4 +226,10 @@ namespace grino
size_t addr = m_stack[m_sp - 1];
return addr;
}
Program& VM::program() const
{
assert(m_frames.back().program);
return *m_frames.back().program;
}
}

View File

@ -12,18 +12,21 @@ namespace grino
GRINO_ERROR(execution_error);
struct Frame {
size_t pc;
size_t sp;
Program* program;
std::unordered_map<size_t, std::shared_ptr<Value>> locals;
};
class VM
{
public:
explicit VM(Logger& logger);
explicit VM(Logger& logger, Program& program);
virtual ~VM();
size_t heap_size() const { return m_heap.size(); }
void run(Program& program);
void run();
std::string string() const;
@ -36,8 +39,10 @@ namespace grino
std::shared_ptr<Value> value);
private:
Logger& m_logger;
std::array<size_t, STACK_SIZE> m_stack;
std::unordered_map<size_t, std::shared_ptr<Value>> m_heap;
std::vector<Frame> m_frames;
size_t m_sp; /* stack pointer */
@ -48,6 +53,7 @@ namespace grino
size_t pop();
size_t top();
Program& program() const;
};
}

View File

@ -1,4 +1,5 @@
#include "Value.hpp"
#include "Program.hpp"
namespace grino
{
@ -35,6 +36,16 @@ namespace grino
return value;
}
/*static*/ std::shared_ptr<Value>
Value::make_function(Loc const& loc,
std::shared_ptr<Program> val)
{
auto value = std::make_shared<Value>(loc);
value->m_type = TYPE_FUNCTION;
value->m_function_val = std::make_shared<Function>(val);
return value;
}
/*static*/ std::shared_ptr<Value> Value::make_ref(Loc const& loc, size_t val)
{
auto value = std::make_shared<Value>(loc);
@ -43,6 +54,21 @@ namespace grino
return value;
}
/*static*/
std::shared_ptr<Value> Value::make_program(Loc const& loc,
std::shared_ptr<Program> val)
{
auto value = std::make_shared<Value>(loc);
value->m_type = TYPE_PROGRAM;
value->m_program_val = val;
return value;
}
std::shared_ptr<Program> Value::as_program() const
{
return m_program_val;
}
std::string Value::string() const
{
switch (m_type)
@ -52,6 +78,7 @@ namespace grino
case TYPE_BOOL: return *m_bool_val ? "true" : "false";
case TYPE_FUNCTION: return "<function>";
case TYPE_REF: return "&" + std::to_string(*m_ref_val);
case TYPE_PROGRAM: return "<program>";
default:
std::cerr << "cannot stringify value "
<< TypeTypeStr[m_type] << std::endl;
@ -69,6 +96,7 @@ namespace grino
case TYPE_BOOL: return *m_bool_val == *other.m_bool_val;
case TYPE_INT: return *m_int_val == *other.m_int_val;
case TYPE_REF: return *m_ref_val == *other.m_ref_val;
case TYPE_PROGRAM: return false;
default:
std::cerr << "cannot compare equality with value "

View File

@ -8,6 +8,8 @@
namespace grino
{
class Program;
class Value
{
public:
@ -16,9 +18,14 @@ namespace grino
static std::shared_ptr<Value> make_int(Loc const& loc, int val);
static std::shared_ptr<Value> make_native_function(Loc const& loc,
native_t val);
static std::shared_ptr<Value> make_function(Loc const& loc,
std::shared_ptr<Program> val);
static std::shared_ptr<Value> make_ref(Loc const& loc,
size_t val);
static std::shared_ptr<Value> make_program(Loc const& loc,
std::shared_ptr<Program> val);
explicit Value(Loc const& loc);
virtual ~Value() = default;
@ -28,6 +35,7 @@ namespace grino
int as_int() const { return *m_int_val; }
std::shared_ptr<Function> as_function() const { return m_function_val; }
size_t as_ref() const { return *m_ref_val; }
std::shared_ptr<Program> as_program() const;
std::string string() const;
bool equals(Value const& other) const;
@ -38,6 +46,7 @@ namespace grino
std::optional<bool> m_bool_val;
std::optional<int> m_int_val;
std::shared_ptr<Function> m_function_val;
std::shared_ptr<Program> m_program_val;
std::optional<size_t> m_ref_val;
};
}

View File

@ -36,14 +36,14 @@ void run(char** argv, bool debug_mode)
}
grino::SymTable sym_table {logger};
grino::VM vm {logger};
grino::Compiler compiler {logger, sym_table};
grino::Program program;
grino::VM vm {logger, program};
grino::Compiler compiler {logger};
grino::Loader loader {vm, compiler, sym_table};
loader.load_libraries();
grino::Program program;
compiler.compile(ast, program);
compiler.compile(ast, program, sym_table);
if (debug_mode)
{
@ -51,7 +51,7 @@ void run(char** argv, bool debug_mode)
std::cout << program.string() << std::endl;
}
vm.run(program);
vm.run();
if (debug_mode)
{

View File

@ -15,7 +15,8 @@
G(OPCODE_BRF), \
G(OPCODE_BR), \
G(OPCODE_NOT), \
G(OPCODE_RET),
G(OPCODE_RET), \
G(OPCODE_MK_FUN),
namespace grino
{

View File

@ -8,7 +8,9 @@
G(TYPE_BOOL), \
G(TYPE_INT), \
G(TYPE_FUNCTION), \
G(TYPE_REF)
G(TYPE_REF), \
G(TYPE_PROGRAM)
namespace grino
{

View File

@ -90,3 +90,12 @@ TEST_CASE_METHOD(LexerTest, "Lexer_int")
test_next(lexer, "INT[-7]");
test_end(lexer);
}
TEST_CASE_METHOD(LexerTest, "Lexer_lambda")
{
grino::Lexer lexer {m_logger, "tests/lexer"};
lexer.scan(" -> ");
test_next(lexer, "LAMBDA");
test_end(lexer);
}

View File

@ -54,3 +54,23 @@ TEST_CASE_METHOD(ParserTest, "Parser_integer")
test_parse("MODULE(FUNCALL(IDENT[hello],INT[2],INT[34]))",
"(hello 2 34)");
}
TEST_CASE_METHOD(ParserTest, "Parser_lambda")
{
test_parse("MODULE(LAMBDA(PARAMS,BODY))",
"(-> ())");
test_parse("MODULE(LAMBDA(PARAMS(IDENT[x]),BODY))",
"(-> (x))");
test_parse("MODULE(LAMBDA(PARAMS(IDENT[x],IDENT[y]),BODY(IDENT[y])))",
"(-> (x y) y)");
test_parse("MODULE(FUNCALL(LAMBDA(PARAMS("
"IDENT[x]"
"),BODY("
"IDENT[x]"
")),INT[4]))",
"( (-> (x) x ) 4 )");
}