from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic
import jax.numpy as jnp
from finhjb.interface.boundary import BoundaryConditionTarget
from finhjb.interface.parameter import P
from finhjb.interface.policy import AbstractPolicy, D
from finhjb.structure._boundary import ImmutableBoundary
from finhjb.structure._grid import Grid
from finhjb.types import ArrayInter, ArrayN
[docs]
@dataclass
class AbstractModel(ABC, Generic[P, D]):
"""
Abstract base class for HJB models.
This class defines the interface for HJB models,
including the required HJB residual and several optional hooks for jumps,
endogenous boundaries, outer-loop boundary updates, and custom diagnostics.
Methods
-------
initialize_policy
Create an initial policy guess for the solver.
update_policy
Update the policy variables using the current value function and its derivatives.
hjb_residual
Calculate the pointwise HJB residual on the interior grid.
jump
Calculate the jump term for the HJB equation.
boundary_condition
Declare endogenous boundary targets for `Solver.boundary_search()`.
update_boundary
Return direct boundary updates for `Solver.boundary_update()`.
auxiliary
Return user-defined diagnostics exposed through `Grid.aux`.
Notes
-----
- Subclasses must implement the abstract `hjb_residual` method.
- The `jump` method is optional to override; the default implementation returns zero jumps.
- Keep parameter orders and return shapes as specified; the solver expects these signatures.
- These are `static` methods so remember to add the `@staticmethod` decorator.
"""
policy: AbstractPolicy
[docs]
@staticmethod
@abstractmethod
def hjb_residual(
v: ArrayInter,
dv: ArrayInter,
d2v: ArrayInter,
s: ArrayInter,
policy: D,
jump: ArrayInter,
boundary: ImmutableBoundary,
p: P,
) -> ArrayInter:
"""
(The NECESSARY method you need to implement)
--------------------------------------------
Calculate the pointwise HJB residual on the interior grid.
Notes
-----
- Keep the parameter order and return shape as specified; the solver expects this signature.
- This is a `static` method and does not have access to instance attributes.
Parameters
----------
v : Float[Array, "N-2"]
Value function evaluated at interior grid points.
dv : Float[Array, "N-2"]
First derivative of the value function at interior points.
d2v : Float[Array, "N-2"]
Second derivative of the value function at interior points.
s : Float[Array, "N-2"]
State variable values at interior grid points.
policy : dict[str, Float[Array, "N-2"]]
Mapping of policy variable names to their values on the interior grid.
jump : Float[Array, "N-2"]
Jump term evaluated at each interior grid point.
boundary : FrozenBoundary
Boundary values for both state and value function.
p : Parameter
Model parameters.
Returns
-------
Float[Array, "N-2"]
HJB residual evaluated at each interior grid point.
Examples
--------
::
@staticmethod
def hjb_residual(v, dv, d2v, s, policy, jump, boundary, p):
control1 = policy["control1"]
residual = ...
return residual
"""
[docs]
@staticmethod
def jump(
v: ArrayN,
s: ArrayN,
policy: D,
boundary: ImmutableBoundary,
p: P,
) -> ArrayN:
"""
(This method is OPTIONAL to override)
Calculate the jump term for the HJB equation.
The default implementation returns zero jumps. Override this only when
your HJB contains a non-zero jump or Poisson-arrival term.
Notes
-----
- The solver evaluates this hook through `Grid.jump_inter`, so in
practice `v`, `s`, and `policy` are the interior-grid slices.
- You can override this method in subclasses to implement specific jump dynamics.
- Keep the parameter order and return shape as specified; the solver expects this signature.
- This is a `static` method and does not have access to instance attributes.
Parameters
----------
v : Float[Array, "N-2"]
Value function evaluated at interior grid points.
s : Float[Array, "N-2"]
State variable values at interior grid points.
policy : dict[str, Float[Array, "N-2"]]
Mapping of policy variable names to their values on the interior grid.
boundary : FrozenBoundary
Boundary values for both state and value function.
p : Parameter
Model parameters.
Returns
-------
Float[Array, "N-2"]
Jump term evaluated at each interior grid point.
"""
return jnp.zeros_like(s)
[docs]
@staticmethod
def boundary_condition() -> list[BoundaryConditionTarget]:
"""
(This method is OPTIONAL to override)
Return endogenous-boundary targets for `Solver.boundary_search()`.
Each returned `BoundaryConditionTarget` specifies:
- which boundary field should be optimized,
- how to evaluate the boundary residual on the solved grid,
- and, for `method="bisection"`, the bracket and per-target search settings.
Notes
-----
- Only boundaries listed here are optimized by `boundary_search()`.
- The order of targets in the returned list defines the boundary vector
order for nonlinear search methods.
- For nested `bisection`, the same order defines the outer-to-inner
search order.
- `low`, `high`, `tol`, and `max_iter` are used by `bisection`.
Other search methods use `Config.bs_tol` and `Config.bs_max_iter`.
"""
return []
[docs]
@staticmethod
def update_boundary(grid: Grid):
"""
Return direct boundary updates for the boundary-update workflow.
Implement this hook when a solved grid directly implies revised boundary
values and an update error, so that `Solver.boundary_update()` can run
the outer loop `solve -> update boundary -> solve again`.
"""
raise NotImplementedError(
"The `update_boundary` method is not implemented for this model."
)
[docs]
@staticmethod
def auxiliary(grid: Grid):
"""
Return user-defined diagnostics derived from a solved grid.
`Grid.aux` is a thin proxy for this hook. Leaving it unimplemented means
`grid.aux` will raise `NotImplementedError`, which is expected.
A common pattern is to return a small dictionary of derived summaries,
for example `{\"value_mean\": jnp.mean(grid.v)}`.
"""
raise NotImplementedError(
"The `auxiliary` method is not implemented for this model."
)