import fol class Node: def __init__(self, name, value=None): self.name = name self.value = value self.children = [] def add_child(self, child): self.children.append(child) def __repr__(self): return fol.pretty.get(self) def __str__(self): res = self.name if self.value is not None: res += f'[{self.value}]' if len(self.children) > 0: res += '(' sep = '' for child in self.children: res += sep + str(child) sep = ',' res += ')' return res def copy(self): node = Node(self.name, self.value) for child in self.children: node.add_child(child.copy()) return node def remove_quant(self, name): if self.name == name: itr = self while itr.name == name: itr = itr.children[1] return itr node = Node(self.name, self.value) for child in self.children: if child.name == name: node.add_child(child.children[1].remove_quant(name)) else: node.add_child(child.remove_quant(name)) return node def subst(self, values): node = Node(self.name, self.value) if self.name == 'VAR' and self.value in values.keys(): node = values[self.value] else: for child in self.children: node.add_child(child.subst(values)) return node def equals(self, rhs): if self.name != rhs.name or self.value != rhs.value: return False if len(self.children) != len(rhs.children): return False for i in range(len(self.children)): if not self.children[i].equals(rhs.children[i]): return False return True def list_of(self, name): result = [] if self.name == name: for c in self.children: if c.name != name: result.append(c) result.extend(c.list_of(name)) return result def flatten(self): result = self.list_of(self.name) print(result) root = fol.node.Node(self.name, self.value) root.children = result return root def find_by_name(self, name): if self.name == name: return [self] res = [] for c in self.children: res.extend(c.find_by_name(name)) return res def depth(self): if len(self.children) == 0: return 0 res = 0 for child in self.children: res = max(res, child.depth()) return res + 1