Source code for finhjb.interface.model

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic

import jax.numpy as jnp

from finhjb.interface.boundary import BoundaryConditionTarget
from finhjb.interface.parameter import P
from finhjb.interface.policy import AbstractPolicy, D
from finhjb.structure._boundary import ImmutableBoundary
from finhjb.structure._grid import Grid
from finhjb.types import ArrayInter, ArrayN


[docs] @dataclass class AbstractModel(ABC, Generic[P, D]): """ Abstract base class for HJB models. This class defines the interface for HJB models, including the required HJB residual and several optional hooks for jumps, endogenous boundaries, outer-loop boundary updates, and custom diagnostics. Methods ------- initialize_policy Create an initial policy guess for the solver. update_policy Update the policy variables using the current value function and its derivatives. hjb_residual Calculate the pointwise HJB residual on the interior grid. jump Calculate the jump term for the HJB equation. boundary_condition Declare endogenous boundary targets for `Solver.boundary_search()`. update_boundary Return direct boundary updates for `Solver.boundary_update()`. auxiliary Return user-defined diagnostics exposed through `Grid.aux`. Notes ----- - Subclasses must implement the abstract `hjb_residual` method. - The `jump` method is optional to override; the default implementation returns zero jumps. - Keep parameter orders and return shapes as specified; the solver expects these signatures. - These are `static` methods so remember to add the `@staticmethod` decorator. """ policy: AbstractPolicy
[docs] @staticmethod @abstractmethod def hjb_residual( v: ArrayInter, dv: ArrayInter, d2v: ArrayInter, s: ArrayInter, policy: D, jump: ArrayInter, boundary: ImmutableBoundary, p: P, ) -> ArrayInter: """ (The NECESSARY method you need to implement) -------------------------------------------- Calculate the pointwise HJB residual on the interior grid. Notes ----- - Keep the parameter order and return shape as specified; the solver expects this signature. - This is a `static` method and does not have access to instance attributes. Parameters ---------- v : Float[Array, "N-2"] Value function evaluated at interior grid points. dv : Float[Array, "N-2"] First derivative of the value function at interior points. d2v : Float[Array, "N-2"] Second derivative of the value function at interior points. s : Float[Array, "N-2"] State variable values at interior grid points. policy : dict[str, Float[Array, "N-2"]] Mapping of policy variable names to their values on the interior grid. jump : Float[Array, "N-2"] Jump term evaluated at each interior grid point. boundary : FrozenBoundary Boundary values for both state and value function. p : Parameter Model parameters. Returns ------- Float[Array, "N-2"] HJB residual evaluated at each interior grid point. Examples -------- :: @staticmethod def hjb_residual(v, dv, d2v, s, policy, jump, boundary, p): control1 = policy["control1"] residual = ... return residual """
[docs] @staticmethod def jump( v: ArrayN, s: ArrayN, policy: D, boundary: ImmutableBoundary, p: P, ) -> ArrayN: """ (This method is OPTIONAL to override) Calculate the jump term for the HJB equation. The default implementation returns zero jumps. Override this only when your HJB contains a non-zero jump or Poisson-arrival term. Notes ----- - The solver evaluates this hook through `Grid.jump_inter`, so in practice `v`, `s`, and `policy` are the interior-grid slices. - You can override this method in subclasses to implement specific jump dynamics. - Keep the parameter order and return shape as specified; the solver expects this signature. - This is a `static` method and does not have access to instance attributes. Parameters ---------- v : Float[Array, "N-2"] Value function evaluated at interior grid points. s : Float[Array, "N-2"] State variable values at interior grid points. policy : dict[str, Float[Array, "N-2"]] Mapping of policy variable names to their values on the interior grid. boundary : FrozenBoundary Boundary values for both state and value function. p : Parameter Model parameters. Returns ------- Float[Array, "N-2"] Jump term evaluated at each interior grid point. """ return jnp.zeros_like(s)
[docs] @staticmethod def boundary_condition() -> list[BoundaryConditionTarget]: """ (This method is OPTIONAL to override) Return endogenous-boundary targets for `Solver.boundary_search()`. Each returned `BoundaryConditionTarget` specifies: - which boundary field should be optimized, - how to evaluate the boundary residual on the solved grid, - and, for `method="bisection"`, the bracket and per-target search settings. Notes ----- - Only boundaries listed here are optimized by `boundary_search()`. - The order of targets in the returned list defines the boundary vector order for nonlinear search methods. - For nested `bisection`, the same order defines the outer-to-inner search order. - `low`, `high`, `tol`, and `max_iter` are used by `bisection`. Other search methods use `Config.bs_tol` and `Config.bs_max_iter`. """ return []
[docs] @staticmethod def update_boundary(grid: Grid): """ Return direct boundary updates for the boundary-update workflow. Implement this hook when a solved grid directly implies revised boundary values and an update error, so that `Solver.boundary_update()` can run the outer loop `solve -> update boundary -> solve again`. """ raise NotImplementedError( "The `update_boundary` method is not implemented for this model." )
[docs] @staticmethod def auxiliary(grid: Grid): """ Return user-defined diagnostics derived from a solved grid. `Grid.aux` is a thin proxy for this hook. Leaving it unimplemented means `grid.aux` will raise `NotImplementedError`, which is expected. A common pattern is to return a small dictionary of derived summaries, for example `{\"value_mean\": jnp.mean(grid.v)}`. """ raise NotImplementedError( "The `auxiliary` method is not implemented for this model." )