# Copyright 2023 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Implements DAG-level traceback.""" from typing import Any import geometry as gm import pretty as pt import problem pretty = pt.pretty def point_levels( setup: list[problem.Dependency], existing_points: list[gm.Point] ) -> list[tuple[set[gm.Point], list[problem.Dependency]]]: """Reformat setup into levels of point constructions.""" levels = [] for con in setup: plevel = max([p.plevel for p in con.args if isinstance(p, gm.Point)]) while len(levels) - 1 < plevel: levels.append((set(), [])) for p in con.args: if not isinstance(p, gm.Point): continue if existing_points and p in existing_points: continue levels[p.plevel][0].add(p) cons = levels[plevel][1] cons.append(con) return [(p, c) for p, c in levels if p or c] def point_log( setup: list[problem.Dependency], ref_id: dict[tuple[str, ...], int], existing_points=list[gm.Point], ) -> list[tuple[list[gm.Point], list[problem.Dependency]]]: """Reformat setup into groups of point constructions.""" log = [] levels = point_levels(setup, existing_points) for points, cons in levels: for con in cons: if con.hashed() not in ref_id: ref_id[con.hashed()] = len(ref_id) log.append((points, cons)) return log def setup_to_levels( setup: list[problem.Dependency], ) -> list[list[problem.Dependency]]: """Reformat setup into levels of point constructions.""" levels = [] for d in setup: plevel = max([p.plevel for p in d.args if isinstance(p, gm.Point)]) while len(levels) - 1 < plevel: levels.append([]) levels[plevel].append(d) levels = [lvl for lvl in levels if lvl] return levels def separate_dependency_difference( query: problem.Dependency, log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], ) -> tuple[ list[tuple[list[problem.Dependency], list[problem.Dependency]]], list[problem.Dependency], list[problem.Dependency], set[gm.Point], set[gm.Point], ]: """Identify and separate the dependency difference.""" setup = [] log_, log = log, [] for prems, cons in log_: if not prems: setup.extend(cons) continue cons_ = [] for con in cons: if con.rule_name == 'c0': setup.append(con) else: cons_.append(con) if not cons_: continue prems = [p for p in prems if p.name != 'ind'] log.append((prems, cons_)) points = set(query.args) queue = list(query.args) i = 0 while i < len(queue): q = queue[i] i += 1 if not isinstance(q, gm.Point): continue for p in q.rely_on: if p not in points: points.add(p) queue.append(p) setup_, setup, aux_setup, aux_points = setup, [], [], set() for con in setup_: if con.name == 'ind': continue elif any([p not in points for p in con.args if isinstance(p, gm.Point)]): aux_setup.append(con) aux_points.update( [p for p in con.args if isinstance(p, gm.Point) and p not in points] ) else: setup.append(con) return log, setup, aux_setup, points, aux_points def recursive_traceback( query: problem.Dependency, ) -> list[tuple[list[problem.Dependency], list[problem.Dependency]]]: """Recursively traceback from the query, i.e. the conclusion.""" visited = set() log = [] stack = [] def read(q: problem.Dependency) -> None: q = q.remove_loop() hashed = q.hashed() if hashed in visited: return if hashed[0] in ['ncoll', 'npara', 'nperp', 'diff', 'sameside']: return nonlocal stack stack.append(hashed) prems = [] if q.rule_name != problem.CONSTRUCTION_RULE: all_deps = [] dep_names = set() for d in q.why: if d.hashed() in dep_names: continue dep_names.add(d.hashed()) all_deps.append(d) for d in all_deps: h = d.hashed() if h not in visited: read(d) if h in visited: prems.append(d) visited.add(hashed) hashs = sorted([d.hashed() for d in prems]) found = False for ps, qs in log: if sorted([d.hashed() for d in ps]) == hashs: qs += [q] found = True break if not found: log.append((prems, [q])) stack.pop(-1) read(query) # post process log: separate multi-conclusion lines log_, log = log, [] for ps, qs in log_: for q in qs: log.append((ps, [q])) return log def collx_to_coll_setup( setup: list[problem.Dependency], ) -> list[problem.Dependency]: """Convert collx to coll in setups.""" result = [] for level in setup_to_levels(setup): hashs = set() for dep in level: if dep.name == 'collx': dep.name = 'coll' dep.args = list(set(dep.args)) if dep.hashed() in hashs: continue hashs.add(dep.hashed()) result.append(dep) return result def collx_to_coll( setup: list[problem.Dependency], aux_setup: list[problem.Dependency], log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], ) -> tuple[ list[problem.Dependency], list[problem.Dependency], list[tuple[list[problem.Dependency], list[problem.Dependency]]], ]: """Convert collx to coll and dedup.""" setup = collx_to_coll_setup(setup) aux_setup = collx_to_coll_setup(aux_setup) con_set = set([p.hashed() for p in setup + aux_setup]) log_, log = log, [] for prems, cons in log_: prem_set = set() prems_, prems = prems, [] for p in prems_: if p.name == 'collx': p.name = 'coll' p.args = list(set(p.args)) if p.hashed() in prem_set: continue prem_set.add(p.hashed()) prems.append(p) cons_, cons = cons, [] for c in cons_: if c.name == 'collx': c.name = 'coll' c.args = list(set(c.args)) if c.hashed() in con_set: continue con_set.add(c.hashed()) cons.append(c) if not cons or not prems: continue log.append((prems, cons)) return setup, aux_setup, log def get_logs( query: problem.Dependency, g: Any, merge_trivials: bool = False ) -> tuple[ list[problem.Dependency], list[problem.Dependency], list[tuple[list[problem.Dependency], list[problem.Dependency]]], set[gm.Point], ]: """Given a DAG and conclusion N, return the premise, aux, proof.""" query = query.why_me_or_cache(g, query.level) log = recursive_traceback(query) log, setup, aux_setup, setup_points, _ = separate_dependency_difference( query, log ) setup, aux_setup, log = collx_to_coll(setup, aux_setup, log) setup, aux_setup, log = shorten_and_shave( setup, aux_setup, log, merge_trivials ) return setup, aux_setup, log, setup_points def shorten_and_shave( setup: list[problem.Dependency], aux_setup: list[problem.Dependency], log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], merge_trivials: bool = False, ) -> tuple[ list[problem.Dependency], list[problem.Dependency], list[tuple[list[problem.Dependency], list[problem.Dependency]]], ]: """Shorten the proof by removing unused predicates.""" log, _ = shorten_proof(log, merge_trivials=merge_trivials) all_prems = sum([list(prems) for prems, _ in log], []) all_prems = set([p.hashed() for p in all_prems]) setup = [d for d in setup if d.hashed() in all_prems] aux_setup = [d for d in aux_setup if d.hashed() in all_prems] return setup, aux_setup, log def join_prems( con: problem.Dependency, con2prems: dict[tuple[str, ...], list[problem.Dependency]], expanded: set[tuple[str, ...]], ) -> list[problem.Dependency]: """Join proof steps with the same premises.""" h = con.hashed() if h in expanded or h not in con2prems: return [con] result = [] for p in con2prems[h]: result += join_prems(p, con2prems, expanded) return result def shorten_proof( log: list[tuple[list[problem.Dependency], list[problem.Dependency]]], merge_trivials: bool = False, ) -> tuple[ list[tuple[list[problem.Dependency], list[problem.Dependency]]], dict[tuple[str, ...], list[problem.Dependency]], ]: """Join multiple trivials proof steps into one.""" pops = set() con2prem = {} for prems, cons in log: assert len(cons) == 1 con = cons[0] if con.rule_name == '': # pylint: disable=g-explicit-bool-comparison con2prem[con.hashed()] = prems elif not merge_trivials: # except for the ones that are premises to non-trivial steps. pops.update({p.hashed() for p in prems}) for p in pops: if p in con2prem: con2prem.pop(p) expanded = set() log2 = [] for i, (prems, cons) in enumerate(log): con = cons[0] if i < len(log) - 1 and con.hashed() in con2prem: continue hashs = set() new_prems = [] for p in sum([join_prems(p, con2prem, expanded) for p in prems], []): if p.hashed() not in hashs: new_prems.append(p) hashs.add(p.hashed()) log2 += [(new_prems, [con])] expanded.add(con.hashed()) return log2, con2prem