diff --git a/.gitignore b/.gitignore index ed8ebf5..37a576f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -__pycache__ \ No newline at end of file +__pycache__ +*.prof \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3d6af1e --- /dev/null +++ b/Makefile @@ -0,0 +1,5 @@ +.PHONY: prof + +prof: + python -m cProfile -o sine_patre.prof sine_patre.py + snakeviz sine_patre.prof diff --git a/fol/__init__.py b/fol/__init__.py index 2f65b62..793d775 100644 --- a/fol/__init__.py +++ b/fol/__init__.py @@ -1,5 +1,6 @@ from fol import lexer from fol import parser +from fol import kb from fol import pretty from fol import cnf from fol.unify import unify diff --git a/fol/cnf.py b/fol/cnf.py index 794a222..34d2cc0 100644 --- a/fol/cnf.py +++ b/fol/cnf.py @@ -1,12 +1,28 @@ import fol +class Counter: + def __init__(self): + self.values = {} + self.counter = 0 + + def get(self, name): + if name in self.values.keys(): + return self.values[name] + self.values[name] = self.counter + self.counter += 1 + return self.counter - 1 + + def begin(self): + self.values = {} + + def elim_imp(f): res = fol.node.Node(f.name, f.value) if f.name == 'IMP': a = fol.node.Node('NOT') - a.add_child(f.children[0]) + a.add_child(elim_imp(f.children[0])) res = fol.node.Node('OR') res.add_child(elim_imp(a)) res.add_child(elim_imp(f.children[1])) @@ -23,9 +39,9 @@ def neg(f): if f.name == 'NOT' and f.children[0].name in ['OR', 'AND']: p = fol.node.Node('NOT') - p.add_child(f.children[0].children[0]) + p.add_child(neg(f.children[0].children[0])) q = fol.node.Node('NOT') - q.add_child(f.children[0].children[1]) + q.add_child(neg(f.children[0].children[1])) if f.children[0].name == 'OR': res = fol.node.Node('AND') @@ -38,24 +54,23 @@ def neg(f): res.add_child(neg(q)) return res + res = fol.node.Node(f.name, f.value) + + for c in f.children: + res.add_child(neg(c)) + return res -def normalize(f, values={}): +def normalize(f, counter): res = fol.node.Node(f.name, f.value) if f.name == 'VAR': - if f.value in values.keys(): - res.value = values[f.value] - else: - new_name = f'x{len(values)}' - values[res.value] = new_name - res.value = new_name - + res.value = f'x{counter.get(f.value)}' return res for child in f.children: - res.add_child(normalize(child)) + res.add_child(normalize(child, counter)) return res @@ -207,7 +222,12 @@ def distrib(f): return f -def cnf(f): - return distrib( - skolem(prenex(normalize(neg(elim_imp(f))))) +def cnf(f, counter=None): + if counter is None: + counter = Counter() + + g = distrib( + skolem(prenex(normalize(neg(elim_imp(f)), counter))) ).remove_quant('FORALL') + + return g diff --git a/fol/kb.py b/fol/kb.py new file mode 100644 index 0000000..53ffb8a --- /dev/null +++ b/fol/kb.py @@ -0,0 +1,168 @@ +import fol + + +class Kb: + def __init__(self): + self.base = [] + self.counter = fol.cnf.Counter() + + def make_f(self, h): + self.counter.begin() + f = h + if type(f) is str: + f = fol.p(f) + self.counter.begin() + f = fol.cnf.cnf(f, self.counter) + + if f.name == 'OR': + f = f.list_of('OR') + else: + f = [f] + + return f + + def make_req(self, h): + f = h + if type(f) is str: + f = fol.p(f) + + if f.name == 'OR': + f = f.list_of('OR') + else: + f = [f] + + return f + + def check_request(self, request): + for r in request: + if r.name not in ['NOT', 'PRED']: + raise Exception(f'invalid statement {repr(r)}') + + def conds(self, statement): + result = [] + for f in statement: + if f.name == 'NOT': + result.append(f.children[0]) + return result + + def conclusion(self, statement): + for f in statement: + if f.name != 'NOT': + return f + return None + + @property + def rules(self): + res = [] + for r in self.base: + if len(r) > 1: + res.append(r) + return res + + @property + def facts(self): + res = [] + for r in self.base: + if len(r) == 1: + res.append(r[0]) + return res + + def exists(self, f): + h = self.make_f(f) + for g in self.base: + if self.clause_equals(h, g): + return True + return False + + def clause_equals(self, left, right): + if len(left) != len(right): + return False + + for element in left: + if not self.clause_in(element, right): + return False + + for element in right: + if not self.clause_in(element, left): + return False + + return True + + def clause_in(self, clause, lst): + for other in lst: + same = fol.unify(clause, other) is not None + if same: + return True + return False + + def tell(self, request): + req = self.make_f(request) + self.check_request(req) + self.base.append(req) + + def ask(self, request): + for i in range(20): + n = len(self.base) + self.update() + if n == len(self.base): + break + + req = self.make_req(request) + results = [] + + for f in self.base: + s = fol.unify(f, req) + if s is not None: + results.append(s) + + return results + + def update(self): + for rule in self.rules: + conclusion = self.conclusion(rule) + solution = [] + solutions = [] + self.solve(rule, self.conds(rule), self.facts, solution, solutions) + for sol in solutions: + concl = conclusion.subst(sol) + if not self.exists(concl): + self.base.append([concl]) + return + + def merge_substs(self, substs): + res = {} + for s in substs: + for k in s.keys(): + if k in res.keys() and not res[k].equals(s[k]): + return None + res[k] = s[k] + return res + + def solve(self, rule, conds, facts, solution, solutions): + if len(conds) == 0: + substs = [] + + for sol in solution: + s = fol.unify(sol[0], sol[1]) + if s is not None: + substs.append(s) + + s = self.merge_substs(substs) + + if s is not None: + solutions.append(s) + return + + for cond in conds: + for fact in facts: + solution.append((cond, fact)) + + self.solve( + rule, + [c for c in conds if c != cond], + [f for f in facts if f != fact], + solution, + solutions + ) + + solution.pop() diff --git a/fol/node.py b/fol/node.py index d082bb5..7730fca 100644 --- a/fol/node.py +++ b/fol/node.py @@ -1,3 +1,6 @@ +import fol + + class Node: def __init__(self, name, value=None): self.name = name @@ -8,7 +11,7 @@ class Node: self.children.append(child) def __repr__(self): - return str(self) + return fol.pretty.get(self) def __str__(self): res = self.name @@ -68,3 +71,32 @@ class Node: return False return True + + def list_of(self, name): + result = [] + + if self.name == name: + for c in self.children: + if c.name != name: + result.append(c) + result.extend(c.list_of(name)) + + return result + + def flatten(self): + result = self.list_of(self.name) + print(result) + root = fol.node.Node(self.name, self.value) + root.children = result + return root + + def find_by_name(self, name): + if self.name == name: + return [self] + + res = [] + + for c in self.children: + res.extend(c.find_by_name(name)) + + return res diff --git a/fol/pretty.py b/fol/pretty.py index 8a49e25..e5cc949 100644 --- a/fol/pretty.py +++ b/fol/pretty.py @@ -1,6 +1,3 @@ -import fol - - def get(f): if f.name == 'EXISTS': return f'\\exist {get(f.children[0])}({get(f.children[1])})' diff --git a/fol/tests/test_cnf.py b/fol/tests/test_cnf.py index b1d40bf..e68e8b5 100644 --- a/fol/tests/test_cnf.py +++ b/fol/tests/test_cnf.py @@ -32,7 +32,8 @@ def test_cnf_neg(f, res): 'Pre(f(x0)) -> Post(g(g(x1), x0))'), ]) def test_cnf_normalize(f, res): - assert fol.cnf.normalize(fol.p(f)).equals(fol.p(res)) + counter = fol.cnf.Counter() + assert fol.cnf.normalize(fol.p(f), counter).equals(fol.p(res)) @pytest.mark.parametrize(['f', 'res'], [ @@ -116,7 +117,12 @@ def test_distrib(oracle, f): @pytest.mark.parametrize(['f', 'oracle'], [ ('Friend(x, y) -> Friend(y, x)', - '~Friend(x0, x1) | Friend(x1, x0)') + '~Friend(x0, x1) | Friend(x1, x0)'), + + ('(Friend(x, y) & Friend(y, z)) -> Friend(x, z)', + '(~Friend(x0, x1) | ~Friend(x1, x2)) | Friend(x0, x2)') ]) def test_cnf(oracle, f): + print(fol.pretty.get(fol.p(f))) + print(fol.pretty.get(fol.cnf.cnf(fol.p(f)))) assert fol.cnf.cnf(fol.p(f)).equals(fol.p(oracle)) diff --git a/fol/tests/test_kb.py b/fol/tests/test_kb.py new file mode 100644 index 0000000..3650580 --- /dev/null +++ b/fol/tests/test_kb.py @@ -0,0 +1,11 @@ +import fol + + +def test_kb_simple(): + kb = fol.kb.Kb() + kb.tell('Friend(ALICE, BOB)') + + res = kb.ask('Friend(ALICE, x)') + + assert len(res) == 1 + assert res[0]['x'].value == 'BOB' diff --git a/fol/tests/test_unify.py b/fol/tests/test_unify.py index d1430b0..049493e 100644 --- a/fol/tests/test_unify.py +++ b/fol/tests/test_unify.py @@ -6,16 +6,37 @@ def test_unify_simple(): fol.p('Hello(x)'), fol.p('Hello(ALICE)') ) + assert subst is not None assert len(subst) == 1 assert str(subst['x']) == 'CONST[ALICE]' +def test_unify_equals(): + subst = fol.unify( + fol.p('Hello(ALICE)'), + fol.p('Hello(ALICE)') + ) + + assert subst is not None + assert len(subst) == 0 + + +def test_unify_different_predicate_name(): + subst = fol.unify( + fol.p('Hello(x)'), + fol.p('World(ALICE)') + ) + + assert subst is None + + def test_unify_fun(): subst = fol.unify( fol.p('Hello(x, f(x))'), fol.p('Hello(ALICE, f(ALICE))') ) + assert subst is not None assert len(subst) == 1 assert str(subst['x']) == 'CONST[ALICE]' diff --git a/fol/unify.py b/fol/unify.py index ea64e76..78cc37f 100644 --- a/fol/unify.py +++ b/fol/unify.py @@ -1,29 +1,84 @@ +import fol + + def unify(lhs, rhs): - res = {} - try: - unify_res(lhs, rhs, res) - except Exception: + s = {} + return _unify(lhs, rhs, s) + + +def _unify(lhs, rhs, s): + if s is None: return None - return res + elif is_equal(lhs, rhs): + return s + elif is_var(lhs): + return unify_var(lhs, rhs, s) + elif is_var(rhs): + return unify_var(rhs, lhs, s) + elif is_compound(lhs) and is_compound(rhs) and ( + lhs.name == rhs.name and lhs.value == rhs.value + ): - -def unify_res(lhs, rhs, result): - if lhs.name == 'VAR': - unify_var(lhs, rhs, result) - elif rhs.name == 'VAR': - unify_var(rhs, lhs, result) - else: - left = lhs.children - right = rhs.children - if len(left) == len(right): - for i in range(len(left)): - unify_res(left[i], right[i], result) - - -def unify_var(var, rhs, result): - if rhs.name in ['VAR', 'CONST', 'FUN'] and rhs.value != var.value: - if var.value in result.keys() \ - and not result[var.value].equals(rhs): - raise Exception('') + return _unify( + lhs.children, + rhs.children, + s + ) + elif is_list(lhs) and is_list(rhs): + if len(lhs) != len(rhs): + return None + elif len(lhs) == 0: + return s + elif len(lhs) == 1: + return _unify(lhs[0], rhs[0], s) else: - result[var.value] = rhs + return _unify(lhs[1:], rhs[1:], _unify(lhs[0], rhs[0], s)) + else: + return None + + +def unify_var(var, x, s): + if var.value in s.keys(): + return _unify(s[var.value], x, s) + elif x in s.keys(): + return _unify(var, s[x], s) + elif occur_check(var, x): + return None + else: + s[var.value] = x + return s + + +def occur_check(var, x): + if var.name == x.name and var.value == x.value: + return True + + for child in x.children: + if occur_check(var, child): + return True + + return False + + +def is_compound(f): + if type(f) is fol.node.Node: + return len(f.children) > 0 + return False + + +def is_equal(lhs, rhs): + if type(lhs) is str and type(rhs) is str: + return lhs == rhs + + if type(lhs) is fol.node.Node and type(rhs) is fol.node.Node: + return lhs.equals(rhs) + + return False + + +def is_var(f): + return type(f) is fol.node.Node and f.name == 'VAR' + + +def is_list(value): + return type(value) is list diff --git a/sine_patre.py b/sine_patre.py index ef005ec..a9154e7 100644 --- a/sine_patre.py +++ b/sine_patre.py @@ -1,2 +1,13 @@ +import fol + if __name__ == '__main__': - print('Hello, World!') + kb = fol.kb.Kb() + try: + kb.tell('Friend(ALICE, BOB)') + kb.tell('Friend(BOB, CLAIRE)') + kb.tell('Friend(x, y) -> Friend(y, x)') + kb.tell('(Friend(x, y) & Friend(y, z)) -> Friend(x, z)') + + print(kb.ask('Friend(ALICE, x)')) + except KeyboardInterrupt: + pass