diff --git a/src/compiler.c b/src/compiler.c index 51b87c6..339e276 100644 --- a/src/compiler.c +++ b/src/compiler.c @@ -11,6 +11,9 @@ void compiler_init(compiler* self) self->binfo.capacity = 1; self->binfo.data = malloc(sizeof(block_info*) * self->binfo.capacity); + self->fun_stack.size = 0; + self->fun_stack.capacity = 1; + self->fun_stack.data = malloc(sizeof(char*) * self->fun_stack.capacity); } void compiler_free(compiler* self) @@ -21,6 +24,17 @@ void compiler_free(compiler* self) free(self->sym); self->sym = NULL; + for (size_t i=0; ifun_stack.size; i++) + { + free(self->fun_stack.data[i]); + } + + free(self->fun_stack.data); + + self->fun_stack.size = 0; + self->fun_stack.capacity = 0; + self->fun_stack.data = NULL; + for (size_t i=0; ibinfo.size; i++) { block_info_free(self->binfo.data[i]); @@ -34,6 +48,39 @@ void compiler_free(compiler* self) self->binfo.data = NULL; } +void compiler_push_fun(compiler* self, char* fun_name) +{ + assert(self); + + if (self->fun_stack.size >= self->fun_stack.capacity) + { + self->fun_stack.capacity *= 2; + self->fun_stack.data = realloc( + self->fun_stack.data, + sizeof(char*) * self->fun_stack.capacity + ); + } + + self->fun_stack.data[self->fun_stack.size] = str_new(fun_name); + self->fun_stack.size++; +} + +void compiler_pop_fun(compiler* self) +{ + assert(self); + assert(self->fun_stack.size > 0); + + free(self->fun_stack.data[self->fun_stack.size - 1]); + self->fun_stack.size--; +} + +char* compiler_top_fun(compiler* self) +{ + assert(self); + assert(self->fun_stack.size > 0); + return self->fun_stack.data[self->fun_stack.size - 1]; +} + block_info* compiler_current_info(compiler* self) { assert(self); @@ -174,15 +221,19 @@ void compile_node(compiler* self, node* root, program* prog) type* ty = cstatic_resolve_new(&cs, self->sym, root->children.data[1]); - + size_t id = symtable_declare( self->sym, root->children.data[0]->value, ty, root->children.data[1] ); - + + + compiler_push_fun(self, root->children.data[0]->value); compile_node(self, root->children.data[1], prog); + compiler_pop_fun(self); + program_add_instr(prog, OP_STORE, id); type_free(ty); @@ -205,7 +256,20 @@ void compile_node(compiler* self, node* root, program* prog) root->children.data[1] ); + int is_fn = root->children.data[1]->type == NODE_FUN; + + if (is_fn) + { + compiler_push_fun(self, root->children.data[0]->value); + } + compile_node(self, root->children.data[1], prog); + + if (is_fn) + { + compiler_pop_fun(self); + } + program_add_instr(prog, OP_STORE, id); type_free(ty); @@ -233,9 +297,22 @@ void compile_node(compiler* self, node* root, program* prog) root->children.data[1] ); - compile_node(self, root->children.data[1], prog); - program_add_instr(prog, OP_STORE, id); + int is_fn = root->children.data[1]->type == NODE_FUN; + if (is_fn) + { + compiler_push_fun(self, root->children.data[0]->value); + } + + compile_node(self, root->children.data[1], prog); + + if (is_fn) + { + compiler_pop_fun(self); + } + + program_add_instr(prog, OP_STORE, id); + type_free(ty); free(ty); cstatic_free(&cs); @@ -292,6 +369,7 @@ void compile_node(compiler* self, node* root, program* prog) int lhs = cstatic_resolve_base_type(&cs, self->sym, root->children.data[0]); + int rhs = cstatic_resolve_base_type(&cs, self->sym, root->children.data[1]); @@ -365,64 +443,70 @@ void compile_node(compiler* self, node* root, program* prog) } else if (root->type == NODE_FUN) { - // init body + + // prepare compiler compiler comp; compiler_init(&comp); - program* fun_prog = malloc(sizeof(program)); - program_init(fun_prog); - node* block = root->children.data[root->children.size - 1]; - // add params + // prepare program + program local_prog; + program_init(&local_prog); + + // resolve function type + cstatic cs; + cstatic_init(&cs); + type* fn_ty = cstatic_resolve_new(&cs, self->sym, root); + + // prepare environment (parameters and function) + for (size_t i=0; ichildren.size; i++) { node* param = root->children.data[i]; + if (param->type != NODE_FUN_PARAM) { continue; } - if (param->type == NODE_FUN_PARAM) - { - char const* param_name = param->children.data[0]->value; - type* param_ty = - type_new_from_node(param->children.data[1]); + char* name = param->children.data[0]->value; + type* ty = type_new_from_node(param->children.data[1]); + + symtable_declare( + comp.sym, + name, + ty, + param + ); - symtable_declare(comp.sym, param_name - , param_ty - , param->children.data[0]); - - type_free(param_ty); free(param_ty); - } + type_free(ty); free(ty); } - - // compile body - compile_node(&comp, block, fun_prog); - compiler_free(&comp); + char* fun_name = compiler_top_fun(self); - // init fun object - cstatic cs; - cstatic_init(&cs); - - fun* fn = malloc(sizeof(fun)); - type* ty = cstatic_resolve_new(&cs, self->sym, root); - - fun_init(fn, ty, (struct program*) fun_prog); - program_free(fun_prog); free(fun_prog); fun_prog = NULL; - - type_free(ty); free(ty); - - // create value - value* val = malloc(sizeof(value)); - value_init_fun(val, fn, root->lineno); - - // push it - program_add_instr( - prog, - OP_PUSH, - program_add_pool(prog, val) + symtable_declare( + comp.sym, + fun_name, + fn_ty, + root ); + // compile fun block + node* fun_block = root->children.data[root->children.size - 1]; + compile_node(&comp, fun_block, &local_prog); + //char c[1024]; node_str(root, c, 1024); printf("%s\n", c); + + // create value + fun* fn = malloc(sizeof(fun)); + fun_init(fn, fn_ty, (struct program*) &local_prog); + value* val = malloc(sizeof(value)); + value_init_fun(val, fn, root->lineno); + cstatic_free(&cs); + + // push value + program_add_instr(prog, OP_PUSH, program_add_pool(prog, val)); + + // free stuff + program_free(&local_prog); + compiler_free(&comp); + type_free(fn_ty); free(fn_ty); value_free(val); free(val); fun_free(fn); free(fn); - - cstatic_free(&cs); } else if (root->type == NODE_CALL) { diff --git a/src/compiler.h b/src/compiler.h index 386625e..257f455 100644 --- a/src/compiler.h +++ b/src/compiler.h @@ -16,11 +16,20 @@ typedef struct { block_info** data; } binfo; + struct { + size_t size; + size_t capacity; + char** data; + } fun_stack; + } compiler; void compiler_init(compiler* self); void compiler_free(compiler* self); +void compiler_push_fun(compiler* self, char* fun_name); +void compiler_pop_fun(compiler* self); +char* compiler_top_fun(compiler* self); block_info* compiler_current_info(compiler* self); block_info* compiler_find_info(compiler* self, int kind); diff --git a/src/cstatic.c b/src/cstatic.c index abd1dfc..abef04b 100644 --- a/src/cstatic.c +++ b/src/cstatic.c @@ -7,7 +7,7 @@ void cstatic_init(cstatic* self) self->fun_types.size = 0; self->fun_types.capacity = 1; - self->fun_types.data = malloc(sizeof(type*) * self->fun_types.capacity); + self->fun_types.data = malloc(sizeof(fun_info*) * self->fun_types.capacity); } void cstatic_free(cstatic* self) @@ -16,8 +16,8 @@ void cstatic_free(cstatic* self) for (size_t i=0; ifun_types.size; i++) { - type_free(self->fun_types.data[i]); - free(self->fun_types.data[i]); + type_free(self->fun_types.data[i]->ty); + free(self->fun_types.data[i]->ty); } free(self->fun_types.data); @@ -90,11 +90,36 @@ type* cstatic_resolve_new(cstatic* self, symtable* sym, node* ast) if (ast->type == NODE_CALL) { char const* fun_name = ast->children.data[0]->value; - + symentry* entry = symtable_find(sym, fun_name); - assert(entry); + fun_info* info = NULL; + + if (self->fun_types.size > 0) + { + info = cstatic_top_fun(self); + } - type* fun_type = entry->ty; + type* fun_type = NULL; + + if (entry) + { + fun_type = entry->ty; + } + else if (info && info->name && strcmp(info->name, fun_name) == 0) + { + fun_type = info->ty; + } + else + { + fprintf(stderr + , "E(%d): cannot call unknown function '%s'.\n" + , ast->children.data[0]->lineno + , fun_name); + + assert(0); + } + + assert(fun_type); size_t ret_count = fun_type->sub_types.size; @@ -106,6 +131,7 @@ type* cstatic_resolve_new(cstatic* self, symtable* sym, node* ast) if (ret_type->kind == KIND_RETURN) { + type* res = type_new_clone(ret_type); res->kind = KIND_CONJUNCTION; return res; @@ -160,8 +186,12 @@ type* cstatic_resolve_new(cstatic* self, symtable* sym, node* ast) return ty; } - if (ast->type == NODE_VARDECL) + if (ast->type == NODE_VARDECL + || ast->type == NODE_CONSTDECL + || ast->type == NODE_FUNDECL) { + char* name = ast->children.data[0]->value; + printf("-> %s\n", name); type* ty = cstatic_resolve_new(self, sym, ast->children.data[1]); @@ -338,7 +368,7 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz { assert(self); assert(ast); - + // BLOCK if (ast->type == NODE_BLOCK) { @@ -365,9 +395,7 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz if (ast->type == NODE_FUN) { type* ty = cstatic_resolve_new(self, sym, ast); - - cstatic_push_fun(self, ty); - + type_free(ty); free(ty); symtable* inner_table = symtable_new_clone(sym); @@ -387,7 +415,7 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz type_free(t); free(t); } } - + int status = cstatic_check(self, ast->children.data[ast->children.size - 1], inner_table, @@ -396,7 +424,6 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz ); symtable_free(inner_table); free(inner_table); - cstatic_pop_fun(self); return status; @@ -406,7 +433,7 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz if (ast->type == NODE_RETURN) { type* ret_ty = cstatic_resolve_new(self, sym, ast); - type* fun_ty = cstatic_top_fun(self); + type* fun_ty = cstatic_top_fun(self)->ty; type* fun_res = NULL; @@ -431,7 +458,7 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz msg, size ); - + if (status) { if (fun_res) @@ -459,6 +486,7 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz for (size_t i=0; isub_types.size; i++) { type* ty = fun_type->sub_types.data[i]; + if (ty->kind != KIND_RETURN) { arity++; @@ -513,7 +541,7 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz } // Children - for (size_t i=0; ichildren.size; i++) + /*for (size_t i=0; ichildren.size; i++) { int status = cstatic_check(self, ast->children.data[i], sym, @@ -524,6 +552,58 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz return 0; } + }*/ + + if (ast->type == NODE_VARDECL + || ast->type == NODE_CONSTDECL + || ast->type == NODE_FUNDECL) + { + char const* name = ast->children.data[0]->value; + symentry* entry = symtable_find(sym, name); + if (entry && entry->scope >= sym->scope) + { + fprintf(stderr, "E(%d): '%s' is already defined.\n", + ast->lineno, + name); + exit(-1); + } + + type* ty = cstatic_resolve_new(self, sym, ast->children.data[1]); + + if (ty->base_type == TY_FUNCTION) + { + cstatic_push_fun(self, ast->children.data[0]->value, ty); + int status = + cstatic_check(self, ast->children.data[1], sym, msg, size); + cstatic_pop_fun(self); + + if (!status) { return status; } + } + else + { + int status = + cstatic_check(self, ast->children.data[1], sym, msg, size); + + if (!status) + { + return status; + } + } + + assert(ty); + + if (ast->type == NODE_VARDECL) + { + symtable_declare_mut(sym, name, ty, ast->children.data[1]); + } + else + { + symtable_declare(sym, name, ty, ast->children.data[1]); + } + + type_free(ty); + free(ty); + return 1; } // Arrays @@ -725,37 +805,6 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz return 1; } - if (ast->type == NODE_VARDECL - || ast->type == NODE_CONSTDECL - || ast->type == NODE_FUNDECL) - { - char const* name = ast->children.data[0]->value; - symentry* entry = symtable_find(sym, name); - - if (entry && entry->scope >= sym->scope) - { - fprintf(stderr, "E(%d): '%s' is already defined.\n", - ast->lineno, - name); - exit(-1); - } - - type* ty = cstatic_resolve_new(self, sym, ast->children.data[1]); - assert(ty); - - if (ast->type == NODE_VARDECL) - { - symtable_declare_mut(sym, name, ty, ast->children.data[1]); - } - else - { - symtable_declare(sym, name, ty, ast->children.data[1]); - } - - type_free(ty); - free(ty); - return 1; - } // Types Operations if (ast->type == NODE_MUL) @@ -997,6 +1046,36 @@ int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t siz } } + for (size_t i=0; ichildren.size; i++) + { + int status = cstatic_check(self, ast->children.data[i], + sym, + msg, size); + + if (!status) + { + return 0; + } + + } + + return 1; +} + +int cstatic_check_children(cstatic* self, node* lhs, symtable* sym, char* msg, size_t size) +{ + for (size_t i=0; ichildren.size; i++) + { + int status = cstatic_check(self, lhs->children.data[i], + sym, + msg, size); + + if (!status) + { + return 0; + } + } + return 1; } @@ -1169,7 +1248,7 @@ int cstatic_check_same_type(cstatic* self, node* lhs, node* rhs, symtable* sym, return 1; } -void cstatic_push_fun(cstatic* self, type* fun_ty) +void cstatic_push_fun(cstatic* self, char* name, type* fun_ty) { assert(self); assert(fun_ty); @@ -1179,11 +1258,18 @@ void cstatic_push_fun(cstatic* self, type* fun_ty) self->fun_types.capacity *= 2; self->fun_types.data = realloc( self->fun_types.data, - sizeof(type*) * self->fun_types.capacity + sizeof(fun_info*) * self->fun_types.capacity ); } + + self->fun_types.data[self->fun_types.size] = malloc(sizeof(fun_info)); + self->fun_types.data[self->fun_types.size]->ty = type_new_clone(fun_ty); + + if (name) + { + self->fun_types.data[self->fun_types.size]->name = str_new(name); + } - self->fun_types.data[self->fun_types.size] = type_new_clone(fun_ty); self->fun_types.size++; } @@ -1192,13 +1278,20 @@ void cstatic_pop_fun(cstatic* self) { assert(self->fun_types.size > 0); - type_free(self->fun_types.data[self->fun_types.size - 1]); + type_free(self->fun_types.data[self->fun_types.size - 1]->ty); + free(self->fun_types.data[self->fun_types.size - 1]->ty); + + if (self->fun_types.data[self->fun_types.size - 1]->name) + { + free(self->fun_types.data[self->fun_types.size - 1]->name); + } + free(self->fun_types.data[self->fun_types.size - 1]); self->fun_types.size--; } -type* cstatic_top_fun(cstatic* self) +fun_info* cstatic_top_fun(cstatic* self) { assert(self->fun_types.size > 0); return self->fun_types.data[self->fun_types.size - 1]; diff --git a/src/cstatic.h b/src/cstatic.h index e805ea9..8aca092 100644 --- a/src/cstatic.h +++ b/src/cstatic.h @@ -7,11 +7,16 @@ #define TYPE_END (-1) +typedef struct { + type* ty; + char* name; +} fun_info; + typedef struct { struct { size_t size; size_t capacity; - type** data; + fun_info** data; } fun_types; } cstatic; @@ -24,6 +29,8 @@ int cstatic_resolve_base_type(cstatic* self, symtable* sym, node* ast); type* cstatic_resolve_new(cstatic* self, symtable* sym, node* ast); int cstatic_check(cstatic* self, node* ast, symtable* sym, char* msg, size_t size); +int cstatic_check_children(cstatic* self, node* lhs, symtable* sym, char* msg, size_t size); + int cstatic_check_type(cstatic* self, node* lhs, symtable* sym, char* msg, size_t size, type* types, ...); @@ -37,7 +44,7 @@ int cstatic_check_same_type_ptr(cstatic* self, int cstatic_check_same_type(cstatic* self, node* lhs, node* rhs, symtable* sym, char* msg, size_t size); -void cstatic_push_fun(cstatic* self, type* fun_ty); +void cstatic_push_fun(cstatic* self, char* name, type* fun_ty); void cstatic_pop_fun(cstatic* self); -type* cstatic_top_fun(cstatic* self); +fun_info* cstatic_top_fun(cstatic* self); #endif diff --git a/src/vm.c b/src/vm.c index 2de56e9..db4536e 100644 --- a/src/vm.c +++ b/src/vm.c @@ -1150,7 +1150,11 @@ void vm_call(vm* self, int param) { assert(self); - // get function and parameters + vm v; + vm_init(&v); + + + // get arguments value* args[param]; for (int i=0; ival.fun_val; - - // create new ecxecution environment - vm v; - vm_init(&v); - for (int i=0; iprog); - value* ret = value_new_clone(v.stack.data[v.stack.size - 1]); - vm_push_value(self, ret); + // get calling function + value* function_val = vm_pop_value(self); + fun* function = function_val->val.fun_val; + vm_set(&v, param, function_val); + + // execute and finalize + vm_exec(&v, (program*) function->prog); + vm_push_value(self, value_new_clone(v.stack.data[v.stack.size - 1])); + + + // cleanup vm_free(&v); - + value_free(function_val); free(function_val); + for (int i=0; ipc++; } diff --git a/tests/test_fun.wuz b/tests/test_fun.wuz index 08993c4..47779d3 100644 --- a/tests/test_fun.wuz +++ b/tests/test_fun.wuz @@ -36,3 +36,13 @@ end assert 14 == e 7 +fun f (n as int) as int + if n == 0 + return 1 + end + + return n * f(n - 1) +end + +assert 120 == f 5 +