111 lines
2.7 KiB
Python
111 lines
2.7 KiB
Python
import fol
|
|
|
|
|
|
class Node:
|
|
def __init__(self, name, value=None):
|
|
self.name = name
|
|
self.value = value
|
|
self.children = []
|
|
|
|
def add_child(self, child):
|
|
self.children.append(child)
|
|
|
|
def __repr__(self):
|
|
return fol.pretty.get(self)
|
|
|
|
def __str__(self):
|
|
res = self.name
|
|
if self.value is not None:
|
|
res += f'[{self.value}]'
|
|
|
|
if len(self.children) > 0:
|
|
res += '('
|
|
sep = ''
|
|
|
|
for child in self.children:
|
|
res += sep + str(child)
|
|
sep = ','
|
|
res += ')'
|
|
return res
|
|
|
|
def copy(self):
|
|
node = Node(self.name, self.value)
|
|
for child in self.children:
|
|
node.add_child(child.copy())
|
|
return node
|
|
|
|
def remove_quant(self, name):
|
|
if self.name == name:
|
|
itr = self
|
|
while itr.name == name:
|
|
itr = itr.children[1]
|
|
return itr
|
|
|
|
node = Node(self.name, self.value)
|
|
for child in self.children:
|
|
if child.name == name:
|
|
node.add_child(child.children[1].remove_quant(name))
|
|
else:
|
|
node.add_child(child.remove_quant(name))
|
|
return node
|
|
|
|
def subst(self, values):
|
|
node = Node(self.name, self.value)
|
|
if self.name == 'VAR' and self.value in values.keys():
|
|
node = values[self.value]
|
|
else:
|
|
for child in self.children:
|
|
node.add_child(child.subst(values))
|
|
|
|
return node
|
|
|
|
def equals(self, rhs):
|
|
if self.name != rhs.name or self.value != rhs.value:
|
|
return False
|
|
|
|
if len(self.children) != len(rhs.children):
|
|
return False
|
|
|
|
for i in range(len(self.children)):
|
|
if not self.children[i].equals(rhs.children[i]):
|
|
return False
|
|
|
|
return True
|
|
|
|
def list_of(self, name):
|
|
result = []
|
|
|
|
if self.name == name:
|
|
for c in self.children:
|
|
if c.name != name:
|
|
result.append(c)
|
|
result.extend(c.list_of(name))
|
|
|
|
return result
|
|
|
|
def flatten(self):
|
|
result = self.list_of(self.name)
|
|
print(result)
|
|
root = fol.node.Node(self.name, self.value)
|
|
root.children = result
|
|
return root
|
|
|
|
def find_by_name(self, name):
|
|
if self.name == name:
|
|
return [self]
|
|
|
|
res = []
|
|
|
|
for c in self.children:
|
|
res.extend(c.find_by_name(name))
|
|
|
|
return res
|
|
|
|
def depth(self):
|
|
if len(self.children) == 0:
|
|
return 0
|
|
res = 0
|
|
for child in self.children:
|
|
res = max(res, child.depth())
|
|
return res + 1
|