Spaces:
Sleeping
Sleeping
File size: 23,673 Bytes
be3b34d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 |
# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run DD+AR or AlphaGeometry solver.
Please refer to README.md for detailed instructions.
"""
import time
import traceback
from absl import app
from absl import flags
from absl import logging
import ddar
import graph as gh
import lm_inference as lm
import pretty as pt
import problem as pr
#=============
import sys, os, math, re
import multiprocessing
import warnings
warnings.filterwarnings("ignore")
model = None # global variable used in multi-processing workers
_GIN_SEARCH_PATHS = flags.DEFINE_list(
'gin_search_paths',
['third_party/py/meliad/transformer/configs'],
'List of paths where the Gin config files are located.',
)
_GIN_FILE = flags.DEFINE_multi_string(
'gin_file', ['base_htrans.gin'], 'List of Gin config files.'
)
_GIN_PARAM = flags.DEFINE_multi_string(
'gin_param', None, 'Newline separated list of Gin parameter bindings.'
)
_PROBLEMS_FILE = flags.DEFINE_string(
'problems_file',
'imo_ag_30.txt',
'text file contains the problem strings. See imo_ag_30.txt for example.',
)
_PROBLEM_NAME = flags.DEFINE_string(
'problem_name',
'imo_2000_p1',
'name of the problem to solve, must be in the problem_file.',
)
_MODE = flags.DEFINE_string(
'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`')
_DEFS_FILE = flags.DEFINE_string(
'defs_file',
'defs.txt',
'definitions of available constructions to state a problem.',
)
_RULES_FILE = flags.DEFINE_string(
'rules_file', 'rules.txt', 'list of deduction rules used by DD.'
)
_CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.')
_VOCAB_PATH = flags.DEFINE_string(
'vocab_path', '', 'path to the LM vocab file.'
)
_OUT_FILE = flags.DEFINE_string(
'out_file', '', 'path to the solution output file.'
) # pylint: disable=line-too-long
_BEAM_SIZE = flags.DEFINE_integer(
'beam_size', 1, 'beam size of the proof search.'
) # pylint: disable=line-too-long
_SEARCH_DEPTH = flags.DEFINE_integer(
'search_depth', 1, 'search depth of the proof search.'
) # pylint: disable=line-too-long
#===================================
_N_WORKSERS = flags.DEFINE_integer(
'n_workers', 1, 'number of workers'
)# pylint: disable=line-too-long
DEFINITIONS = None # contains definitions of construction actions
RULES = None # contains rules of deductions
def natural_language_statement(logical_statement: pr.Dependency) -> str:
"""Convert logical_statement to natural language.
Args:
logical_statement: pr.Dependency with .name and .args
Returns:
a string of (pseudo) natural language of the predicate for human reader.
"""
names = [a.name.upper() for a in logical_statement.args]
names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names]
return pt.pretty_nl(logical_statement.name, names)
def proof_step_string(
proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool
) -> str:
"""Translate proof to natural language.
Args:
proof_step: pr.Dependency with .name and .args
refs: dict(hash: int) to keep track of derived predicates
last_step: boolean to keep track whether this is the last step.
Returns:
a string of (pseudo) natural language of the proof step for human reader.
"""
premises, [conclusion] = proof_step
premises_nl = ' & '.join(
[
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
for p in premises
]
)
if not premises:
premises_nl = 'similarly'
refs[conclusion.hashed()] = len(refs)
conclusion_nl = natural_language_statement(conclusion)
if not last_step:
conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()])
return f'{premises_nl} \u21d2 {conclusion_nl}'
def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
"""Output the solution to out_file.
Args:
g: gh.Graph object, containing the proof state.
p: pr.Problem object, containing the theorem.
out_file: file to write to, empty string to skip writing to file.
"""
setup, aux, proof_steps, refs = ddar.get_proof_steps(
g, p.goal, merge_trivials=False
)
solution = ''
solution += 'Theo đề bài ta có:\n'
premises_nl = []
for premises, [points] in setup:
solution += ' '.join([p.name.upper() for p in points]) + ' '
if not premises:
continue
premises_nl += [
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
for p in premises
]
solution += ': Points\n' + '\n'.join(premises_nl)
solution += '\n\nCác điểm cần dựng thêm:\n'
aux_premises_nl = []
if len(aux) == 0:
solution += 'Không cần dựng thêm điểm nào.'
else:
for premises, [points] in aux:
solution += ' '.join([p.name.upper() for p in points]) + ' '
aux_premises_nl += [
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
for p in premises
]
solution += ': Points\n' + '\n'.join(aux_premises_nl)
# some special case where the deduction rule has a well known name.
r2name = {
'r32': '(SSS)',
'r33': '(SAS)',
'r34': '(Similar Triangles)',
'r35': '(Similar Triangles)',
'r36': '(ASA)',
'r37': '(ASA)',
'r38': '(Similar Triangles)',
'r39': '(Similar Triangles)',
'r40': '(Congruent Triangles)',
'a00': '(Distance chase)',
'a01': '(Ratio chase)',
'a02': '(Angle chase)',
}
solution += '\n\nCác bước chứng minh:\n'
for i, step in enumerate(proof_steps):
_, [con] = step
nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
rule_name = r2name.get(con.rule_name, '')
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
solution += '{:03}. '.format(i + 1) + nl + '\n'
logging.info(solution)
if out_file:
with open(out_file, 'w') as f:
f.write(solution)
logging.info('Solution written to %s.', out_file)
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
lm.parse_gin_configuration(
_GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
)
return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
"""Run DD+AR.
Args:
g: gh.Graph object, containing the proof state.
p: pr.Problem object, containing the problem statement.
out_file: path to output file if solution is found.
Returns:
Boolean, whether DD+AR finishes successfully.
"""
ddar.solve(g, RULES, p, max_level=1000)
goal_args = g.names2nodes(p.goal.args)
if not g.check(p.goal.name, goal_args):
logging.info('DD+AR failed to solve the problem.')
return False
write_solution(g, p, out_file)
gh.nm.draw(
g.type2nodes[gh.Point],
g.type2nodes[gh.Line],
g.type2nodes[gh.Circle],
g.type2nodes[gh.Segment],
save_to="ag4mout/output.png",)
return True
def translate_constrained_to_constructive(
point: str, name: str, args: list[str]
) -> tuple[str, list[str]]:
"""Translate a predicate from constraint-based to construction-based.
Args:
point: str: name of the new point
name: str: name of the predicate, e.g., perp, para, etc.
args: list[str]: list of predicate args.
Returns:
(name, args): translated to constructive predicate.
"""
if name in ['T', 'perp']:
a, b, c, d = args
if point in [c, d]:
a, b, c, d = c, d, a, b
if point == b:
a, b = b, a
if point == d:
c, d = d, c
if a == c and a == point:
return 'on_dia', [a, b, d]
return 'on_tline', [a, b, c, d]
elif name in ['P', 'para']:
a, b, c, d = args
if point in [c, d]:
a, b, c, d = c, d, a, b
if point == b:
a, b = b, a
return 'on_pline', [a, b, c, d]
elif name in ['D', 'cong']:
a, b, c, d = args
if point in [c, d]:
a, b, c, d = c, d, a, b
if point == b:
a, b = b, a
if point == d:
c, d = d, c
if a == c and a == point:
return 'on_bline', [a, b, d]
if b in [c, d]:
if b == d:
c, d = d, c # pylint: disable=unused-variable
return 'on_circle', [a, b, d]
return 'eqdistance', [a, b, c, d]
elif name in ['C', 'coll']:
a, b, c = args
if point == b:
a, b = b, a
if point == c:
a, b, c = c, a, b
return 'on_line', [a, b, c]
elif name in ['^', 'eqangle']:
a, b, c, d, e, f = args
if point in [d, e, f]:
a, b, c, d, e, f = d, e, f, a, b, c
x, b, y, c, d = b, c, e, d, f
if point == b:
a, b, c, d = b, a, d, c
if point == d and x == y: # x p x b = x c x p
return 'angle_bisector', [point, b, x, c]
if point == x:
return 'eqangle3', [x, a, b, y, c, d]
return 'on_aline', [a, x, b, c, y, d]
elif name in ['cyclic', 'O']:
a, b, c = [x for x in args if x != point]
return 'on_circum', [point, a, b, c]
return name, args
def check_valid_args(name: str, args: list[str]) -> bool:
"""Check whether a predicate is grammarically correct.
Args:
name: str: name of the predicate
args: list[str]: args of the predicate
Returns:
bool: whether the predicate arg count is valid.
"""
if name == 'perp':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b}) < 2:
return False
if len({c, d}) < 2:
return False
elif name == 'para':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b, c, d}) < 4:
return False
elif name == 'cong':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b}) < 2:
return False
if len({c, d}) < 2:
return False
elif name == 'coll':
if len(args) != 3:
return False
a, b, c = args
if len({a, b, c}) < 3:
return False
elif name == 'cyclic':
if len(args) != 4:
return False
a, b, c, d = args
if len({a, b, c, d}) < 4:
return False
elif name == 'eqangle':
if len(args) != 8:
return False
a, b, c, d, e, f, g, h = args
if len({a, b, c, d}) < 3:
return False
if len({e, f, g, h}) < 3:
return False
return True
def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str:
"""Whether a string of aux construction can be constructed.
Args:
string: str: the string describing aux construction.
g: gh.Graph: the current proof state.
Returns:
str: whether this construction is valid. If not, starts with "ERROR:".
"""
if string[-1] != ';':
return 'ERROR: must end with ;'
logging.info(f'PID={os.getpid()}: !! try_translate_constrained_to_construct: string=%s', string)
# sometimes the LM may return ill-formed result with multiple colons.
# example:
#
# napoleon2
# a1 a2 a3 = triangle; c3 = s_angle a1 a2 c3 30, s_angle a2 a1 c3 150; c1 = s_angle a2 a3 c1 30, s_angle a3 a2 c1 150; c2 = s_angle a3 a1 c2 30, s_angle a1 a3 c2 150 ? cong c1 c2 c1 c3
#
# in the process,
# I0210 17:58:01.513668 140016515833856 alphageometry.py:550] Decoding from {S} a : ; b : ; c : ; d : ^ a d a b 5. pi / 6. 00 ^ b d b a 1. pi / 6. 01 ; e : ^ b e b c 5. pi / 6. 02 ^ c e c b 1. pi / 6. 03 ; f : ^ a f a c 1. pi / 6. 04 ^ c f c a 5. pi / 6. 05 ? D e f e d {F1} x00 g : C a b g 06 D a g b g 07 ; x00 h : C c b h 08 D c h b h 09 ; x00
# I0210 18:01:38.182158 140016515833856 alphageometry.py:384] !! try_translate_constrained_to_construct: string=i : C a c i 10 D a i c i 11 ? V d f {F1} x00 j : D g j h j 12 D h j i j 13 ;
#XXX
# str_parts = string.split(' : ')
# if len(str_parts) != 2:
# return f'ERROR: string has multiple colons: |{string}|'
mch = re.match('(.*?)( \? | \. \{)', string)
if mch :
strFixed = mch.group(1) + ';'
logging.info(f'ID={os.getpid()}: Bad LM output: {string}. Changed to {strFixed}')
string = strFixed
# sometimes the constraint in string is empty:
# 0407 17:11:35.470240 126383800963072 alphageometry.py:394] !! try_translate_constrained_to_construct: string=j : ;
hdprem = string.split(' : ')
if len(hdprem) !=2 or hdprem[1].strip()==';' :
logging.info(f'ID={os.getpid()}: Bad LM output: {string}. ERROR')
return f'ERROR: Bad LM output: {string}'
head, prem_str = hdprem
point = head.strip()
if len(point) != 1 or point == ' ':
return f'ERROR: invalid point name {point}'
existing_points = [p.name for p in g.all_points()]
if point in existing_points:
return f'ERROR: point {point} already exists.'
prem_toks = prem_str.split()[:-1] # remove the EOS ' ;'
prems = [[]]
for i, tok in enumerate(prem_toks):
if tok.isdigit():
if i < len(prem_toks) - 1:
prems.append([])
else:
prems[-1].append(tok)
if len(prems) > 2:
return 'ERROR: there cannot be more than two predicates.'
clause_txt = point + ' = '
constructions = []
for prem in prems:
name, *args = prem
if point not in args:
return f'ERROR: {point} not found in predicate args.'
if not check_valid_args(pt.map_symbol(name), args):
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
for a in args:
if a != point and a not in existing_points:
return f'ERROR: point {a} does not exist.'
try:
name, args = translate_constrained_to_constructive(point, name, args)
except: # pylint: disable=bare-except
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
if name == 'on_aline':
if args.count(point) > 1:
return f'ERROR: on_aline involves twice {point}'
constructions += [name + ' ' + ' '.join(args)]
clause_txt += ', '.join(constructions)
clause = pr.Clause.from_txt(clause_txt)
try:
g.copy().add_clause(clause, 0, DEFINITIONS)
except: # pylint: disable=bare-except
return 'ERROR: ' + traceback.format_exc()
return clause_txt
def insert_aux_to_premise(pstring: str, auxstring: str) -> str:
"""Insert auxiliary constructs from proof to premise.
Args:
pstring: str: describing the problem to solve.
auxstring: str: describing the auxiliar construction.
Returns:
str: new pstring with auxstring inserted before the conclusion.
"""
setup, goal = pstring.split(' ? ')
return setup + '; ' + auxstring + ' ? ' + goal
class BeamQueue:
"""Keep only the top k objects according to their values."""
def __init__(self, max_size: int = 512):
self.queue = []
self.max_size = max_size
def add(self, node: object, val: float) -> None:
"""Add a new node to this queue."""
if len(self.queue) < self.max_size:
self.queue.append((val, node))
return
# Find the minimum node:
min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1])
# replace it if the new node has higher value.
if val > min_val:
self.queue[min_idx] = (val, node)
def __iter__(self):
for val, node in self.queue:
yield val, node
def __len__(self) -> int:
return len(self.queue)
#XXX
def bqsearch_init():
global model
logging.info('Worker initializing. PID=%d', os.getpid())
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
def bqsearch(i_nd, srch_inputs, out_file) -> tuple[int, bool, list]: # ( iNode, solved, [ (node, score) ] )
pid = os.getpid()
logging.info(f'Worker PID={pid} called for beam search node {i_nd}')
prev_score, (g, string, pstring) = srch_inputs
logging.info(f'Worker PID={pid}: Decoding from {string}')
outputs = model.beam_decode(string, eos_tokens=[';'])
# translate lm output to the constructive language.
# so that we can update the graph representing proof states:
translations = [
try_translate_constrained_to_construct(o, g)
for o in outputs['seqs_str']
]
# couple the lm outputs with its translations
candidates = zip(outputs['seqs_str'], translations, outputs['scores'])
# bring the highest scoring candidate first
candidates = reversed(list(candidates))
ret = []
for lm_out, translation, score in candidates:
logging.info(f'Worker PID={pid}: LM output (score={score}): "{lm_out}"')
logging.info(f'Worker PID={pid}: Translation: "{translation}"')
if translation.startswith('ERROR:'):
# the construction is invalid.
continue
# Update the constructive statement of the problem with the aux point:
candidate_pstring = insert_aux_to_premise(pstring, translation)
#XXX
logging.info(f'Worker PID={pid}: string=|{string}| lm_out=|{lm_out}|')
logging.info(f'Worker PID={pid}: Solving: "{candidate_pstring}"')
p_new = pr.Problem.from_txt(candidate_pstring)
# This is the new proof state graph representation:
g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS)
try:
if run_ddar(g_new, p_new, out_file):
logging.info('Worker PID={pid}: Solved.')
return (i_nd, True, None)
except Exception as e:
logging.info(f'Worker PID={pid}: Error in run_ddar: {e}')
# Add the candidate to the beam queue.
ret.append( [
# The string for the new node is old_string + lm output +
# the special token asking for a new auxiliary point ' x00':
# node
(g_new, string + ' ' + lm_out + ' x00', candidate_pstring),
# the score of each node is sum of score of all nodes
# on the path to itself. For beam search, there is no need to
# normalize according to path length because all nodes in beam
# is of the same path length.
# val
prev_score + score ]
)
logging.info(f'Worker PID={pid} beam search node {i_nd}: returning')
return (i_nd, False, ret)
def run_alphageometry(
#XX model: lm.LanguageModelInference,
p: pr.Problem,
search_depth: int,
beam_size: int,
out_file: str,
) -> bool:
"""Simplified code to run AlphaGeometry proof search.
We removed all optimizations that are infrastructure-dependent, e.g.
parallelized model inference on multi GPUs,
parallelized DD+AR on multiple CPUs,
parallel execution of LM and DD+AR,
shared pool of CPU workers across different problems, etc.
Many other speed optimizations and abstractions are also removed to
better present the core structure of the proof search.
Args:
model: Interface with inference-related endpoints to JAX's model.
p: pr.Problem object describing the problem to solve.
search_depth: max proof search depth.
beam_size: beam size of the proof search.
out_file: path to output file if solution is found.
Returns:
boolean of whether this is solved.
"""
# translate the problem to a string of grammar that the LM is trained on.
string = p.setup_str_from_problem(DEFINITIONS)
# special tokens prompting the LM to generate auxiliary points.
string += ' {F1} x00'
# the graph to represent the proof state.
g, _ = gh.Graph.build_problem(p, DEFINITIONS)
# First we run the symbolic engine DD+AR:
if run_ddar(g, p, out_file):
return True
# ?? when pickling graph for some problems, the default recursion limit 1000 is not enough,
# got 'maximum recursion depth exceeded while pickling an object' error
sys.setrecursionlimit(10000)
# beam search for the proof
# each node in the search tree is a 3-tuple:
# (<graph representation of proof state>,
# <string for LM to decode from>,
# <original problem string>)
beam_queue = BeamQueue(max_size=beam_size)
# originally the beam search tree starts with a single node (a 3-tuple):
beam_queue.add(
node=(g, string, p.txt()), val=0.0 # value of the root node is simply 0.
)
pool = None
if _N_WORKSERS.value == 1:
bqsearch_init()
else:
pool = multiprocessing.Pool(_N_WORKSERS.value, bqsearch_init)
for depth in range(search_depth):
logging.info(
'Depth %s. There are %i nodes to expand:', depth, len(beam_queue)
)
for _, (_, string, _) in beam_queue:
logging.info(string)
new_queue = BeamQueue(max_size=beam_size) # to replace beam_queue.
if _N_WORKSERS.value==1:
for i, srch_inputs in enumerate(beam_queue):
_, solved, res = bqsearch(i, srch_inputs, out_file)
if solved:
return True
for node, val in res:
# Add the candidate to the beam queue.
new_queue.add(node, val)
# Note that the queue only maintain at most beam_size nodes
# so this new node might possibly be dropped depending on its value.
else:
jobs = [pool.apply_async(bqsearch, (i, srch_inputs, out_file)) for i, srch_inputs in enumerate(beam_queue)]
n_done = 0
while n_done < len(beam_queue):
for i, jobres in enumerate(jobs):
if jobres and jobres.ready():
n_done += 1
jobs[i] = None
_, solved, res = jobres.get()
if solved:
# Clean up resources
pool.terminate()
pool.join()
return True
for node, val in res:
# Add the candidate to the beam queue.
new_queue.add(node, val)
# Note that the queue only maintain at most beam_size nodes
# so this new node might possibly be dropped depending on its value.
time.sleep(1) # Adjust wait time as needed
# replace the old queue with new queue before the new proof search depth.
beam_queue = new_queue
# Clean up resources
if pool:
pool.terminate()
pool.join()
return False
def main(_):
global DEFINITIONS
global RULES
# definitions of terms used in our domain-specific language.
DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True)
# load inference rules used in DD.
RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True)
# when using the language model,
# point names will be renamed to alphabetical a, b, c, d, e, ...
# instead of staying with their original names,
# in order to match the synthetic training data generation.
need_rename = _MODE.value != 'ddar'
# load problems from the problems_file,
problems = pr.Problem.from_txt_file(
_PROBLEMS_FILE.value, to_dict=True, translate=need_rename
)
if _PROBLEM_NAME.value not in problems:
raise ValueError(
f'Problem name `{_PROBLEM_NAME.value}` '
+ f'not found in `{_PROBLEMS_FILE.value}`'
)
this_problem = problems[_PROBLEM_NAME.value]
if _MODE.value == 'ddar':
g, _ = gh.Graph.build_problem(this_problem, DEFINITIONS)
run_ddar(g, this_problem, _OUT_FILE.value)
elif _MODE.value == 'alphageometry':
#XX model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
run_alphageometry(
#XX model,
this_problem,
_SEARCH_DEPTH.value,
_BEAM_SIZE.value,
_OUT_FILE.value,
)
else:
raise ValueError(f'Unknown FLAGS.mode: {_MODE.value}')
if __name__ == '__main__':
app.run(main)
|