"""
PDDL Domain Formalization/Generation Functions
This module defines the `DomainBuilder` class and related utilities for constructing
PDDL domain specifications programatically.
Refer to: https://marcustantakoun.github.io/l2p.github.io/l2p.html for more information
how to use class functions. Refer to /templates in: https://github.com/AI-Planning/l2p
for how to structurally prompt LLMs so they are compatible with class function parsing.
"""
import re
import time
from collections import OrderedDict
from typing import Any
from .llm import BaseLLM, require_llm
from .utils import *
[docs]
class DomainBuilder:
[docs]
def __init__(
self,
requirements: list[str] = None,
types: dict[str, str] = None,
type_hierarchy: list[dict[str, str]] = None,
constants: dict[str, str] = None,
predicates: list[Predicate] = None,
functions: list[Function] = None,
pddl_actions: list[Action] = None,
) -> None:
"""
Initializes an L2P domain builder object.
Args:
requirements (list[str]): list of PDDL requirements
types (dict[str,str]): flat types dictionary w/ {name: description} key-value pair (PDDL :types)
type_hierarchy (list[dict[str,str]]): type hierarchy dictionary list (PDDL :types)
constants (dict[str,str]): flat constant dictionary w/ {name: type} key-value pair (PDDL :constants)
predicates (list[Predicate]): list of Predicate objects (PDDL :predicates)
functions (list[Function]): list of Function objects (PDDL :functions)
pddl_actions (list[Action]): list of Action objects (PDDL :action)
"""
self.requirements = requirements or []
self.types = types or {}
self.type_hierarchy = type_hierarchy or []
self.constants = constants or {}
self.predicates = predicates or []
self.functions = functions or []
self.pddl_actions = pddl_actions or []
"""Formalize/generate functions"""
[docs]
@require_llm
def formalize_types(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
types: dict[str, str] | list[dict[str, str]] | None = None,
check_invalid_obj_usage: bool = True,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[dict[str, str], str, tuple[bool, str]]:
"""
Formalizes PDDL :types in singular flat hierarchy via LLM. It is recommended to use
`formalize_type_hierarchy()` for sub-type support.
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :types extraction
types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None
check_invalid_obj_usage (bool): removes keyword `object` from types, defaults to True
syntax_validator (SyntaxValidator): syntax checker for generated types, defaults to None
max_retries (int): max # of retries if failure occurs, defaults to 3
Returns:
types (dict[str,str]): dictionary of types with {<name>: <description>} pair
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool,str]): validation info containing pass flag and error message
"""
types_str = pretty_print_dict(types) if types else "No types provided."
prompt = prompt_template.replace("{domain_desc}", domain_desc).replace(
"{types}", types_str
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
# parse LLM output into types
types = parse_types(llm_output=llm_output)
# flag that removes keyword 'object' if detected
if check_invalid_obj_usage:
if types and "object" in types:
del types["object"]
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_format_types":
validation_info = validator(types)
if not validation_info[0]:
return types, llm_output, validation_info
return types, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract types.")
[docs]
@require_llm
def formalize_type_hierarchy(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
types: dict[str, str] | list[dict[str, str]] | None = None,
check_invalid_obj_usage: bool = True,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[list[dict[str, str]], str, tuple[bool, str]]:
"""
Formalizes PDDL :types in hierarchy format via LLM. Recommended to use over `formalize_types()`
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :types extraction
types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None
check_invalid_obj_usage (bool): removes keyword `object` from types, defaults to True
syntax_validator (SyntaxValidator): syntax checker for generated types, defaults to None
max_retries (int): max # of retries if failure occurs
Returns:
type_hierarchy (list[dict[str,str]]): list of dictionaries containing the type hierarchy
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool,str]): validation info containing pass flag and error message
"""
types_str = pretty_print_dict(types) if types else "No types provided."
prompt = prompt_template.replace("{domain_desc}", domain_desc).replace(
"{types}", types_str
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
# extract respective types from response
type_hierarchy = parse_type_hierarchy(llm_output=llm_output)
# flag that removes keyword 'object' if detected
if type_hierarchy is not None:
if check_invalid_obj_usage:
# promote children if top-level "object" type exists
new_hierarchy = []
for entry in type_hierarchy:
if "object" in entry:
children = entry.get("children", [])
new_hierarchy.extend(children)
else:
new_hierarchy.append(entry)
type_hierarchy = new_hierarchy
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_format_types":
validation_info = validator(type_hierarchy)
elif error_type == "validate_cyclic_types":
validation_info = validator(type_hierarchy)
if not validation_info[0]:
return type_hierarchy, llm_output, validation_info
return type_hierarchy, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract types.")
[docs]
@require_llm
def formalize_constants(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
types: dict[str, str] | list[dict[str, str]] | None = None,
constants: dict[str, str] | None = None,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[dict[str, str], str, tuple[bool, str]]:
"""
Formalizes PDDL :constants in flat dictionary format via LLM.
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :constants extraction
types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None
constants (dict[str,str]): current constants in specification, defaults to None
syntax_validator (SyntaxValidator): syntax checker for generated constants, defaults to None
max_retries (int): max # of retries if failure occurs, defaults to 3
Returns:
constants (dict[str,str]): dictionary of constants with {<name>: <type>} pair
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool,str]): validation info containing pass flag and error message
"""
types_str = pretty_print_dict(types) if types else "No types provided."
const_str = (
format_constants(constants) if constants else "No constants provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{types}", types_str)
.replace("{constants}", const_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
# parse LLM output into constants
constants = parse_constants(llm_output=llm_output)
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_constant_types":
validation_info = validator(constants, types)
if not validation_info[0]:
return constants, llm_output, validation_info
return constants, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract constants.")
[docs]
@require_llm
def formalize_predicates(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
types: dict[str, str] | list[dict[str, str]] | None = None,
constants: dict[str, str] | None = None,
predicates: list[Predicate] = None,
functions: list[Function] = None,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[list[Predicate], str, tuple[bool, str]]:
"""
Formalizes PDDL :predicates via LLM.
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :predicates extraction
types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None
constants (dict[str,str]): current constants in specification, defaults to None
predicates (list[Predicate]): list of current predicates in specification, defaults to None
functions (list[Function]): list of current functions in specification, defaults to None
syntax_validator (SyntaxValidator): syntax checker for generated predicates, defaults to None
max_retries (int): max # of retries if failure occurs
Returns:
new_predicates (list[Predicate]): a list of new predicates
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool, str]): validation info containing pass flag and error message
"""
types_str = pretty_print_dict(types) if types else "No types provided."
const_str = (
format_constants(constants) if constants else "No constants provided."
)
preds_str = (
"\n".join([f"{pred['raw']}" for pred in predicates])
if predicates
else "No predicates provided."
)
funcs_str = (
"\n".join([f"{func['raw']}" for func in functions])
if functions
else "No functions provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{types}", types_str)
.replace("{constants}", const_str)
.replace("{predicates}", preds_str)
.replace("{functions}", funcs_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt) # prompt model
# extract new predicates from response
new_predicates = parse_new_predicates(llm_output=llm_output)
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_header":
validation_info = validator(llm_output)
elif error_type == "validate_duplicate_headers":
validation_info = validator(llm_output)
elif error_type == "validate_unsupported_keywords":
validation_info = validator(llm_output)
elif error_type == "validate_types_predicates":
validation_info = validator(new_predicates, types)
elif error_type == "validate_format_predicates":
validation_info = validator(new_predicates, types)
elif error_type == "validate_duplicate_predicates":
validation_info = validator(predicates, new_predicates)
if not validation_info[0]:
return new_predicates, llm_output, validation_info
return new_predicates, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract predicates.")
[docs]
@require_llm
def formalize_functions(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
types: dict[str, str] | list[dict[str, str]] | None = None,
constants: dict[str, str] | None = None,
predicates: list[Predicate] = None,
functions: list[Function] = None,
syntax_validator: SyntaxValidator = None,
max_retries=3,
) -> tuple[list[Function], str, tuple[bool, str]]:
"""
Formalizes PDDL :functions via LLM
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :functions extraction
types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None
constants (dict[str,str]): current constants in specification, defaults to None
predicates (list[Predicate]): list of current predicates in specification, defaults to None
functions (list[Function]): list of current functions in specification, defaults to None
syntax_validator (SyntaxValidator): syntax checker for generated functions, defaults to None
max_retries (int): max # of retries if failure occurs
Returns:
functions (list[Function]): a list of generated :functions
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool,str]): validation info containing pass flag and error message
"""
types_str = pretty_print_dict(types) if types else "No types provided."
const_str = (
format_constants(constants) if constants else "No constants provided."
)
preds_str = (
"\n".join([f"{pred['raw']}" for pred in predicates])
if predicates
else "No predicates provided."
)
funcs_str = (
"\n".join([f"{func['raw']}" for func in functions])
if functions
else "No functions provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{types}", types_str)
.replace("{constants}", const_str)
.replace("{predicates}", preds_str)
.replace("{functions}", funcs_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
# extract functions from response
functions = parse_functions(llm_output=llm_output)
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_header":
validation_info = validator(llm_output)
elif error_type == "validate_duplicate_headers":
validation_info = validator(llm_output)
elif error_type == "validate_unsupported_keywords":
validation_info = validator(llm_output)
elif error_type == "validate_format_functions":
validation_info = validator(functions, types)
if not validation_info[0]:
return functions, llm_output, validation_info
return functions, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract functions.")
[docs]
@require_llm
def extract_nl_actions(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
types: dict[str, str] | list[dict[str, str]] = None,
nl_actions: dict[str, str] = None,
max_retries: int = 3,
) -> tuple[dict[str, str], str]:
"""
Extract actions in natural language given domain description using BaseLLM.
NOTE: This is not an official formalize function. It is inspired by the NL2PLAN framework
(Gestrin et al., 2024) and is designed to guide the LLM in constructing appropriate actions.
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for dictionary extraction
types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None
nl_actions (dict[str, str]): NL actions currently in class object w/ {<name>: <description>} key-value pair
max_retries (int): max # of retries if failure occurs
Returns:
nl_actions (dict[str, str]): a dictionary of extracted NL actions {<name>: <description>}
llm_output (str): the raw string BaseLLM response
"""
types_str = pretty_print_dict(types) if types else "No types provided."
nl_act_str = (
"\n".join(f" - {name}: {desc}" for name, desc in nl_actions.items())
if nl_actions
else "No actions provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{types}", types_str)
.replace("{nl_actions}", nl_act_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
# extract respective nl actions from response
nl_actions = parse_types(llm_output=llm_output, heading="ACTIONS")
if nl_actions is not None:
return nl_actions, llm_output
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract NL actions.")
[docs]
@require_llm
def formalize_pddl_action(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
action_name: str,
action_desc: str = None,
action_list: list[str] = None,
types: dict[str, str] | list[dict[str, str]] = None,
constants: dict[str, str] | None = None,
predicates: list[Predicate] | None = None,
functions: list[Function] | None = None,
extract_new_preds=False,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[Action, list[Predicate], str, tuple[bool, str]]:
"""
Formalizes an :action and new :predicates from a given action description using BaseLLM.
Users can set `extract_new_preds (bool)` to True if tasking LLM to generate new predicates.
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :action extraction
action_name (str): action name
action_desc (str): action description, defaults to None
action_list (list[str]): list of other actions to be translated, defaults to None
types (dict[str,str] | list[dict[str,str]]): types in current specification, defaults to None
constants (dict[str,str]): current constants in specification, defaults to None
predicates (list[Predicate]): list of current predicates in specification, defaults to None
functions (list[Function]): list of current functions in specification, defaults to None
extract_new_preds (bool): flag for parsing new predicates generated from action, defaults to False
syntax_validator (SyntaxValidator): syntax checker for generated actions
max_retries (int): max # of retries if failure occurs
Returns:
action (Action): constructed action class containing :parameters, :preconditions, and :effects
new_predicates (list[Predicate]): a list of new predicates, defaults to empty list
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool, str]): validation info containing pass flag and error message
"""
act_list_str = (
"\n".join([f"- {a}" for a in action_list])
if action_list
else "No other actions provided."
)
types_str = pretty_print_dict(types) if types else "No types provided."
const_str = (
format_constants(constants) if constants else "No constants provided."
)
preds_str = (
"\n".join([f"{pred['raw']}" for pred in predicates])
if predicates
else "No predicates provided."
)
funcs_str = (
"\n".join([f"{func['raw']}" for func in functions])
if functions
else "No functions provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{action_list}", act_list_str)
.replace("{action_name}", action_name)
.replace("{action_desc}", action_desc or "No description available.")
.replace("{types}", types_str)
.replace("{constants}", const_str)
.replace("{predicates}", preds_str)
.replace("{functions}", funcs_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
# parse LLM output into action and predicates
action = parse_action(llm_output=llm_output, action_name=action_name)
if extract_new_preds:
new_predicates = parse_new_predicates(llm_output=llm_output)
else:
new_predicates = []
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_header":
validation_info = validator(llm_output)
elif error_type == "validate_duplicate_headers":
validation_info = validator(llm_output)
elif error_type == "validate_unsupported_keywords":
validation_info = validator(llm_output)
elif error_type == "validate_params":
validation_info = validator(action["params"], types)
elif error_type == "validate_duplicate_predicates":
validation_info == validator(predicates, new_predicates)
elif error_type == "validate_types_predicates":
validation_info = validator(new_predicates, types)
elif error_type == "validate_format_predicates":
validation_info = validator(new_predicates, types)
elif error_type == "validate_usage_action":
validation_info = validator(
llm_output,
predicates,
types,
functions,
extract_new_preds,
)
if not validation_info[0]:
return action, new_predicates, llm_output, validation_info
return action, new_predicates, llm_output, validation_info
except Exception as e:
print(
f"Error on attempt {attempt + 1}/{max_retries}: {e}\n"
f"LLM Output:\n{llm_output if 'llm_output' in locals() else 'None'}\nRetrying...\n"
)
time.sleep(2)
raise RuntimeError("Max retries exceeded. Failed to extract PDDL action.")
# NOTE: This function is experimental and may be subject to change in future versions.
[docs]
@require_llm
def formalize_pddl_actions(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
action_list: list[str] = None,
types: dict[str, str] | list[dict[str, str]] = None,
constants: dict[str, str] | None = None,
predicates: list[Predicate] | None = None,
functions: list[Function] | None = None,
extract_new_preds=False,
max_retries: int = 3,
) -> tuple[list[Action], list[Predicate], str]:
"""
Formalizes several :actions via LLM.
Args:
model (BaseLLM): LLM to query
domain_desc (str): domain description
prompt_template (str): action construction prompt
action_list (list[str]): list of other actions to be translated, defaults to None
types (dict[str,str] | list[dict[str,str]]): current types in specification, defaults to None
constants (dict[str,str]): current constants in specification, defaults to None
predicates (list[Predicate]): list of current predicates in specification, defaults to None
functions (list[Function]): list of current functions in specification, defaults to None
extract_new_preds (bool): flag for parsing new predicates generated from action, defaults to False
max_retries (int): max # of retries if failure occurs
Returns:
action (Action): constructed action class
new_predicates (list[Predicate]): a list of new predicates
llm_output (str): the raw string BaseLLM response
"""
act_list_str = (
"\n".join([f"- {a}" for a in action_list])
if action_list
else "No other actions provided."
)
types_str = pretty_print_dict(types) if types else "No types provided."
const_str = (
format_constants(constants) if constants else "No constants provided."
)
preds_str = (
"\n".join([f"{pred['raw']}" for pred in predicates])
if predicates
else "No predicates provided."
)
funcs_str = (
"\n".join([f"{func['raw']}" for func in functions])
if functions
else "No functions provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{action_list}", act_list_str)
.replace("{types}", types_str)
.replace("{constants}", const_str)
.replace("{predicates}", preds_str)
.replace("{functions}", funcs_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
# extract respective types from response
raw_actions = llm_output.split("## NEXT ACTION")
actions = []
for i in raw_actions:
# define the regex patterns
action_pattern = re.compile(r"\[([^\]]+)\]")
rest_of_string_pattern = re.compile(r"\[([^\]]+)\](.*)", re.DOTALL)
# search for the action name
action_match = action_pattern.search(i)
action_name = action_match.group(1) if action_match else None
# extract the rest of the string
rest_match = rest_of_string_pattern.search(i)
rest_of_string = rest_match.group(2).strip() if rest_match else None
actions.append(
parse_action(llm_output=rest_of_string, action_name=action_name)
)
# if user queries predicate creation via LLM
try:
if extract_new_preds:
new_predicates = parse_new_predicates(llm_output)
else:
new_predicates = []
if predicates:
new_predicates = [
pred
for pred in new_predicates
if pred["name"] not in [p["name"] for p in predicates]
] # remove re-defined predicates
except Exception as e:
print(f"No new predicates: {e}")
new_predicates = None
return actions, new_predicates, llm_output
except Exception as e:
print(
f"Error on attempt {attempt + 1}/{max_retries}: {e}\n"
f"LLM Output:\n{llm_output if 'llm_output' in locals() else 'None'}\nRetrying...\n"
)
time.sleep(2)
raise RuntimeError("Max retries exceeded. Failed to extract PDDL action.")
[docs]
@require_llm
def formalize_parameters(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
action_name: str,
action_desc: str = None,
types: dict[str, str] | list[dict[str, str]] | None = None,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[OrderedDict, list, str, tuple[bool, str]]:
"""
Formalizes PDDL :parameters for single action via LLM.
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :parameters extraction
action_name (str): action name
action_desc (str): action description, defaults to None
types (dict[str,str] | list(dict[str,str])): current types in specification, defaults to None
syntax_validator (SyntaxValidator): syntax checker for generated params, defaults to None
max_retries (int): max # of retries if failure occurs
Returns:
param (OrderedDict): ordered list of parameters {<?var>: <type>}
param_raw (list()): list of raw parameters
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool,str]): validation info containing pass flag and error message
"""
types_str = pretty_print_dict(types) if types else "No types provided."
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{action_name}", action_name)
.replace("{action_desc}", action_desc or "No description available.")
.replace("{types}", types_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt) # get BaseLLM response
# extract respective types from response
param, param_raw = parse_params(llm_output=llm_output)
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_header":
validation_info = validator(llm_output)
elif error_type == "validate_duplicate_headers":
validation_info = validator(llm_output)
elif error_type == "validate_unsupported_keywords":
validation_info = validator(param_raw)
elif error_type == "validate_params":
validation_info = validator(param, types)
if not validation_info[0]:
return param, param_raw, llm_output, validation_info
return param, param_raw, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract parameters.")
[docs]
@require_llm
def formalize_preconditions(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
action_name: str,
action_desc: str = None,
params: OrderedDict = None,
types: dict[str, str] | list[dict[str, str]] | None = None,
constants: dict[str, str] | None = None,
predicates: list[Predicate] | None = None,
functions: list[Function] | None = None,
extract_new_preds: bool = False,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[str, list[Predicate], str, tuple[bool, str]]:
"""
Formalizes PDDL :preconditions from a single action via LLM.
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :preconditions extraction
action_name (str): action name
action_desc (str): action description, defaults to None
params (OrderedDict): dictionary of parameters from action, defaults to None
types (dict[str,str] | list(dict[str,str])): current types in specification, defaults to None
constants (dict[str,str]): current constants in specification, defaults to None
predicates (list[Predicate]): list of current predicates in specification, defaults to None
functions (list[Function]): list of current functions in specification, defaults to None
extract_new_preds (bool): flag for parsing new predicates generated from action, defaults to False
syntax_validator (SyntaxValidator): syntax checker for generated preconditions, defaults to None
max_retries (int): max # of retries if failure occurs
Returns:
preconditions (str): PDDL format of :preconditions
new_predicates (list[Predicate]): a list of new predicates, defaults to empty list
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool,str]): validation info containing pass flag and error message
"""
params_str = format_params(params) if params else "No parameters provided."
types_str = pretty_print_dict(types) if types else "No types provided."
const_str = (
format_constants(constants) if constants else "No constants provided."
)
preds_str = (
"\n".join([f"{pred['raw']}" for pred in predicates])
if predicates
else "No predicates provided."
)
funcs_str = (
"\n".join([f"{func['raw']}" for func in functions])
if functions
else "No functions provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{action_name}", action_name)
.replace("{action_desc}", action_desc or "No description available.")
.replace("{parameters}", params_str)
.replace("{types}", types_str)
.replace("{constants}", const_str)
.replace("{predicates}", preds_str)
.replace("{functions}", funcs_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt) # get BaseLLM response
# extract respective preconditions from response
preconditions = parse_preconditions(llm_output=llm_output)
if extract_new_preds:
new_predicates = parse_new_predicates(llm_output=llm_output)
else:
new_predicates = None
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_header":
validation_info = validator(llm_output)
elif error_type == "validate_duplicate_headers":
validation_info = validator(llm_output)
elif error_type == "validate_unsupported_keywords":
validation_info = validator(preconditions)
elif error_type == "validate_duplicate_predicates":
validation_info == validator(predicates, new_predicates)
elif error_type == "validate_pddl_action":
all_predicates = predicates
all_predicates.extend(new_predicates)
validation_info = validator(
preconditions,
all_predicates,
params,
types,
"preconditions",
)
if not validation_info[0]:
return (
preconditions,
new_predicates,
llm_output,
validation_info,
)
return preconditions, new_predicates, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract preconditions.")
[docs]
@require_llm
def formalize_effects(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
action_name: str,
action_desc: str = None,
params: OrderedDict = None,
preconditions: str = None,
types: dict[str, str] | list[dict[str, str]] | None = None,
constants: dict[str, str] | None = None,
predicates: list[Predicate] | None = None,
functions: list[Function] | None = None,
extract_new_preds: bool = False,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[str, list[Predicate], str, tuple[bool, str]]:
"""
Formalizes PDDL :effects from a single action via LLM
Args:
model (BaseLLM): LLM to query
domain_desc (str): general domain description
prompt_template (str): structured prompt template for :effects extraction
action_name (str): action name
action_desc (str): action description, defaults to None
params (list[str]): list of parameters from action, defaults to None
precondition (str): PDDL format of preconditions, defaults to None
types (dict[str,str] | list(dict[str,str])): current types in specification, defaults to None
constants (dict[str,str]): current constants in specification, defaults to None
predicates (list[Predicate]): list of current predicates in specification, defaults to None
functions (list[Function]): list of current functions in specification, defaults to None
extract_new_preds (bool): flag for parsing new predicates generated from action, defaults to False
syntax_validator (SyntaxValidator): syntax checker for generated effects, defaults to None
max_retries (int): max # of retries if failure occurs
Returns:
effects (str): PDDL format of :effects
new_predicates (list[Predicate]): a list of new predicates, defaults to empty list
llm_output (str): the raw string BaseLLM response
validation_info (tuple[bool,str]): validation info containing pass flag and error message
"""
params_str = format_params(params) if params else "No parameters provided."
types_str = pretty_print_dict(types) if types else "No types provided."
const_str = (
format_constants(constants) if constants else "No constants provided."
)
preds_str = (
"\n".join([f"{pred['raw']}" for pred in predicates])
if predicates
else "No predicates provided."
)
funcs_str = (
"\n".join([f"{func['raw']}" for func in functions])
if functions
else "No functions provided."
)
prompt = (
prompt_template.replace("{domain_desc}", domain_desc)
.replace("{action_name}", action_name)
.replace("{action_desc}", action_desc or "No description available.")
.replace("{parameters}", params_str)
.replace("{preconditions}", preconditions or "No precondition provided.")
.replace("{types}", types_str)
.replace("{constants}", const_str)
.replace("{predicates}", preds_str)
.replace("{functions}", funcs_str)
)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt) # get BaseLLM response
# extract respective effects from response
effects = parse_effects(llm_output=llm_output)
if extract_new_preds:
new_predicates = parse_new_predicates(llm_output=llm_output)
else:
new_predicates = None
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_header":
validation_info = validator(llm_output)
elif error_type == "validate_duplicate_headers":
validation_info = validator(llm_output)
elif error_type == "validate_unsupported_keywords":
validation_info = validator(effects)
elif error_type == "validate_duplicate_predicates":
validation_info == validator(predicates, new_predicates)
elif error_type == "validate_pddl_action":
all_predicates = predicates
all_predicates.extend(new_predicates)
validation_info = validator(
effects, all_predicates, params, types, "effects"
)
if not validation_info[0]:
return effects, new_predicates, llm_output, validation_info
return effects, new_predicates, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError("Max retries exceeded. Failed to extract effects.")
# NOTE: This function is experimental and may be subject to change in future versions.
[docs]
@require_llm
def formalize_domain_level_specs(
self,
model: BaseLLM,
domain_desc: str,
prompt_template: str,
formalize_types: bool = False,
formalize_constants: bool = False,
formalize_predicates: bool = False,
formalize_functions: bool = False,
syntax_validator: SyntaxValidator = None,
max_retries: int = 3,
) -> tuple[dict[str, Any], str, tuple[bool, str]]:
"""
Formalizes domain-level specifications (i.e. :types, :constants, :predicates, :functions) via LLM.
Args:
model (BaseLLM): LLM to query
domain_desc (str): domain description
prompt_template (str): prompt template
formalize_types (bool): flag for extracting :types, defaults to False
formalize_constants (bool): flag for extracting :constants, defaults to False
formalize_predicates (bool): flag for extracting :predicates, defaults to False
formalize_functions (bool): flag for extracting :functions, defaults to False
syntax_validator (SyntaxValidator): syntax checker for domain specs., defaults to None
max_retries (int): max # of retries if failure occurs
Returns:
spec_results (dict[str, Any]): domain-level specifications of user requirements
llm_output (str): the raw string BaseLLM response
"""
spec_results = {} # results dictionary of top-level PDDL domain specifications
prompt = prompt_template.replace("{domain_desc}", domain_desc)
# iterate through attempts in case of extraction failure
for attempt in range(max_retries):
try:
model.reset_tokens()
llm_output = model.query(prompt=prompt)
if formalize_types:
types = parse_type_hierarchy(llm_output=llm_output)
if formalize_constants:
constants = parse_constants(llm_output=llm_output)
if formalize_predicates:
predicates = parse_new_predicates(llm_output=llm_output)
if formalize_functions:
functions = parse_functions(llm_output=llm_output)
spec_results["types"] = types
spec_results["constants"] = constants
spec_results["predicates"] = predicates
spec_results["functions"] = functions
# run syntax validation if applicable
validation_info = (True, "All validations passed.")
if syntax_validator:
for error_type in syntax_validator.error_types:
validator = getattr(syntax_validator, f"{error_type}", None)
if not callable(validator):
continue
# dispatch based on expected arguments
if error_type == "validate_format_types":
validation_info = validator(types)
elif error_type == "validate_cyclic_types":
validation_info = validator(types)
elif error_type == "validate_constant_types":
validation_info = validator(constants, types)
elif error_type == "validate_types_predicates":
validation_info = validator(predicates, types)
elif error_type == "validate_format_predicates":
validation_info = validator(predicates, types)
elif error_type == "validate_format_functions":
validation_info = validator(functions, types)
if not validation_info[0]:
return spec_results, llm_output, validation_info
return spec_results, llm_output, validation_info
except Exception as e:
print(
f"Error encountered during attempt {attempt + 1}/{max_retries}: {e}. "
f"\nLLM Output: \n\n{llm_output if 'llm_output' in locals() else 'None'}\n\n Retrying..."
)
time.sleep(2) # add a delay before retrying
raise RuntimeError(
"Max retries exceeded. Failed to extract domain specification."
)
"""Delete functions"""
[docs]
def delete_type(self, name: str):
"""Deletes a specific type from both `self.types` and `self.type_hierarchy`."""
# remove from flat types dictionary if present
if self.types is not None:
self.types = {
type_: desc for type_, desc in self.types.items() if type_ != name
}
def remove_and_promote(node_list):
updated_list = []
for node in node_list:
# get the current node's type name and description
type_name = next((k for k in node if k != "children"), None)
if type_name is None:
continue
# if this is the type to remove, promote its children to the current level
if type_name == name:
children = node.get("children", [])
updated_list.extend(remove_and_promote(children))
else:
# recursively clean the children
children = remove_and_promote(node.get("children", []))
updated_node = {type_name: node[type_name], "children": children}
updated_list.append(updated_node)
return updated_list
# update the type_hierarchy if it exists
if self.type_hierarchy is not None:
self.type_hierarchy = remove_and_promote(self.type_hierarchy)
[docs]
def delete_constants(self, name: str):
"""Deletes specific constant from current specification"""
if self.constants is not None:
self.constants = {
cons_: type_ for cons_, type_ in self.constants.items() if cons_ != name
}
[docs]
def delete_predicate(self, name: str):
"""Deletes specific predicate from current specification"""
if self.predicates is not None:
self.predicates = [
predicate for predicate in self.predicates if predicate["name"] != name
]
[docs]
def delete_function(self, name: str):
"""Deletes specific function from current specification"""
if self.functions is not None:
self.functions = [
function for function in self.functions if function["name"] != name
]
[docs]
def delete_pddl_action(self, name: str):
"""Deletes specific PDDL action from current specification"""
if self.pddl_actions is not None:
self.pddl_actions = [
action for action in self.pddl_actions if action["name"] != name
]
"""Set functions"""
[docs]
def set_types(self, types: dict[str, str]):
"""Sets types for current specification"""
self.types = types
[docs]
def set_type_hierarchy(self, type_hierarchy: list[dict[str, str]]):
"""Sets type hierarchy for current specification"""
self.type_hierarchy = type_hierarchy
[docs]
def set_constants(self, constants: dict[str, str]):
"""Sets constants for current specification"""
self.constants = constants
[docs]
def set_predicate(self, predicate: Predicate):
"""Appends a predicate for current specification"""
self.predicates.append(predicate)
[docs]
def set_function(self, function: Function):
"""Appends a function for current specification"""
self.functions.append(function)
[docs]
def set_pddl_action(self, pddl_action: Action):
"""Appends a PDDL action for current specification"""
self.pddl_actions.append(pddl_action)
"""Get functions"""
[docs]
def get_types(self) -> dict[str, str]:
"""Returns types from current specification"""
return self.types
[docs]
def get_type_hierarchy(self) -> list[dict[str, str]]:
"""Returns type hierarchy from current specification"""
return self.type_hierarchy
[docs]
def get_constants(self) -> dict[str, str]:
"""Returns constants from current specification"""
return self.constants
[docs]
def get_predicates(self) -> list[Predicate]:
"""Returns predicates from current specification"""
return self.predicates
[docs]
def get_functions(self) -> list[Function]:
"""Returns functions from current specification"""
return self.functions
[docs]
def get_pddl_actions(self) -> list[Action]:
"""Returns PDDL actions from current specification"""
return self.pddl_actions
[docs]
def generate_requirements(
self,
types: dict[str, str] | list[dict[str, str]] | None = None,
functions: list[Function] | None = None,
actions: list[Action] = None,
) -> list[str]:
"""
Generates necessary PDDL requirements based off of rest of domain specification.
Motivation was not needing LLMs to specify :requirements predeterminate of domain generation.
Currently does not support :durative-actions
Args:
types (dict[str,str] | list(dict[str,str])): current types in specification, defaults to None
functions (list[Function]): list of current functions in specification, defaults to None
actions (list[Action]): domain :action(s), defaults to None
Returns:
requirements (list[str]): list of PDDL requirements
"""
requirements = set()
requirements.add(":strips")
# check if each specification needs a :requirement
if types:
requirements.add(":typing")
if functions:
requirements.add(":numeric-fluents")
# go through actions and checks if it needs a :requirement
for action in actions:
preconditions = "\n".join(
line for line in action["preconditions"].splitlines() if line.strip()
)
effects = "\n".join(
line for line in action["effects"].splitlines() if line.strip()
)
if "not" in preconditions:
requirements.add(":negative-preconditions")
if "or" in preconditions:
requirements.add(":disjunctive-preconditions")
if "=" in preconditions:
requirements.add(":equality")
if "exists" in preconditions and "forall" in preconditions:
requirements.add(":quantified-preconditions")
else:
if "exists" in preconditions:
requirements.add(":existential-preconditions")
if "forall" in preconditions:
requirements.add(":universal-preconditions")
if "when" in effects:
requirements.add(":conditional-effects")
# replace ADL components with :adl
adl_components = {
":strips",
":typing",
":disjunctive-preconditions",
":equality",
":quantified-preconditions",
":conditional-effects",
}
if adl_components.issubset(requirements):
requirements -= adl_components
requirements.add(":adl")
requirements = list(sorted(requirements)) # convert set back into list
return requirements
[docs]
def generate_domain(
self,
domain_name: str,
types: dict[str, str] | list[dict[str, str]] | None = None,
constants: dict[str, str] | None = None,
predicates: list[Predicate] | None = None,
functions: list[Function] | None = None,
actions: list[Action] = [],
requirements: list[str] | None = None,
) -> str:
"""
Generates PDDL domain from given information.
Args:
domain_name (str): domain name
types (dict[str,str] | list[dict[str,str]] | None): domain :types, defaults to None
constants (dict[str,str] | None): domain :constants, defaults to None
predicates (list[Predicate] | None): domain :predicates, defaults to None
functions (list[Function] | None): domain :functions, defaults to None
actions (list[Action]): domain :action(s), defaults to None
requirements (list[str]): domain :requirements, defaults to constant REQUIREMENTS
Returns:
desc (str): PDDL domain in string format
"""
# generates requirements if not set
if not requirements:
requirements = self.generate_requirements(
types=types, functions=functions, actions=actions
)
desc = ""
desc += f"(define (domain {domain_name})\n"
desc += indent(string=f"(:requirements\n {' '.join(requirements)})", level=1)
if types:
types_str = format_types_to_string(types)
desc += f"\n\n (:types \n{indent(string=types_str, level=2)}\n )"
if constants:
const_str = format_constants(constants)
desc += f"\n\n (:constants \n{indent(string=const_str, level=2)}\n )"
if not predicates:
print(
"[WARNING]: Domain has no predicates. This may cause planners to reject the domain or behave unexpectedly."
)
else:
pred_str = format_expression(predicates)
desc += f"\n\n (:predicates \n{indent(string=pred_str, level=2)}\n )"
if functions:
func_str = format_expression(functions)
desc += f"\n\n (:functions \n{indent(string=func_str, level=2)}\n )"
if not actions:
print(
"[WARNING]: Domain has no actions. The planner will not be able to generate any plan unless the goal is already satisfied."
)
else:
desc += format_actions(actions)
desc += "\n)"
desc = desc.replace("AND", "and").replace("OR", "or")
return desc