Source code for l2p.task_builder

"""
PDDL Problem Formalization/Generation Functions

This module defines the `TaskBuilder` class and related utilities for constructing
PDDL problem specifications programatically.

Refer to: https://marcustantakoun.github.io/l2p.github.io/l2p.html for more information
how to use class functions. Refer to /templates in: https://github.com/AI-Planning/l2p
for how to structurally prompt LLMs so they are compatible with class function parsing.
"""

import time
from .llm import BaseLLM, require_llm
from .utils import *


[docs] class TaskBuilder:
[docs] def __init__( self, objects: dict[str, str] = None, initial: list[dict[str, str]] = None, goal: list[dict[str, str]] = None, ) -> None: """ Initializes an L2P task builder object. Args: objects (dict[str,str]): current dictionary of task objects in specification initial (list[dict[str,str]]): current initial states in specification goal (list[dict[str,str]]): current goal states in specification """ self.objects = objects or {} self.initial = initial or [] self.goal = goal or []
"""Formalize/generate functions"""
[docs] @require_llm def formalize_objects( self, model: BaseLLM, problem_desc: str, prompt_template: str, types: dict[str, str] | list[dict[str, str]] | None = None, constants: dict[str, str] | None = None, syntax_validator: SyntaxValidator = None, max_retries: int = 3, ) -> tuple[dict[str, str], str, tuple[bool, str]]: """ Formalizes PDDL :objects via LLM. Args: model (BaseLLM): LLM to query problem_desc (str): general problem description prompt_template (str): structured prompt template for :objects extraction types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None constants (dict[str,str]): current constants in specification, defaults to None syntax_validator (SyntaxValidator): syntax checker for generated objects, defaults to None max_retries (int): max # of retries if failure occurs Returns: objects (dict[str,str]): dictionary of object types {<name>: <type>} llm_output (str): the raw string BaseLLM response validation_info (tuple[bool,str]): validation info containing pass flag and error message """ types_str = pretty_print_dict(types) if types else "No types provided." const_str = ( format_constants(constants) if constants else "No constants provided." ) prompt = ( prompt_template.replace("{problem_desc}", problem_desc) .replace("{types}", types_str) .replace("{constants}", const_str) ) # iterate through attempts in case of extraction failure for attempt in range(max_retries): try: model.reset_tokens() llm_output = model.query(prompt=prompt) # get BaseLLM response # extract respective types from response objects = parse_objects(llm_output=llm_output) # run syntax validation if applicable validation_info = (True, "All validations passed.") if syntax_validator: for error_type in syntax_validator.error_types: validator = getattr(syntax_validator, f"{error_type}", None) if not callable(validator): continue # dispatch based on expected arguments if error_type == "validate_header": validation_info = validator(llm_output) elif error_type == "validate_duplicate_headers": validation_info = validator(llm_output) elif error_type == "validate_unsupported_keywords": validation_info = validator(llm_output) elif error_type == "validate_task_objects": validation_info = validator(objects, types) if not validation_info[0]: return objects, llm_output, validation_info return objects, llm_output, validation_info except Exception as e: print( f"Error on attempt {attempt + 1}/{max_retries}: {e}\n" f"LLM Output:\n{llm_output if 'llm_output' in locals() else 'None'}\nRetrying...\n" ) time.sleep(2) raise RuntimeError("Max retries exceeded. Failed to extract objects.")
[docs] @require_llm def formalize_initial_state( self, model: BaseLLM, problem_desc: str, prompt_template: str, types: dict[str, str] | list[dict[str, str]] | None = None, constants: dict[str, str] | None = None, predicates: list[Predicate] | None = None, functions: list[Function] | None = None, objects: dict[str, str] | None = None, initial: list[dict[str, str]] | None = None, goal: list[dict[str, str]] | None = None, syntax_validator: SyntaxValidator = None, max_retries: int = 3, ) -> tuple[list[dict[str, str]], str, tuple[bool, str]]: """ Formalizes PDDL :init states via LLM. Args: model (BaseLLM): LLM to query problem_desc (str): general problem description prompt_template (str): structured prompt template for :init extraction types (dict[str,str] | list[dict[str,str]]): current types in domain, defaults to None constants (dict[str,str]): current constants in specification, defaults to None predicates (list[Predicate]): list of current predicates in domain, defaults to None functions (list[Function]): list of current functions in specification, defaults to None objects (dict[str,str]): current dictionary of task :objects in specification, defaults to None initial (list[dict[str,str]]): current :init states in specification, defaults to None goal (list[dict[str,str]]): current :goal states in specification, defaults to None syntax_validator (SyntaxValidator): syntax checker for generated initial states, defaults to None max_retries (int): max # of retries if failure occurs Returns: initial (list[dict[str,str]]): list of dictionaries containing initial states consisting of: {<predicate_name>:<str>, <params>:<list[str]>, <neg>:<bool>} OR {<function_name>:<str>, <params>:<list[str]>, <value>:<int>, <op>:<str>} llm_output (str): the raw string BaseLLM response validation_info (tuple[bool,str]): validation info containing pass flag and error message """ types_str = pretty_print_dict(types) if types else "No types provided." const_str = ( format_constants(constants) if constants else "No constants provided." ) preds_str = ( "\n".join([f"{pred['raw']}" for pred in predicates]) if predicates else "No predicates provided." ) funcs_str = ( "\n".join([f"{func['raw']}" for func in functions]) if functions else "No functions provided." ) obj_str = format_objects(objects) if objects else "No objects provided." init_str = format_initial(initial) if initial else "No initial state provided." goal_str = format_goal(goal) if goal else "No goal state provided." prompt = ( prompt_template.replace("{problem_desc}", problem_desc) .replace("{types}", types_str) .replace("{constants}", const_str) .replace("{predicates}", preds_str) .replace("{functions}", funcs_str) .replace("{objects}", obj_str) .replace("{initial_state}", init_str) .replace("{goal_state}", goal_str) ) # iterate through attempts in case of extraction failure for attempt in range(max_retries): try: model.reset_tokens() llm_output = model.query(prompt=prompt) # extract respective types from response initial = parse_initial(llm_output=llm_output) # run syntax validation if applicable validation_info = (True, "All validations passed.") if syntax_validator: for error_type in syntax_validator.error_types: validator = getattr(syntax_validator, f"{error_type}", None) if not callable(validator): continue # dispatch based on expected arguments if error_type == "validate_header": validation_info = validator(llm_output) elif error_type == "validate_duplicate_headers": validation_info = validator(llm_output) elif error_type == "validate_unsupported_keywords": validation_info = validator(llm_output) elif error_type == "validate_task_states": validation_info = validator( initial, objects, predicates, "initial" ) if not validation_info[0]: return initial, llm_output, validation_info return initial, llm_output, validation_info except Exception as e: print( f"Error on attempt {attempt + 1}/{max_retries}: {e}\n" f"LLM Output:\n{llm_output if 'llm_output' in locals() else 'None'}\nRetrying...\n" ) time.sleep(2) raise RuntimeError("Max retries exceeded. Failed to extract initial states.")
[docs] @require_llm def formalize_goal_state( self, model: BaseLLM, problem_desc: str, prompt_template: str, types: dict[str, str] | list[dict[str, str]] | None = None, constants: dict[str, str] | None = None, predicates: list[Predicate] | None = None, functions: list[Function] | None = None, objects: dict[str, str] | None = None, initial: list[dict[str, str]] | None = None, goal: list[dict[str, str]] | None = None, syntax_validator: SyntaxValidator = None, max_retries: int = 3, ) -> tuple[list[dict[str, str]], str, tuple[bool, str]]: """ Formalizes PDDL :goal states via LLM. Args: model (BaseLLM): LLM to query problem_desc (str): general problem description prompt_template (str): structured prompt template for :goal extraction types (dict[str,str] | list[dict[str,str]]): current :types in domain, defaults to None constants (dict[str,str]): current constants in specification, defaults to None predicates (list[Predicate]): list of current predicates in domain, defaults to None functions (list[Function]): list of current functions in specification, defaults to None objects (dict[str,str]): current dictionary of task :objects in specification, defaults to None initial (list[dict[str,str]]): current :init states in specification, defaults to None goal (list[dict[str,str]]): current :goal states in specification, defaults to None syntax_validator (SyntaxValidator): syntax checker for generated goal states, defaults to None max_retries (int): max # of retries if failure occurs Returns: goal (list[dict[str,str]]): list of dictionaries containing goal states consisting of: {<predicate_name>:<str>, <params>:<list[str]>, <neg>:<bool>} OR {<function_name>:<str>, <params>:<list[str]>, <value>:<int>, <op>:<str>} llm_output (str): the raw string BaseLLM response validation_info (tuple[bool,str]): validation info containing pass flag and error message """ types_str = pretty_print_dict(types) if types else "No types provided." const_str = ( format_constants(constants) if constants else "No constants provided." ) preds_str = ( "\n".join([f"{pred['raw']}" for pred in predicates]) if predicates else "No predicates provided." ) funcs_str = ( "\n".join([f"{func['raw']}" for func in functions]) if functions else "No functions provided." ) obj_str = format_objects(objects) if objects else "No objects provided." init_str = format_initial(initial) if initial else "No initial state provided." goal_str = format_goal(goal) if goal else "No goal state provided." prompt = ( prompt_template.replace("{problem_desc}", problem_desc) .replace("{types}", types_str) .replace("{constants}", const_str) .replace("{predicates}", preds_str) .replace("{functions}", funcs_str) .replace("{objects}", obj_str) .replace("{initial_state}", init_str) .replace("{goal_state}", goal_str) ) # iterate through attempts in case of extraction failure for attempt in range(max_retries): try: model.reset_tokens() llm_output = model.query(prompt=prompt) # extract respective types from response goal = parse_goal(llm_output=llm_output) # run syntax validation if applicable validation_info = (True, "All validations passed.") if syntax_validator: for error_type in syntax_validator.error_types: validator = getattr(syntax_validator, f"{error_type}", None) if not callable(validator): continue # dispatch based on expected arguments if error_type == "validate_header": validation_info = validator(llm_output) elif error_type == "validate_duplicate_headers": validation_info = validator(llm_output) elif error_type == "validate_unsupported_keywords": validation_info = validator(llm_output) elif error_type == "validate_task_states": validation_info = validator( goal, objects, predicates, "goal" ) if not validation_info[0]: return goal, llm_output, validation_info return goal, llm_output, validation_info except Exception as e: print( f"Error on attempt {attempt + 1}/{max_retries}: {e}\n" f"LLM Output:\n{llm_output if 'llm_output' in locals() else 'None'}\nRetrying...\n" ) time.sleep(2) raise RuntimeError("Max retries exceeded. Failed to extract goal states.")
[docs] @require_llm def formalize_task( self, model: BaseLLM, problem_desc: str, prompt_template: str, types: dict[str, str] | list[dict[str, str]] | None = None, constants: dict[str, str] | None = None, predicates: list[Predicate] | None = None, functions: list[Function] | None = None, syntax_validator: SyntaxValidator = None, max_retries: int = 3, ) -> tuple[ dict[str, str], list[dict[str, str]], list[dict[str, str]], str, tuple[bool, str], ]: """ Formalizes whole task specification via LLM. Args: model (BaseLLM): LLM to query problem_desc (str): general problem description prompt_template (str): structured prompt template for :problem extraction types (dict[str,str] | list[dict[str,str]]): current :types in domain, defaults to None constants (dict[str,str]): current constants in specification, defaults to None predicates (list[Predicate]): list of current predicates in domain, defaults to None functions (list[Function]): list of current functions in specification, defaults to None syntax_validator (SyntaxValidator): syntax checker for generated :problem, defaults to None max_retries (int): max # of retries if failure occurs Returns: objects (dict[str,str]): dictionary of object names and assigned types initial (list[dict[str,str]]): list of dictionary of initial states goal (list[dict[str,str]]): list of dictionary of goal states llm_output (str): the raw string BaseLLM response validation_info (tuple[bool,str]): validation info containing pass flag and error message """ types_str = pretty_print_dict(types) if types else "No types provided." const_str = ( format_constants(constants) if constants else "No constants provided." ) preds_str = ( "\n".join([f"{pred['raw']}" for pred in predicates]) if predicates else "No predicates provided." ) funcs_str = ( "\n".join([f"{func['raw']}" for func in functions]) if functions else "No functions provided." ) prompt = ( prompt_template.replace("{problem_desc}", problem_desc) .replace("{types}", types_str) .replace("{constants}", const_str) .replace("{predicates}", preds_str) .replace("{functions}", funcs_str) ) # iterate through attempts in case of extraction failure for attempt in range(max_retries): try: model.reset_tokens() llm_output = model.query(prompt=prompt) # extract respective types from response objects = parse_objects(llm_output=llm_output) initial = parse_initial(llm_output=llm_output) goal = parse_goal(llm_output=llm_output) # run syntax validation if applicable validation_info = (True, "All validations passed.") if syntax_validator: for error_type in syntax_validator.error_types: validator = getattr(syntax_validator, f"{error_type}", None) if not callable(validator): continue # dispatch based on expected arguments if error_type == "validate_header": validation_info = validator(llm_output) elif error_type == "validate_duplicate_headers": validation_info = validator(llm_output) elif error_type == "validate_unsupported_keywords": validation_info = validator(llm_output) elif error_type == "validate_task_objects": validation_info = validator(objects, types) elif error_type == "validate_task_states": validation_info = validator( states=initial, objects=objects, predicates=predicates, functions=functions, state_type="initial", ) if validation_info[0]: validation_info = validator( states=goal, objects=objects, predicates=predicates, functions=functions, state_type="goal", ) if not validation_info[0]: return objects, initial, goal, llm_output, validation_info return objects, initial, goal, llm_output, validation_info except Exception as e: print( f"Error on attempt {attempt + 1}/{max_retries}: {e}\n" f"LLM Output:\n{llm_output if 'llm_output' in locals() else 'None'}\nRetrying...\n" ) time.sleep(2) raise RuntimeError("Max retries exceeded. Failed to extract task.")
"""Delete functions"""
[docs] def delete_objects(self, object: dict[str, str]): """Deletes specific item in :objects from current specification""" if self.objects is not None: self.objects = { var: type_ for var, type_ in self.objects.items() if var != object }
[docs] def delete_initial_state(self, state: dict[str, str]): """Deletes specific :init state from current specification""" if self.initial is not None: self.initial = [s for s in self.initial if s != state]
[docs] def delete_goal_state(self, state: dict[str, str]): """Deletes specific PDDL :goal state from current specification""" if self.goal is not None: self.goal = [s for s in self.goal if s != state]
"""Set functions"""
[docs] def set_objects(self, objects: dict[str, str]): """Sets PDDL :objects for current specification""" self.objects = objects
[docs] def set_initial(self, initial: list[dict[str, str]]): """Sets PDDL :init states for current specification""" self.initial = initial
[docs] def set_goal(self, goal: list[dict[str, str]]): """Sets PDDL :goal states for current specification""" self.goal = goal
"""Get functions"""
[docs] def get_objects(self) -> dict[str, str]: """Returns PDDL :objects from current specification""" return self.objects
[docs] def get_initial(self) -> list[dict[str, str]]: """Returns PDDL :init states from current specification""" return self.initial
[docs] def get_goal(self) -> list[dict[str, str]]: """Returns PDDL :goal states from current specification""" return self.goal
[docs] def generate_task( self, domain_name: str, problem_name: str, objects: dict[str, str], initial: list[dict[str, str]], goal: list[dict[str, str]], ) -> str: """ Generates PDDL problem from given information. Args: domain_name (str): domain name problem_name (str): specific task instance name objects (dict[str,str]): PDDL :objects initial (list[dict[str,str]]): PDDL :init states goal (list[dict[str,str]]): PDDL :goal states Returns: desc (str): PDDL problem in string format """ desc = "(define\n" desc += f" (problem {problem_name})\n" desc += f" (:domain {domain_name})\n\n" desc += f" (:objects \n{indent(format_objects(objects))}\n )\n\n" desc += f" (:init\n{indent(format_initial(initial))}\n )\n\n" desc += f" (:goal\n{indent(format_goal(goal))}\n )\n" desc += ")" desc = desc.replace("AND", "and").replace("OR", "or") return desc