forward chaining v0.

main
bog 2024-04-28 18:23:47 +02:00
parent fb41410797
commit bfd64b20b6
12 changed files with 376 additions and 48 deletions

3
.gitignore vendored
View File

@ -1 +1,2 @@
__pycache__ __pycache__
*.prof

5
Makefile Normal file
View File

@ -0,0 +1,5 @@
.PHONY: prof
prof:
python -m cProfile -o sine_patre.prof sine_patre.py
snakeviz sine_patre.prof

View File

@ -1,5 +1,6 @@
from fol import lexer from fol import lexer
from fol import parser from fol import parser
from fol import kb
from fol import pretty from fol import pretty
from fol import cnf from fol import cnf
from fol.unify import unify from fol.unify import unify

View File

@ -1,12 +1,28 @@
import fol import fol
class Counter:
def __init__(self):
self.values = {}
self.counter = 0
def get(self, name):
if name in self.values.keys():
return self.values[name]
self.values[name] = self.counter
self.counter += 1
return self.counter - 1
def begin(self):
self.values = {}
def elim_imp(f): def elim_imp(f):
res = fol.node.Node(f.name, f.value) res = fol.node.Node(f.name, f.value)
if f.name == 'IMP': if f.name == 'IMP':
a = fol.node.Node('NOT') a = fol.node.Node('NOT')
a.add_child(f.children[0]) a.add_child(elim_imp(f.children[0]))
res = fol.node.Node('OR') res = fol.node.Node('OR')
res.add_child(elim_imp(a)) res.add_child(elim_imp(a))
res.add_child(elim_imp(f.children[1])) res.add_child(elim_imp(f.children[1]))
@ -23,9 +39,9 @@ def neg(f):
if f.name == 'NOT' and f.children[0].name in ['OR', 'AND']: if f.name == 'NOT' and f.children[0].name in ['OR', 'AND']:
p = fol.node.Node('NOT') p = fol.node.Node('NOT')
p.add_child(f.children[0].children[0]) p.add_child(neg(f.children[0].children[0]))
q = fol.node.Node('NOT') q = fol.node.Node('NOT')
q.add_child(f.children[0].children[1]) q.add_child(neg(f.children[0].children[1]))
if f.children[0].name == 'OR': if f.children[0].name == 'OR':
res = fol.node.Node('AND') res = fol.node.Node('AND')
@ -38,24 +54,23 @@ def neg(f):
res.add_child(neg(q)) res.add_child(neg(q))
return res return res
res = fol.node.Node(f.name, f.value)
for c in f.children:
res.add_child(neg(c))
return res return res
def normalize(f, values={}): def normalize(f, counter):
res = fol.node.Node(f.name, f.value) res = fol.node.Node(f.name, f.value)
if f.name == 'VAR': if f.name == 'VAR':
if f.value in values.keys(): res.value = f'x{counter.get(f.value)}'
res.value = values[f.value]
else:
new_name = f'x{len(values)}'
values[res.value] = new_name
res.value = new_name
return res return res
for child in f.children: for child in f.children:
res.add_child(normalize(child)) res.add_child(normalize(child, counter))
return res return res
@ -207,7 +222,12 @@ def distrib(f):
return f return f
def cnf(f): def cnf(f, counter=None):
return distrib( if counter is None:
skolem(prenex(normalize(neg(elim_imp(f))))) counter = Counter()
g = distrib(
skolem(prenex(normalize(neg(elim_imp(f)), counter)))
).remove_quant('FORALL') ).remove_quant('FORALL')
return g

168
fol/kb.py Normal file
View File

@ -0,0 +1,168 @@
import fol
class Kb:
def __init__(self):
self.base = []
self.counter = fol.cnf.Counter()
def make_f(self, h):
self.counter.begin()
f = h
if type(f) is str:
f = fol.p(f)
self.counter.begin()
f = fol.cnf.cnf(f, self.counter)
if f.name == 'OR':
f = f.list_of('OR')
else:
f = [f]
return f
def make_req(self, h):
f = h
if type(f) is str:
f = fol.p(f)
if f.name == 'OR':
f = f.list_of('OR')
else:
f = [f]
return f
def check_request(self, request):
for r in request:
if r.name not in ['NOT', 'PRED']:
raise Exception(f'invalid statement {repr(r)}')
def conds(self, statement):
result = []
for f in statement:
if f.name == 'NOT':
result.append(f.children[0])
return result
def conclusion(self, statement):
for f in statement:
if f.name != 'NOT':
return f
return None
@property
def rules(self):
res = []
for r in self.base:
if len(r) > 1:
res.append(r)
return res
@property
def facts(self):
res = []
for r in self.base:
if len(r) == 1:
res.append(r[0])
return res
def exists(self, f):
h = self.make_f(f)
for g in self.base:
if self.clause_equals(h, g):
return True
return False
def clause_equals(self, left, right):
if len(left) != len(right):
return False
for element in left:
if not self.clause_in(element, right):
return False
for element in right:
if not self.clause_in(element, left):
return False
return True
def clause_in(self, clause, lst):
for other in lst:
same = fol.unify(clause, other) is not None
if same:
return True
return False
def tell(self, request):
req = self.make_f(request)
self.check_request(req)
self.base.append(req)
def ask(self, request):
for i in range(20):
n = len(self.base)
self.update()
if n == len(self.base):
break
req = self.make_req(request)
results = []
for f in self.base:
s = fol.unify(f, req)
if s is not None:
results.append(s)
return results
def update(self):
for rule in self.rules:
conclusion = self.conclusion(rule)
solution = []
solutions = []
self.solve(rule, self.conds(rule), self.facts, solution, solutions)
for sol in solutions:
concl = conclusion.subst(sol)
if not self.exists(concl):
self.base.append([concl])
return
def merge_substs(self, substs):
res = {}
for s in substs:
for k in s.keys():
if k in res.keys() and not res[k].equals(s[k]):
return None
res[k] = s[k]
return res
def solve(self, rule, conds, facts, solution, solutions):
if len(conds) == 0:
substs = []
for sol in solution:
s = fol.unify(sol[0], sol[1])
if s is not None:
substs.append(s)
s = self.merge_substs(substs)
if s is not None:
solutions.append(s)
return
for cond in conds:
for fact in facts:
solution.append((cond, fact))
self.solve(
rule,
[c for c in conds if c != cond],
[f for f in facts if f != fact],
solution,
solutions
)
solution.pop()

View File

@ -1,3 +1,6 @@
import fol
class Node: class Node:
def __init__(self, name, value=None): def __init__(self, name, value=None):
self.name = name self.name = name
@ -8,7 +11,7 @@ class Node:
self.children.append(child) self.children.append(child)
def __repr__(self): def __repr__(self):
return str(self) return fol.pretty.get(self)
def __str__(self): def __str__(self):
res = self.name res = self.name
@ -68,3 +71,32 @@ class Node:
return False return False
return True return True
def list_of(self, name):
result = []
if self.name == name:
for c in self.children:
if c.name != name:
result.append(c)
result.extend(c.list_of(name))
return result
def flatten(self):
result = self.list_of(self.name)
print(result)
root = fol.node.Node(self.name, self.value)
root.children = result
return root
def find_by_name(self, name):
if self.name == name:
return [self]
res = []
for c in self.children:
res.extend(c.find_by_name(name))
return res

View File

@ -1,6 +1,3 @@
import fol
def get(f): def get(f):
if f.name == 'EXISTS': if f.name == 'EXISTS':
return f'\\exist {get(f.children[0])}({get(f.children[1])})' return f'\\exist {get(f.children[0])}({get(f.children[1])})'

View File

@ -32,7 +32,8 @@ def test_cnf_neg(f, res):
'Pre(f(x0)) -> Post(g(g(x1), x0))'), 'Pre(f(x0)) -> Post(g(g(x1), x0))'),
]) ])
def test_cnf_normalize(f, res): def test_cnf_normalize(f, res):
assert fol.cnf.normalize(fol.p(f)).equals(fol.p(res)) counter = fol.cnf.Counter()
assert fol.cnf.normalize(fol.p(f), counter).equals(fol.p(res))
@pytest.mark.parametrize(['f', 'res'], [ @pytest.mark.parametrize(['f', 'res'], [
@ -116,7 +117,12 @@ def test_distrib(oracle, f):
@pytest.mark.parametrize(['f', 'oracle'], [ @pytest.mark.parametrize(['f', 'oracle'], [
('Friend(x, y) -> Friend(y, x)', ('Friend(x, y) -> Friend(y, x)',
'~Friend(x0, x1) | Friend(x1, x0)') '~Friend(x0, x1) | Friend(x1, x0)'),
('(Friend(x, y) & Friend(y, z)) -> Friend(x, z)',
'(~Friend(x0, x1) | ~Friend(x1, x2)) | Friend(x0, x2)')
]) ])
def test_cnf(oracle, f): def test_cnf(oracle, f):
print(fol.pretty.get(fol.p(f)))
print(fol.pretty.get(fol.cnf.cnf(fol.p(f))))
assert fol.cnf.cnf(fol.p(f)).equals(fol.p(oracle)) assert fol.cnf.cnf(fol.p(f)).equals(fol.p(oracle))

11
fol/tests/test_kb.py Normal file
View File

@ -0,0 +1,11 @@
import fol
def test_kb_simple():
kb = fol.kb.Kb()
kb.tell('Friend(ALICE, BOB)')
res = kb.ask('Friend(ALICE, x)')
assert len(res) == 1
assert res[0]['x'].value == 'BOB'

View File

@ -6,16 +6,37 @@ def test_unify_simple():
fol.p('Hello(x)'), fol.p('Hello(x)'),
fol.p('Hello(ALICE)') fol.p('Hello(ALICE)')
) )
assert subst is not None assert subst is not None
assert len(subst) == 1 assert len(subst) == 1
assert str(subst['x']) == 'CONST[ALICE]' assert str(subst['x']) == 'CONST[ALICE]'
def test_unify_equals():
subst = fol.unify(
fol.p('Hello(ALICE)'),
fol.p('Hello(ALICE)')
)
assert subst is not None
assert len(subst) == 0
def test_unify_different_predicate_name():
subst = fol.unify(
fol.p('Hello(x)'),
fol.p('World(ALICE)')
)
assert subst is None
def test_unify_fun(): def test_unify_fun():
subst = fol.unify( subst = fol.unify(
fol.p('Hello(x, f(x))'), fol.p('Hello(x, f(x))'),
fol.p('Hello(ALICE, f(ALICE))') fol.p('Hello(ALICE, f(ALICE))')
) )
assert subst is not None assert subst is not None
assert len(subst) == 1 assert len(subst) == 1
assert str(subst['x']) == 'CONST[ALICE]' assert str(subst['x']) == 'CONST[ALICE]'

View File

@ -1,29 +1,84 @@
import fol
def unify(lhs, rhs): def unify(lhs, rhs):
res = {} s = {}
try: return _unify(lhs, rhs, s)
unify_res(lhs, rhs, res)
except Exception:
def _unify(lhs, rhs, s):
if s is None:
return None return None
return res elif is_equal(lhs, rhs):
return s
elif is_var(lhs):
return unify_var(lhs, rhs, s)
elif is_var(rhs):
return unify_var(rhs, lhs, s)
elif is_compound(lhs) and is_compound(rhs) and (
lhs.name == rhs.name and lhs.value == rhs.value
):
return _unify(
def unify_res(lhs, rhs, result): lhs.children,
if lhs.name == 'VAR': rhs.children,
unify_var(lhs, rhs, result) s
elif rhs.name == 'VAR': )
unify_var(rhs, lhs, result) elif is_list(lhs) and is_list(rhs):
else: if len(lhs) != len(rhs):
left = lhs.children return None
right = rhs.children elif len(lhs) == 0:
if len(left) == len(right): return s
for i in range(len(left)): elif len(lhs) == 1:
unify_res(left[i], right[i], result) return _unify(lhs[0], rhs[0], s)
def unify_var(var, rhs, result):
if rhs.name in ['VAR', 'CONST', 'FUN'] and rhs.value != var.value:
if var.value in result.keys() \
and not result[var.value].equals(rhs):
raise Exception('')
else: else:
result[var.value] = rhs return _unify(lhs[1:], rhs[1:], _unify(lhs[0], rhs[0], s))
else:
return None
def unify_var(var, x, s):
if var.value in s.keys():
return _unify(s[var.value], x, s)
elif x in s.keys():
return _unify(var, s[x], s)
elif occur_check(var, x):
return None
else:
s[var.value] = x
return s
def occur_check(var, x):
if var.name == x.name and var.value == x.value:
return True
for child in x.children:
if occur_check(var, child):
return True
return False
def is_compound(f):
if type(f) is fol.node.Node:
return len(f.children) > 0
return False
def is_equal(lhs, rhs):
if type(lhs) is str and type(rhs) is str:
return lhs == rhs
if type(lhs) is fol.node.Node and type(rhs) is fol.node.Node:
return lhs.equals(rhs)
return False
def is_var(f):
return type(f) is fol.node.Node and f.name == 'VAR'
def is_list(value):
return type(value) is list

View File

@ -1,2 +1,13 @@
import fol
if __name__ == '__main__': if __name__ == '__main__':
print('Hello, World!') kb = fol.kb.Kb()
try:
kb.tell('Friend(ALICE, BOB)')
kb.tell('Friend(BOB, CLAIRE)')
kb.tell('Friend(x, y) -> Friend(y, x)')
kb.tell('(Friend(x, y) & Friend(y, z)) -> Friend(x, z)')
print(kb.ask('Friend(ALICE, x)'))
except KeyboardInterrupt:
pass