From fb41410797c30c6c5caa203f2179b6a398066a2c Mon Sep 17 00:00:00 2001 From: bog Date: Thu, 25 Apr 2024 18:47:45 +0200 Subject: [PATCH] :sparkles: formula to CNF converter. --- doc/grammar.bnf | 2 +- fol/__init__.py | 15 +++ fol/cnf.py | 213 ++++++++++++++++++++++++++++++++++++++++ fol/lexer.py | 10 +- fol/node.py | 70 +++++++++++++ fol/parser.py | 26 +---- fol/pretty.py | 26 +++++ fol/tests/test_cnf.py | 122 +++++++++++++++++++++++ fol/tests/test_subst.py | 32 ++++++ fol/tests/test_unify.py | 66 +++++++++++++ fol/unify.py | 29 ++++++ sine_patre.py | 2 +- 12 files changed, 585 insertions(+), 28 deletions(-) create mode 100644 fol/cnf.py create mode 100644 fol/node.py create mode 100644 fol/pretty.py create mode 100644 fol/tests/test_cnf.py create mode 100644 fol/tests/test_subst.py create mode 100644 fol/tests/test_unify.py create mode 100644 fol/unify.py diff --git a/doc/grammar.bnf b/doc/grammar.bnf index 4d66d6e..547428e 100644 --- a/doc/grammar.bnf +++ b/doc/grammar.bnf @@ -5,7 +5,7 @@ IMP ::= OR imp OR OR ::= AND (or AND)* AND ::= NOT (and NOT)* NOT ::= not? GROUP -GROUP ::= NOT | opar F cpar +GROUP ::= PRED | opar F cpar PRED ::= pred opar TERM (comma TERM)* cpar TERM ::= diff --git a/fol/__init__.py b/fol/__init__.py index 0967ba8..2f65b62 100644 --- a/fol/__init__.py +++ b/fol/__init__.py @@ -1,5 +1,12 @@ from fol import lexer from fol import parser +from fol import pretty +from fol import cnf +from fol.unify import unify + + +def p(text): + return parse(text) def parse(text): @@ -8,3 +15,11 @@ def parse(text): p = parser.Parser(lex) root = p.parse() return root + + +def term(text): + lex = lexer.Lexer() + lex.scan(text) + p = parser.Parser(lex) + root = p.parse_term() + return root diff --git a/fol/cnf.py b/fol/cnf.py new file mode 100644 index 0000000..794a222 --- /dev/null +++ b/fol/cnf.py @@ -0,0 +1,213 @@ +import fol + + +def elim_imp(f): + res = fol.node.Node(f.name, f.value) + + if f.name == 'IMP': + a = fol.node.Node('NOT') + a.add_child(f.children[0]) + res = fol.node.Node('OR') + res.add_child(elim_imp(a)) + res.add_child(elim_imp(f.children[1])) + return res + else: + for c in f.children: + res.add_child(elim_imp(c)) + + return res + + +def neg(f): + res = f.copy() + + if f.name == 'NOT' and f.children[0].name in ['OR', 'AND']: + p = fol.node.Node('NOT') + p.add_child(f.children[0].children[0]) + q = fol.node.Node('NOT') + q.add_child(f.children[0].children[1]) + + if f.children[0].name == 'OR': + res = fol.node.Node('AND') + res.add_child(neg(p)) + res.add_child(neg(q)) + return res + else: + res = fol.node.Node('OR') + res.add_child(neg(p)) + res.add_child(neg(q)) + return res + + return res + + +def normalize(f, values={}): + res = fol.node.Node(f.name, f.value) + + if f.name == 'VAR': + if f.value in values.keys(): + res.value = values[f.value] + else: + new_name = f'x{len(values)}' + values[res.value] = new_name + res.value = new_name + + return res + + for child in f.children: + res.add_child(normalize(child)) + + return res + + +def prenex(f): + def mkneg(n): + res = fol.node.Node('NOT') + res.add_child(n) + return res + + def inv(n): + if n.name == 'FORALL': + res = n.copy() + res.name = 'EXISTS' + return res + elif n.name == 'EXISTS': + res = n.copy() + res.name = 'FORALL' + return res + else: + return n.copy() + + res = fol.node.Node(f.name, f.value) + + if f.name == 'NOT': + arg = f.children[0] + + if arg.name in ['FORALL', 'EXISTS']: + res = inv(arg) + res.children[1] = mkneg(arg.children[1]) + return res + + elif f.name in ['AND', 'OR', 'IMP']: + def apply_prenex(lhs, rhs, op_name): + swapped = False + + if rhs.name in ['FORALL', 'EXISTS']: + lhs, rhs = rhs, lhs + swapped = True + + if lhs.name in ['FORALL', 'EXISTS']: + left = lhs.children[1] + right = rhs + if swapped: + left, right = right, left + var = lhs.children[0] + + op = fol.node.Node(op_name) + op.add_child(left) + op.add_child(right) + + res = fol.node.Node(lhs.name) + if op_name == 'IMP' and not swapped: + res.name = 'EXISTS' if lhs.name == 'FORALL' else 'FORALL' + res.add_child(var) + res.add_child(op) + return res + return f.copy() + + lhs = prenex(f.children[0]) + rhs = prenex(f.children[1]) + + return apply_prenex(lhs, rhs, f.name) + return f.copy() + + +def skolem(f): + def collect_exists(n): + res = [] + if n.name == 'EXISTS': + res.append(n) + for child in n.children: + res.extend(collect_exists(child)) + return res + + def context(n, exists_quant, ctx=[]): + if n.equals(exists_quant): + return ctx + + if n.name == 'FORALL': + ctx.append(n.children[0].value) + + for child in n.children: + context(child, target, ctx) + + return ctx + + res = f.copy() + targets = collect_exists(res) + + for target in targets: + idx = targets.index(target) + var_name = target.children[0].value + + subst = { + var_name: fol.term(f'S{idx}') + } + + ctx = context(f, target) + + if len(ctx) > 0: + args = ','.join(ctx) + subst = { + var_name: fol.term(f's{idx}({args})') + } + + res = res.subst(subst) + + return res.remove_quant('EXISTS') + + +def distrib(f): + if len(f.children) != 2: + return f + + lhs = f.children[0] + rhs = f.children[1] + + if f.name == 'OR' and rhs.name == 'AND': + a = distrib(lhs) + b = distrib(rhs.children[0]) + c = distrib(rhs.children[1]) + res = fol.node.Node('AND') + left = fol.node.Node('OR') + left.add_child(a) + left.add_child(b) + right = fol.node.Node('OR') + right.add_child(a) + right.add_child(c) + res.add_child(distrib(left)) + res.add_child(distrib(right)) + return res + + if f.name == 'OR' and lhs.name == 'AND': + a = distrib(rhs) + b = distrib(lhs.children[0]) + c = distrib(lhs.children[1]) + res = fol.node.Node('AND') + left = fol.node.Node('OR') + left.add_child(b) + left.add_child(a) + right = fol.node.Node('OR') + right.add_child(c) + right.add_child(a) + res.add_child(distrib(left)) + res.add_child(distrib(right)) + return res + + return f + + +def cnf(f): + return distrib( + skolem(prenex(normalize(neg(elim_imp(f))))) + ).remove_quant('FORALL') diff --git a/fol/lexer.py b/fol/lexer.py index 07c02da..06ef018 100644 --- a/fol/lexer.py +++ b/fol/lexer.py @@ -66,7 +66,10 @@ class Lexer: cursor = self.cursor value = '' - while cursor < len(self.text) and self.text[cursor].islower(): + while cursor < len(self.text) and ( + self.text[cursor].islower() + or self.text[cursor].isdigit() + ): value += self.text[cursor] cursor += 1 @@ -80,7 +83,10 @@ class Lexer: cursor = self.cursor value = '' - while cursor < len(self.text) and self.text[cursor].isupper(): + while cursor < len(self.text) and ( + self.text[cursor].isupper() + or self.text[cursor].isdigit() + ): value += self.text[cursor] cursor += 1 diff --git a/fol/node.py b/fol/node.py new file mode 100644 index 0000000..d082bb5 --- /dev/null +++ b/fol/node.py @@ -0,0 +1,70 @@ +class Node: + def __init__(self, name, value=None): + self.name = name + self.value = value + self.children = [] + + def add_child(self, child): + self.children.append(child) + + def __repr__(self): + return str(self) + + def __str__(self): + res = self.name + if self.value is not None: + res += f'[{self.value}]' + + if len(self.children) > 0: + res += '(' + sep = '' + + for child in self.children: + res += sep + str(child) + sep = ',' + res += ')' + return res + + def copy(self): + node = Node(self.name, self.value) + for child in self.children: + node.add_child(child.copy()) + return node + + def remove_quant(self, name): + if self.name == name: + itr = self + while itr.name == name: + itr = itr.children[1] + return itr + + node = Node(self.name, self.value) + for child in self.children: + if child.name == name: + node.add_child(child.children[1].remove_quant(name)) + else: + node.add_child(child.remove_quant(name)) + return node + + def subst(self, values): + node = Node(self.name, self.value) + if self.name == 'VAR' and self.value in values.keys(): + node = values[self.value] + else: + for child in self.children: + node.add_child(child.subst(values)) + + return node + + def equals(self, rhs): + if self.name != rhs.name or self.value != rhs.value: + return False + + if len(self.children) != len(rhs.children): + return False + + for i in range(len(self.children)): + if not self.children[i].equals(rhs.children[i]): + return False + + return True diff --git a/fol/parser.py b/fol/parser.py index af894d1..4540f38 100644 --- a/fol/parser.py +++ b/fol/parser.py @@ -1,26 +1,4 @@ -class Node: - def __init__(self, name, value=None): - self.name = name - self.value = value - self.children = [] - - def add_child(self, child): - self.children.append(child) - - def __str__(self): - res = self.name - if self.value is not None: - res += f'[{self.value}]' - - if len(self.children) > 0: - res += '(' - sep = '' - - for child in self.children: - res += sep + str(child) - sep = ',' - res += ')' - return res +from fol.node import Node class Parser: @@ -114,7 +92,7 @@ class Parser: def parse_not(self): if self.consume('NOT'): node = Node('NOT') - node.add_child(self.parse_pred()) + node.add_child(self.parse_group()) return node return self.parse_group() diff --git a/fol/pretty.py b/fol/pretty.py new file mode 100644 index 0000000..8a49e25 --- /dev/null +++ b/fol/pretty.py @@ -0,0 +1,26 @@ +import fol + + +def get(f): + if f.name == 'EXISTS': + return f'\\exist {get(f.children[0])}({get(f.children[1])})' + + if f.name == 'NOT': + return '~' + get(f.children[0]) + + if f.name == 'AND': + return '(' + " & ".join([get(c) for c in f.children]) + ')' + + if f.name == 'OR': + return '(' + " | ".join([get(c) for c in f.children]) + ')' + + if f.name in ['PRED', 'FUN']: + return f.value + "(" + ", ".join([get(c) for c in f.children]) + ")" + + if f.name in ['VAR', 'CONST']: + return f.value + + if f.name == 'IMP': + return '(' + get(f.children[0]) + ' -> ' + get(f.children[1]) + ')' + + raise Exception(f'cannot print f = {f.name}') diff --git a/fol/tests/test_cnf.py b/fol/tests/test_cnf.py new file mode 100644 index 0000000..b1d40bf --- /dev/null +++ b/fol/tests/test_cnf.py @@ -0,0 +1,122 @@ +import pytest +import fol + + +@pytest.mark.parametrize(['imp', 'res'], [ + ('\\exists x (Pre(y) -> Post(y))', + '\\exists x (~Pre(y) | Post(y))'), + ('Pre(x) -> Post(x)', '~Pre(x) | Post(x)'), + ('Pre(x) -> (Aux(y) -> Post(x))', + '~Pre(x) | (~Aux(y) | Post(x))'), + ('(Pre(x) & Aux(y)) -> Post(x)', '~(Pre(x) & Aux(y)) | Post(x)') +]) +def test_cnf_imp(imp, res): + assert fol.cnf.elim_imp(fol.p(imp)).equals(fol.p(res)) + + +@pytest.mark.parametrize(['f', 'res'], [ + ('~(Pre(x) | Post(y))', '~Pre(x) & ~Post(y)'), + ('~(Pre(x) & Post(y))', '~Pre(x) | ~Post(y)'), + ('~(Pre(x) & (Post(y) | Post(z)))', + '~Pre(x) | (~Post(y) & ~Post(z))'), +]) +def test_cnf_neg(f, res): + assert fol.cnf.neg(fol.p(f)).equals(fol.p(res)) + + +@pytest.mark.parametrize(['f', 'res'], [ + ('Pre(x) -> Post(x)', 'Pre(x0) -> Post(x0)'), + ('Pre(x) -> Post(y)', 'Pre(x0) -> Post(x1)'), + ('Pre(x, y) -> Post(y, x)', 'Pre(x0, x1) -> Post(x1, x0)'), + ('Pre(f(x)) -> Post(g(g(y), x))', + 'Pre(f(x0)) -> Post(g(g(x1), x0))'), +]) +def test_cnf_normalize(f, res): + assert fol.cnf.normalize(fol.p(f)).equals(fol.p(res)) + + +@pytest.mark.parametrize(['f', 'res'], [ + # FOR ALL + ('~(\\forall x Pre(x))', '\\exists x ~Pre(x)'), + + ('(\\forall x Pre(x)) & Post(x)', + '\\forall x (Pre(x) & Post(x))'), + + ('(\\forall x Pre(x)) | Post(x)', + '\\forall x (Pre(x) | Post(x))'), + + ('(\\forall x Pre(x)) -> Post(x)', + '\\exists x (Pre(x) -> Post(x))'), + + ('Pre(x) & (\\forall x Post(x))', + '\\forall x (Pre(x) & Post(x))'), + + ('Pre(x) | (\\forall x Post(x))', + '\\forall x (Pre(x) | Post(x))'), + + ('Pre(x) -> (\\forall x Post(x))', + '\\forall x (Pre(x) -> Post(x))'), + # EXISTS + ('~(\\exists x Pre(x))', '\\forall x ~Pre(x)'), + + ('(\\exists x Pre(x)) & Post(x)', + '\\exists x (Pre(x) & Post(x))'), + + ('(\\exists x Pre(x)) | Post(x)', + '\\exists x (Pre(x) | Post(x))'), + + ('(\\exists x Pre(x)) -> Post(x)', + '\\forall x (Pre(x) -> Post(x))'), + + ('Pre(x) & (\\exists x Post(x))', + '\\exists x (Pre(x) & Post(x))'), + + ('Pre(x) | (\\exists x Post(x))', + '\\exists x (Pre(x) | Post(x))'), + + ('Pre(x) -> (\\exists x Post(x))', + '\\exists x (Pre(x) -> Post(x))'), + +]) +def test_cnf_prenex(f, res): + assert fol.cnf.prenex(fol.p(f)).equals(fol.p(res)) + + +@pytest.mark.parametrize(['imp', 'res'], [ + ('\\exists x (Pre(x) -> Post(x))', 'Pre(S0) -> Post(S0)'), + + ('\\exists x (\\exists y (Pre(x) -> Post(y)))', + 'Pre(S0) -> Post(S1)'), + + ('\\forall a (\\exists x (Pre(a, x) -> Post(x)))', + '\\forall a (Pre(a, s0(a)) -> Post(s0(a)))'), + + ('\\forall a (\\forall b (\\exists x (Pre(a, x) -> Post(x))))', + '\\forall a (\\forall b (Pre(a, s0(a, b)) -> Post(s0(a, b))))'), + + ('\\forall a (\\exists x (\\forall b (Pre(x) -> Post(x))))', + '\\forall a (\\forall b (Pre(s0(a)) -> Post(s0(a))))'), + + ('\\forall a (\\exists x (\\forall b (Pre(x) -> Post(x))))', + '\\forall a (\\forall b (Pre(s0(a)) -> Post(s0(a))))'), +]) +def test_cnf_skolem(imp, res): + assert fol.cnf.skolem(fol.p(imp)).equals(fol.p(res)) + + +@pytest.mark.parametrize(['f', 'oracle'], [ + ('Pre(x) | (Aux(x) & Post(x))', + '(Pre(x) | Aux(x)) & (Pre(x) | Post(x))'), + ('(Aux(x) & Post(x)) | Pre(x)', + '(Aux(x) | Pre(x)) & (Post(x) | Pre(x))') +]) +def test_distrib(oracle, f): + assert fol.cnf.distrib(fol.p(f)).equals(fol.p(oracle)) + + +@pytest.mark.parametrize(['f', 'oracle'], [ + ('Friend(x, y) -> Friend(y, x)', + '~Friend(x0, x1) | Friend(x1, x0)') +]) +def test_cnf(oracle, f): + assert fol.cnf.cnf(fol.p(f)).equals(fol.p(oracle)) diff --git a/fol/tests/test_subst.py b/fol/tests/test_subst.py new file mode 100644 index 0000000..21098d4 --- /dev/null +++ b/fol/tests/test_subst.py @@ -0,0 +1,32 @@ +import fol + + +def test_empty_subst(): + root = fol.parse('Happy(x)').subst({}) + assert 'PRED[Happy](VAR[x])' == str(root) + + +def test_one_var_subst(): + root = fol.parse('Happy(x)').subst({'x': fol.term('z')}) + assert 'PRED[Happy](VAR[z])' == str(root) + + +def test_one_function_subst(): + root = fol.parse('Happy(x)').subst({'x': fol.term('f(z)')}) + assert 'PRED[Happy](FUN[f](VAR[z]))' == str(root) + + +def test_ambiguous_subst(): + root = fol.parse('Happy(x)').subst({ + 'x': fol.term('y'), + 'y': fol.term('z'), + }) + assert 'PRED[Happy](VAR[y])' == str(root) + + +def test_ambiguous_reverse_subst(): + root = fol.parse('Happy(x, y)').subst({ + 'y': fol.term('z'), + 'x': fol.term('y'), + }) + assert 'PRED[Happy](VAR[y],VAR[z])' == str(root) diff --git a/fol/tests/test_unify.py b/fol/tests/test_unify.py new file mode 100644 index 0000000..d1430b0 --- /dev/null +++ b/fol/tests/test_unify.py @@ -0,0 +1,66 @@ +import fol + + +def test_unify_simple(): + subst = fol.unify( + fol.p('Hello(x)'), + fol.p('Hello(ALICE)') + ) + assert subst is not None + assert len(subst) == 1 + assert str(subst['x']) == 'CONST[ALICE]' + + +def test_unify_fun(): + subst = fol.unify( + fol.p('Hello(x, f(x))'), + fol.p('Hello(ALICE, f(ALICE))') + ) + assert subst is not None + assert len(subst) == 1 + assert str(subst['x']) == 'CONST[ALICE]' + + +def test_unify_two_vars(): + subst = fol.unify( + fol.p('Hello(z, f(x), z)'), + fol.p('Hello(ALICE, f(BOB), ALICE)') + ) + + assert subst is not None + assert len(subst) == 2 + assert str(subst['x']) == 'CONST[BOB]' + assert str(subst['z']) == 'CONST[ALICE]' + + +def test_unify_fun_var(): + subst = fol.unify( + fol.p('Hello(z, f(x), z)'), + fol.p('Hello(ALICE, a, ALICE)') + ) + + assert subst is not None + assert len(subst) == 2 + assert str(subst['a']) == 'FUN[f](VAR[x])' + assert str(subst['z']) == 'CONST[ALICE]' + + +def test_unify_formula(): + subst = fol.unify( + fol.p('Friend(x, y) -> Friend(y, x)'), + fol.p('Friend(ALICE, y) -> Friend(BOB, x)') + ) + + assert subst is not None + assert len(subst) == 2 + assert str(subst['x']) == 'CONST[ALICE]' + assert str(subst['y']) == 'CONST[BOB]' + + +def test_unify_formula_contradiction(): + subst = fol.unify( + fol.p('Friend(x, y) -> Friend(y, x)'), + fol.p('Friend(ALICE, y) -> Friend(BOB, CLARA)') + ) + + assert subst is None diff --git a/fol/unify.py b/fol/unify.py new file mode 100644 index 0000000..ea64e76 --- /dev/null +++ b/fol/unify.py @@ -0,0 +1,29 @@ +def unify(lhs, rhs): + res = {} + try: + unify_res(lhs, rhs, res) + except Exception: + return None + return res + + +def unify_res(lhs, rhs, result): + if lhs.name == 'VAR': + unify_var(lhs, rhs, result) + elif rhs.name == 'VAR': + unify_var(rhs, lhs, result) + else: + left = lhs.children + right = rhs.children + if len(left) == len(right): + for i in range(len(left)): + unify_res(left[i], right[i], result) + + +def unify_var(var, rhs, result): + if rhs.name in ['VAR', 'CONST', 'FUN'] and rhs.value != var.value: + if var.value in result.keys() \ + and not result[var.value].equals(rhs): + raise Exception('') + else: + result[var.value] = rhs diff --git a/sine_patre.py b/sine_patre.py index 19fc553..ef005ec 100644 --- a/sine_patre.py +++ b/sine_patre.py @@ -1,2 +1,2 @@ if __name__ == '__main__': - print('Sine Patre') + print('Hello, World!')