Source code for finhjb.orchestration.solver

from dataclasses import dataclass, field
from typing import Generic, Optional

from finhjb.algorithm.boundary_search import SearchMethods, get_boundary_search_solver
from finhjb.algorithm.boundary_update import BoundaryUpdate, BoundaryUpdateState
from finhjb.algorithm.continuation import SensitivityAnalysis, SensitivityResult
from finhjb.algorithm.evaluation import EvaluationState
from finhjb.algorithm.policy_iteration import PolicyIteration, PolicyIterationState
from finhjb.config import Config
from finhjb.interface.boundary import AbstractBoundary
from finhjb.interface.guess import AbstractValueGuess
from finhjb.interface.model import AbstractModel
from finhjb.interface.parameter import P
from finhjb.interface.policy import D
from finhjb.structure._grid import Grid
from finhjb.types import Array


[docs] @dataclass class Solver(Generic[P, D]): """High-level orchestrator for solving HJB models on one-dimensional grids.""" boundary: Optional[AbstractBoundary] = field(repr=False, default=None) model: Optional[AbstractModel] = field(repr=False, default=None) value_guess: Optional[AbstractValueGuess[P]] = field(repr=False, default=None) policy_guess: bool = field(default=True, repr=True) number: int = field(default=1000, repr=True) config: Config = field(default_factory=Config) grid: Optional[Grid] = field(init=True, repr=False, default=None) def __post_init__(self): """Initialize the working grid and algorithm backends.""" if self.grid is None: if self.boundary is None or self.model is None: raise ValueError( "Either provide a grid or both boundary and model to initialize the grid." ) self._grid: Grid[P, D] = Grid( p=self.boundary.p, boundary=self.boundary.frozen_boundary, model=self.model, value_guess=self.value_guess, policy_guess=self.policy_guess, number=self.number, config=self.config, ).reset() else: self._grid: Grid[P, D] = self.grid.reset() self.policy_iteration = PolicyIteration(config=self.config) self.boundary_update_solver = BoundaryUpdate( config=self.config, policy_iteration=self.policy_iteration, )
[docs] def solve(self) -> tuple[PolicyIterationState | EvaluationState, Array]: """Run policy iteration (or one-step evaluation) on the active grid.""" final_state, history_of_errors = self.policy_iteration.policy_iteration( grid=self._grid, jit=True, ) return final_state, history_of_errors
def _ensure_boundary_update_available(self) -> None: """Validate model support for boundary update workflow.""" if type(self._grid.model).update_boundary is AbstractModel.update_boundary: raise NotImplementedError( "`Solver.boundary_update()` requires the model class to implement " "`update_boundary(grid)`." )
[docs] def boundary_update(self) -> tuple[BoundaryUpdateState, Array]: """Iteratively update model boundaries and re-solve the HJB system.""" self._ensure_boundary_update_available() return self.boundary_update_solver.update(grid=self._grid)
[docs] def sensitivity_analysis( self, method: SearchMethods, param_name: str, param_values: Array, ) -> SensitivityResult: """Solve the model along a parameter path using continuation.""" sensitivity_solver = SensitivityAnalysis( config=self.config, method=method, ) return sensitivity_solver.solve_over_parameter_path( grid=self._grid, param_name=param_name, param_values=param_values, )