Source code for finhjb.interface.boundary

import inspect
from dataclasses import dataclass, field
from functools import cached_property
from graphlib import CycleError, TopologicalSorter
from typing import Callable, Generic, Iterable, Optional

from finhjb.interface.parameter import P
from finhjb.structure._boundary import (
    DependencyMethod,
    ImmutableBoundary,
)
from finhjb.structure._grid import Grid
from finhjb.types import BOUNDARY_NAMES, BoundaryName


[docs] @dataclass(frozen=True) class BoundaryConditionTarget: """ Specification for one endogenous-boundary target used by boundary search. A target says: - which boundary field should be varied, - how to evaluate the resulting residual on a solved `Grid`, - and, optionally, which bracket/tolerance settings should be used for bisection-style search. Notes ----- - Only boundaries that appear in `Model.boundary_condition()` are searched by `Solver.boundary_search()`. - The order of targets in that list defines the parameter order passed to nonlinear search methods. - For nested `bisection`, the same order also defines the outer-to-inner search order. Attributes ---------- boundary_name : BoundaryName Boundary field to optimize, such as `s_max` or `v_left`. condition_func : Callable[["Grid"], float] Residual evaluated on the solved grid. Search aims to drive this value to zero. low : Optional[float] Lower bracket used by `method="bisection"`. Ignored by the other search methods. high : Optional[float] Upper bracket used by `method="bisection"`. Ignored by the other search methods. tol : float Per-target tolerance used by `method="bisection"`. max_iter : int Per-target iteration cap used by `method="bisection"`. """ boundary_name: BoundaryName condition_func: Callable[["Grid"], float] # Optional parameters for Bisection method low: Optional[float] = None high: Optional[float] = None tol: float = 1e-6 max_iter: int = 50
[docs] @dataclass class AbstractBoundary(Generic[P]): """ An intelligent, auto-configuring class for defining HJB problem boundaries. This class serves as a configuration object for boundary values. It automatically discovers dependencies between boundaries by inspecting the signatures of `compute_<boundary_name>` methods defined in subclasses. Usage ----- 1. Subclass AbstractBoundary. 2. For boundaries that can be computed from parameters or other boundaries, define methods like `compute_v_right(self, s_max: float) -> float`. The dependency (`s_max`) is automatically inferred from the signature. 3. Instantiate the subclass, providing any boundary values that cannot be computed as keyword arguments (e.g., `MyBoundary(p=params, s_min=0.0)`). Examples -------- :: class MyBoundary(AbstractBoundary): @staticmethod def compute_s_max(p) -> float: return p.x_bar @staticmethod def compute_v_right(s_max: float, p) -> float: return s_max * 2.0 boundary = MyBoundary(p=params, s_min=0.0, v_left=0.0) frozen_boundary = boundary.frozen_boundary Attributes ---------- p : Parameter The parameter object containing model parameters. s_min : Optional[float] The minimum state boundary value. s_max : Optional[float] The maximum state boundary value. v_left : Optional[float] The value function at the left boundary. v_right : Optional[float] The value function at the right boundary. already_boundary : dict[BoundaryName, float] A dictionary of boundary values that were provided directly. independent_boundary : set[BoundaryName] A set of boundary names that were provided directly. frozen_boundary : ImmutableBoundary The fully computed and immutable boundary object. required_boundary : set[BoundaryName] A set of all boundary names required to compute the full boundary. boundary_dependencies : dict[BoundaryName, set[BoundaryName]] A mapping of boundary names to their dependencies. graph : list[DependencyMethod] A topologically sorted list of methods to compute boundary values. """ p: P = field(repr=False) s_min: Optional[float] = None s_max: Optional[float] = None v_left: Optional[float] = None v_right: Optional[float] = None already_boundary: dict[BoundaryName, float] = field(init=False, repr=False) independent_boundary: set[BoundaryName] = field(init=False, repr=False) frozen_boundary: ImmutableBoundary = field(init=False, repr=False) def __post_init__(self): """Compute and cache immutable boundary values at construction.""" ... # Compute and store the frozen boundary upon initialization self.frozen_boundary = self.freeze()
[docs] @cached_property def required_boundary(self) -> set[BoundaryName]: """All dependencies required to compute boundaries.""" return {dep for deps in self.boundary_dependencies.values() for dep in deps}
[docs] @cached_property def boundary_dependencies( self, ) -> dict[BoundaryName, set[BoundaryName]]: """Dict of compute methods and their dependencies.""" dependency_dict: dict[BoundaryName, set[BoundaryName]] = {} for name in BOUNDARY_NAMES: compute_method_name = f"compute_{name}" if hasattr(self, compute_method_name): method = getattr(self, compute_method_name) if getattr(self, name) is not None: raise ValueError( f'Boundary "{name}" and its compute method "{compute_method_name}()" cannot be defined simultaneously!' ) if not callable(method): raise TypeError( f"Attribute '{compute_method_name}' must be callable." ) sig = inspect.signature(method) dependencies: set[BoundaryName] = set( filter(lambda b: b != "p", sig.parameters.keys()) ) # pyright: ignore[reportAssignmentType] if dependencies: dependency_dict[name] = dependencies # pyright: ignore[reportArgumentType] else: setattr(self, name, method(p=self.p)) # Store already provided boundary values self.already_boundary = { # pyright: ignore[reportAttributeAccessIssue] name: getattr(self, name) for name in BOUNDARY_NAMES if getattr(self, name) is not None } self.independent_boundary = set(self.already_boundary.keys()) return dependency_dict
[docs] @cached_property def graph(self) -> list[DependencyMethod]: """ Topologically sorted list of methods to compute boundary values. Returns ------- list[DependencyMethod] A list of methods with their metadata, sorted in the order they should be executed. Raises ------ ValueError If there are circular dependencies or missing dependencies. """ try: order: Iterable[BoundaryName] = TopologicalSorter( self.boundary_dependencies ).static_order() except CycleError as e: raise ValueError( f"Circular dependency detected in boundary calculations. " f"Dependencies: {self.boundary_dependencies}" ) from e methods = [] i = 1 for name in order: if name in self.already_boundary: continue required_deps = self.boundary_dependencies[name] if missing_deps := required_deps - self.already_boundary.keys(): raise ValueError( f"Cannot compute '{name}': missing dependencies {missing_deps}. " f"These must be either provided in the constructor or have their own compute_* methods." ) methods.append( DependencyMethod( order=i, name=name, deps=required_deps, method=getattr(self, f"compute_{name}"), ) ) i += 1 return methods
[docs] def freeze(self): """ Computes and returns an ImmutableBoundary with all boundary values set. Returns ------- ImmutableBoundary The fully computed and validated immutable boundary object. """ if not self.boundary_dependencies: return self._create_boundary(self.already_boundary) for item in self.graph: name = item["name"] method = item["method"] deps = item["deps"] self.already_boundary[name] = method( **{dep: self.already_boundary[dep] for dep in deps} | {"p": self.p} ) return self._create_boundary(self.already_boundary)
def _create_boundary(self, values: dict[BoundaryName, float]) -> ImmutableBoundary: """ Creates an ImmutableBoundary after validating all values are present and valid. Parameters ---------- values : dict[BoundaryName, float] Dictionary containing all four boundary values. Returns ------- ImmutableBoundary The validated immutable boundary object. Raises ------ ValueError If any boundary values are missing or invalid. """ # Validate all boundaries are present missing = set(BOUNDARY_NAMES) - values.keys() if missing: raise ValueError( f"Could not determine all boundary values. Missing: {missing}" ) # Semantic validation if values["s_min"] >= values["s_max"]: raise ValueError( f"Invalid boundary: s_min ({values['s_min']}) must be " f"strictly less than s_max ({values['s_max']})." ) # print(f"Computed boundary values: {values}") return ImmutableBoundary( s_min=float(values["s_min"]), s_max=float(values["s_max"]), v_left=float(values["v_left"]), v_right=float(values["v_right"]), graph=self.graph, )