API Reference#

Use this page when you already know what object or method you are looking for and want exact exported names, members, and loader behavior.

Read it after Library Quickstart, Getting Started, or Solver Guide. If you still need conceptual guidance, go back to Modeling Guide or Results and Diagnostics.

Use the tutorials first if you are still learning the workflow:

If you want to…

Read this first

install and run the first example

Installation and Environment

reproduce the BCW baseline examples

Getting Started

understand returned objects and diagnostics

Results and Diagnostics

adapt BCW to your own model

Adapting BCW to Your Model

understand workflow choice

Solver Guide

Come back here when you want exact exported names, method members, and loader behavior.

Top-Level Exports (finhjb)#

Core#

  • Config

  • Solver

  • Grid

  • Grids

  • ImmutableBoundary

Interfaces#

  • AbstractBoundary

  • BoundaryConditionTarget

  • AbstractModel

  • AbstractParameter

  • AbstractPolicy

  • AbstractPolicyDict

  • AbstractValueGuess

Helpers#

  • explicit_policy(order: int)

  • implicit_policy(...)

  • LinearInitialValue

  • QuadraticInitialValue

Loading#

  • load_grid(path)

  • load_grids(path)

  • load_sensitivity_result(path)

API By Task#

Task

Objects you will touch first

define a model

AbstractParameter, AbstractBoundary, AbstractPolicy, AbstractModel

run one fixed-boundary solve

Solver, Config

search for an endogenous boundary

BoundaryConditionTarget, Solver.boundary_search()

inspect a solved object

Grid, Grid.df, Grid.aux

store and reload results

Grid.save, load_grid, Grids.save, load_grids, load_sensitivity_result

Loading Functions In Detail#

The three load_* functions differ by what you want to restore: a single solved grid, a grid collection, or a full sensitivity result.

Function

Restored type

Matching save call

Typical use

load_grid(path)

Grid

state.grid.save(path)

one solved run

load_grids(path)

Grids

result.grids.save(path)

many solved grids along parameter values

load_sensitivity_result(path)

SensitivityResult

result.save(path)

full continuation output (summary + grids)

Guaranteed behavior:

  • .pkl suffix is auto-added,

  • type is validated after loading,

  • using the wrong loader raises TypeError.

Example: load_grid#

import finhjb as fjb

state, _ = solver.solve()
state.grid.save("outputs/baseline_grid")

grid = fjb.load_grid("outputs/baseline_grid")
print(type(grid).__name__)
print(grid.df.head())

Example: load_grids#

import finhjb as fjb
import jax.numpy as jnp

result = solver.sensitivity_analysis(
    method="hybr",
    param_name="sigma",
    param_values=jnp.array([0.09, 0.10, 0.11]),
)
result.grids.save("outputs/sigma_grids")

grids = fjb.load_grids("outputs/sigma_grids")
print(type(grids).__name__)
print(list(grids.values))

Example: load_sensitivity_result#

import finhjb as fjb

result.save("outputs/sigma_result")
loaded = fjb.load_sensitivity_result("outputs/sigma_result")

print(type(loaded).__name__)
print(loaded.df.head())

Common Loading Mistakes#

  1. Loading a continuation result with load_grid.

  2. Forgetting that the loader auto-adds .pkl.

  3. Loading the full continuation result when you only needed a single grid.

Solver Methods#

  • Solver.solve() -> (PolicyIterationState | EvaluationState, history)

  • Solver.boundary_update() -> (BoundaryUpdateState, history)

  • Solver.boundary_search(method, verbose=False) -> BoundarySearchState

  • Solver.sensitivity_analysis(method, param_name, param_values) -> SensitivityResult

See Solver Guide for when to use each one.

Boundary Search Method Notes#

Supported Solver.boundary_search(method=...) values:

  • bisection

  • hybr

  • lm

  • broyden

  • gauss_newton

  • lbfgs

  • krylov

  • broyden1

Important behavior:

  • boundary_condition() returns the exact list of boundaries that will be searched.

  • The order of that list defines the boundary-parameter order for nonlinear methods.

  • For bisection, the same order also defines the nested outer-to-inner search order.

  • BoundaryConditionTarget.low and high matter only for bisection.

  • BoundaryConditionTarget.tol and max_iter also matter only for bisection.

  • All the other methods use Config.bs_tol and Config.bs_max_iter.

  • lbfgs minimizes squared residual loss rather than solving the root problem directly.

Model Hook Quick Reference#

The most important optional AbstractModel hooks are:

  • jump(...): optional, default zero, evaluated by the solver through Grid.jump_inter.

  • boundary_condition(): returns the BoundaryConditionTarget list for boundary_search().

  • update_boundary(grid): used only by boundary_update().

  • auxiliary(grid): exposed through Grid.aux; leaving it unimplemented means grid.aux raises NotImplementedError.

Grid Convenience#

Grid properties:

  • s, v, dv, d2v

  • s_inter, policy_inter, number_inter, jump_inter

  • df, aux

Notes:

  • jump_inter is the interior-grid evaluation of Model.jump(...).

  • aux is just a proxy for Model.auxiliary(grid).

  • A common auxiliary(grid) pattern is to return a small dictionary of derived diagnostics.

Grid methods:

  • reset()

  • update_grid(boundary)

  • update_with_v_inter(v_inter)

  • save(path)

Grids methods:

  • get, select_grids, add, merge, save

For interpretation, go to Results and Diagnostics.

API Details#

Config#

class Config(*, enable_x64=True, derivative_method='central', policy_guess=False, pi_method='scan', pi_max_iter=50, pi_tol=1e-06, pi_patience=10, pe_max_iter=10, pe_tol=1e-06, pe_patience=5, bs_max_iter=20, bs_tol=1e-06, bs_patience=5, aa_history_size=5, aa_mixing_frequency=1, aa_beta=1.0, aa_ridge=0)[source]

Runtime configuration for numerical solvers and differentiation rules.

The configuration object is validated by Pydantic and then used by algorithm modules (evaluation, policy_iteration, and boundary routines) to control convergence thresholds, iteration limits, and differentiation behavior.

Parameters:
  • enable_x64 (bool)

  • derivative_method (Literal['central', 'forward', 'backward'])

  • policy_guess (bool)

  • pi_method (Literal['scan', 'anderson'])

  • pi_max_iter (int)

  • pi_tol (float)

  • pi_patience (int)

  • pe_max_iter (int)

  • pe_tol (float)

  • pe_patience (int)

  • bs_max_iter (int)

  • bs_tol (float)

  • bs_patience (int)

  • aa_history_size (int)

  • aa_mixing_frequency (int)

  • aa_beta (float)

  • aa_ridge (float)

model_config: ClassVar[ConfigDict] = {'use_enum_values': True, 'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

set_derivative_function()[source]

Sets the derivative function based on the method.

Return type:

Config

property dv_func: Callable

Return the finite-difference function selected by configuration.

setup_jax()[source]

Applies JAX configuration after model validation.

Return type:

Config

model_post_init(context, /)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that’s what pydantic-core passes when calling it.

Parameters:
  • self (BaseModel) – The BaseModel instance.

  • context (Any) – The context.

Return type:

None

Solver#

class Solver(boundary=None, model=None, value_guess=None, policy_guess=True, number=1000, config=<factory>, grid=None)[source]

High-level orchestrator for solving HJB models on one-dimensional grids.

Parameters:
  • boundary (AbstractBoundary | None)

  • model (AbstractModel | None)

  • value_guess (AbstractValueGuess[P] | None)

  • policy_guess (bool)

  • number (int)

  • config (Config)

  • grid (Grid | None)

solve()[source]

Run policy iteration (or one-step evaluation) on the active grid.

Return type:

tuple[PolicyIterationState | EvaluationState, Array]

boundary_update()[source]

Iteratively update model boundaries and re-solve the HJB system.

Return type:

tuple[BoundaryUpdateState, Array]

boundary_search(method, verbose=False)[source]

Search optimal boundaries by solving boundary conditions as root problems.

Parameters:

method (Literal['gauss_newton', 'lm', 'broyden', 'lbfgs', 'bisection', 'hybr', 'broyden1', 'krylov'])

sensitivity_analysis(method, param_name, param_values)[source]

Solve the model along a parameter path using continuation.

Parameters:
  • method (Literal['gauss_newton', 'lm', 'broyden', 'lbfgs', 'bisection', 'hybr', 'broyden1', 'krylov'])

  • param_name (str)

  • param_values (Array)

Return type:

SensitivityResult

Structures#

class Grid(p, boundary, model, h=0, s=<factory>, v=<factory>, v_inter=<factory>, dv=<factory>, d2v=<factory>, policy=<factory>, value_guess=None, policy_guess=False, number=1000, config=<factory>)[source]

Immutable computational grid and solver state container.

Grid stores state-space coordinates, value function approximations, derivatives, boundary values, and policy variables. The object is a Flax PyTree so it can be transformed by JAX while keeping model/config objects as static fields.

Parameters:
  • p (P)

  • boundary (ImmutableBoundary)

  • model (AbstractModel[P, PolicyDictType])

  • h (float)

  • s (Float[Array, 'N'])

  • v (Float[Array, 'N'])

  • v_inter (Float[Array, 'N-2'])

  • dv (Float[Array, 'N'])

  • d2v (Float[Array, 'N'])

  • policy (PolicyDictType)

  • value_guess (AbstractValueGuess[P] | None)

  • policy_guess (bool)

  • number (int)

  • config (Config)

reset()[source]

Rebuild the full grid from boundary/parameter/policy configuration.

Return type:

Self

update_grid(boundary)[source]

Update boundary values and resample state grid if s bounds changed.

Parameters:

boundary (ImmutableBoundary)

Return type:

Self

property optimizable_boundaries[source]

Return boundary names targeted by model-defined boundary conditions.

property policy_in_axes[source]

Returns the axes for the policy parameters.

property s_inter: Float[Array, 'N-2']

Interior state grid (excluding both boundary points).

property policy_inter: PolicyDictType

Interior slices of all policy arrays.

property number_inter: int

Number of interior grid points.

property jump_inter

Evaluate Model.jump(…) on the interior grid slices.

update_with_v_inter(v_inter)[source]

Update full value and derivative arrays from interior values.

Parameters:

v_inter (Float[Array, 'N-2'])

Return type:

Self

property df

Convert grid data to a pandas DataFrame for easy inspection.

property aux

Return Model.auxiliary(grid=self) for user-defined diagnostics.

save(file_path)[source]

Save the Grid to a pickle file.

Parameters:

file_path (str | Path)

Return type:

None

replace(**updates)

Returns a new object replacing the specified fields with new values.

class Grids(param_name='???', data=<factory>)[source]

Collection of solved grids indexed by a scalar continuation parameter.

Parameters:
  • param_name (str)

  • data (dict[float, Grid])

property values

Sorted parameter values contained in this subset.

get(value, default=None)[source]

Get a grid by key with a default fallback.

Parameters:
  • value (float)

  • default (Grid | None)

Return type:

Grid | None

save(file_path)[source]

Save the Grids to a pickle file.

Parameters:

file_path (str | Path)

Return type:

None

select_grids(values, *, atol=1e-08, rtol=1e-06)[source]

Select grids for specific parameter values and return a Grids object.

Parameters:
  • values (Iterable[float])

  • atol (float)

  • rtol (float)

Return type:

Grids

add(label, grid)[source]

Insert or replace one grid at label.

Parameters:
  • label (float)

  • grid (Grid)

Return type:

Grids

merge(other)[source]

Merge two Grids collections with right-hand overwrite on conflicts.

Parameters:

other (Grids)

Return type:

Grids

class ImmutableBoundary(s_min, s_max, v_left, v_right, graph)[source]

Immutable boundary values structure.

Parameters:
  • s_min (float)

  • s_max (float)

  • v_left (float)

  • v_right (float)

  • graph (list[DependencyMethod])

s_min

Minimum state variable value.

Type:

float

s_max

Maximum state variable value.

Type:

float

v_left

Value function at the left boundary.

Type:

float

v_right

Value function at the right boundary.

Type:

float

get_boundaries()[source]

Return (s_min, s_max, v_left, v_right) as a tuple.

Return type:

tuple[float, float, float, float]

get_boundary_dict()[source]

Return all boundary values as a dictionary keyed by boundary name.

Return type:

dict[Literal[‘s_min’, ‘s_max’, ‘v_left’, ‘v_right’], float]

update_boundaries(boundary_dict, p)[source]

Return a new boundary object after applying dependency graph updates.

Parameters:
  • boundary_dict (dict[Literal['s_min', 's_max', 'v_left', 'v_right'], float])

  • p (P)

s_changed(boundary)[source]

Check whether state-space limits changed versus another boundary.

Parameters:

boundary (Self)

replace(**updates)

Returns a new object replacing the specified fields with new values.

Interfaces#

class AbstractBoundary(p, s_min=None, s_max=None, v_left=None, v_right=None)[source]

An intelligent, auto-configuring class for defining HJB problem boundaries.

This class serves as a configuration object for boundary values. It automatically discovers dependencies between boundaries by inspecting the signatures of compute_<boundary_name> methods defined in subclasses.

Usage#

  1. Subclass AbstractBoundary.

  2. For boundaries that can be computed from parameters or other boundaries, define methods like compute_v_right(self, s_max: float) -> float. The dependency (s_max) is automatically inferred from the signature.

  3. Instantiate the subclass, providing any boundary values that cannot be computed as keyword arguments (e.g., MyBoundary(p=params, s_min=0.0)).

Examples

class MyBoundary(AbstractBoundary):
    @staticmethod
    def compute_s_max(p) -> float:
        return p.x_bar

    @staticmethod
    def compute_v_right(s_max: float, p) -> float:
        return s_max * 2.0

boundary = MyBoundary(p=params, s_min=0.0, v_left=0.0)
frozen_boundary = boundary.frozen_boundary
p

The parameter object containing model parameters.

Type:

Parameter

s_min

The minimum state boundary value.

Type:

Optional[float]

s_max

The maximum state boundary value.

Type:

Optional[float]

v_left

The value function at the left boundary.

Type:

Optional[float]

v_right

The value function at the right boundary.

Type:

Optional[float]

already_boundary

A dictionary of boundary values that were provided directly.

Type:

dict[BoundaryName, float]

independent_boundary

A set of boundary names that were provided directly.

Type:

set[BoundaryName]

frozen_boundary

The fully computed and immutable boundary object.

Type:

ImmutableBoundary

required_boundary[source]

A set of all boundary names required to compute the full boundary.

Type:

set[BoundaryName]

boundary_dependencies[source]

A mapping of boundary names to their dependencies.

Type:

dict[BoundaryName, set[BoundaryName]]

graph[source]

A topologically sorted list of methods to compute boundary values.

Type:

list[DependencyMethod]

property required_boundary: set[Literal['s_min', 's_max', 'v_left', 'v_right']][source]

All dependencies required to compute boundaries.

property boundary_dependencies: dict[Literal['s_min', 's_max', 'v_left', 'v_right'], set[Literal['s_min', 's_max', 'v_left', 'v_right']]][source]

Dict of compute methods and their dependencies.

property graph: list[DependencyMethod][source]

Topologically sorted list of methods to compute boundary values.

Returns:

list[DependencyMethod] – A list of methods with their metadata, sorted in the order they should be executed.

Raises:

ValueError – If there are circular dependencies or missing dependencies.

freeze()[source]

Computes and returns an ImmutableBoundary with all boundary values set.

Returns:

ImmutableBoundary – The fully computed and validated immutable boundary object.

Parameters:
  • p (P)

  • s_min (float | None)

  • s_max (float | None)

  • v_left (float | None)

  • v_right (float | None)

class BoundaryConditionTarget(boundary_name, condition_func, low=None, high=None, tol=1e-06, max_iter=50)[source]

Specification for one endogenous-boundary target used by boundary search.

A target says:

  • which boundary field should be varied,

  • how to evaluate the resulting residual on a solved Grid,

  • and, optionally, which bracket/tolerance settings should be used for bisection-style search.

Notes

  • Only boundaries that appear in Model.boundary_condition() are searched by Solver.boundary_search().

  • The order of targets in that list defines the parameter order passed to nonlinear search methods.

  • For nested bisection, the same order also defines the outer-to-inner search order.

Parameters:
  • boundary_name (Literal['s_min', 's_max', 'v_left', 'v_right'])

  • condition_func (Callable[[Grid], float])

  • low (float | None)

  • high (float | None)

  • tol (float)

  • max_iter (int)

boundary_name

Boundary field to optimize, such as s_max or v_left.

Type:

BoundaryName

condition_func

Residual evaluated on the solved grid. Search aims to drive this value to zero.

Type:

Callable[[“Grid”], float]

low

Lower bracket used by method=”bisection”. Ignored by the other search methods.

Type:

Optional[float]

high

Upper bracket used by method=”bisection”. Ignored by the other search methods.

Type:

Optional[float]

tol

Per-target tolerance used by method=”bisection”.

Type:

float

max_iter

Per-target iteration cap used by method=”bisection”.

Type:

int

class AbstractModel(policy)[source]

Abstract base class for HJB models.

This class defines the interface for HJB models, including the required HJB residual and several optional hooks for jumps, endogenous boundaries, outer-loop boundary updates, and custom diagnostics.

Parameters:

policy (AbstractPolicy)

initialize_policy()

Create an initial policy guess for the solver.

update_policy()

Update the policy variables using the current value function and its derivatives.

hjb_residual()[source]

Calculate the pointwise HJB residual on the interior grid.

Parameters:
  • v (Float[Array, 'N-2'])

  • dv (Float[Array, 'N-2'])

  • d2v (Float[Array, 'N-2'])

  • s (Float[Array, 'N-2'])

  • policy (D)

  • jump (Float[Array, 'N-2'])

  • boundary (ImmutableBoundary)

  • p (P)

Return type:

Float[Array, ‘N-2’]

jump()[source]

Calculate the jump term for the HJB equation.

Parameters:
  • v (Float[Array, 'N'])

  • s (Float[Array, 'N'])

  • policy (D)

  • boundary (ImmutableBoundary)

  • p (P)

Return type:

Float[Array, ‘N’]

boundary_condition()[source]

Declare endogenous boundary targets for Solver.boundary_search().

Return type:

list[BoundaryConditionTarget]

update_boundary()[source]

Return direct boundary updates for Solver.boundary_update().

Parameters:

grid (Grid)

auxiliary()[source]

Return user-defined diagnostics exposed through Grid.aux.

Parameters:

grid (Grid)

Notes

  • Subclasses must implement the abstract hjb_residual method.

  • The jump method is optional to override; the default implementation returns zero jumps.

  • Keep parameter orders and return shapes as specified; the solver expects these signatures.

  • These are static methods so remember to add the @staticmethod decorator.

abstract static hjb_residual(v, dv, d2v, s, policy, jump, boundary, p)[source]

(The NECESSARY method you need to implement)#

Calculate the pointwise HJB residual on the interior grid.

Notes

  • Keep the parameter order and return shape as specified; the solver expects this signature.

  • This is a static method and does not have access to instance attributes.

param v:

Value function evaluated at interior grid points.

type v:

Float[Array, “N-2”]

param dv:

First derivative of the value function at interior points.

type dv:

Float[Array, “N-2”]

param d2v:

Second derivative of the value function at interior points.

type d2v:

Float[Array, “N-2”]

param s:

State variable values at interior grid points.

type s:

Float[Array, “N-2”]

param policy:

Mapping of policy variable names to their values on the interior grid.

type policy:

dict[str, Float[Array, “N-2”]]

param jump:

Jump term evaluated at each interior grid point.

type jump:

Float[Array, “N-2”]

param boundary:

Boundary values for both state and value function.

type boundary:

FrozenBoundary

param p:

Model parameters.

type p:

Parameter

returns:

Float[Array, “N-2”] – HJB residual evaluated at each interior grid point.

Examples

@staticmethod
def hjb_residual(v, dv, d2v, s, policy, jump, boundary, p):
    control1 = policy["control1"]
    residual = ...
    return residual
Parameters:
  • v (Float[Array, 'N-2'])

  • dv (Float[Array, 'N-2'])

  • d2v (Float[Array, 'N-2'])

  • s (Float[Array, 'N-2'])

  • policy (D)

  • jump (Float[Array, 'N-2'])

  • boundary (ImmutableBoundary)

  • p (P)

Return type:

Float[Array, ‘N-2’]

static jump(v, s, policy, boundary, p)[source]

(This method is OPTIONAL to override)

Calculate the jump term for the HJB equation.

The default implementation returns zero jumps. Override this only when your HJB contains a non-zero jump or Poisson-arrival term.

Notes

  • The solver evaluates this hook through Grid.jump_inter, so in practice v, s, and policy are the interior-grid slices.

  • You can override this method in subclasses to implement specific jump dynamics.

  • Keep the parameter order and return shape as specified; the solver expects this signature.

  • This is a static method and does not have access to instance attributes.

Parameters:
  • v (Float[Array, "N-2"]) – Value function evaluated at interior grid points.

  • s (Float[Array, "N-2"]) – State variable values at interior grid points.

  • policy (dict[str, Float[Array, "N-2"]]) – Mapping of policy variable names to their values on the interior grid.

  • boundary (FrozenBoundary) – Boundary values for both state and value function.

  • p (Parameter) – Model parameters.

Returns:

Float[Array, “N-2”] – Jump term evaluated at each interior grid point.

Return type:

Float[Array, ‘N’]

static boundary_condition()[source]

(This method is OPTIONAL to override)

Return endogenous-boundary targets for Solver.boundary_search().

Each returned BoundaryConditionTarget specifies:

  • which boundary field should be optimized,

  • how to evaluate the boundary residual on the solved grid,

  • and, for method=”bisection”, the bracket and per-target search settings.

Notes

  • Only boundaries listed here are optimized by boundary_search().

  • The order of targets in the returned list defines the boundary vector order for nonlinear search methods.

  • For nested bisection, the same order defines the outer-to-inner search order.

  • low, high, tol, and max_iter are used by bisection. Other search methods use Config.bs_tol and Config.bs_max_iter.

Return type:

list[BoundaryConditionTarget]

static update_boundary(grid)[source]

Return direct boundary updates for the boundary-update workflow.

Implement this hook when a solved grid directly implies revised boundary values and an update error, so that Solver.boundary_update() can run the outer loop solve -> update boundary -> solve again.

Parameters:

grid (Grid)

static auxiliary(grid)[source]

Return user-defined diagnostics derived from a solved grid.

Grid.aux is a thin proxy for this hook. Leaving it unimplemented means grid.aux will raise NotImplementedError, which is expected.

A common pattern is to return a small dictionary of derived summaries, for example {“value_mean”: jnp.mean(grid.v)}.

Parameters:

grid (Grid)

class AbstractParameter[source]

Abstract base class for model parameters.

This class serves as a base for immutable, hashable parameter containers required by JAX (e.g. for use with static_argnames). Subclasses must be decorated with @struct.dataclass(frozen=True) and should declare all model parameters as class attributes with default values.

Examples

from functools import cached_property
from flax import struct

class Parameter(AbstractParameter):
    r: float
    g: float
    gamma: float
    ...

    apply_func: Callable = struct.field(pytree_node=False, repr=False)

    @cached_property
    def vFB(self) -> float:
        return self.r + self.gamma * self.g / (self.gamma - 1)

params = Parameter(r=0.06, g=0.03, gamma=0.25, ...)

Notes

  • Do not add mutable fields or methods that modify instance state.

  • Derived-parameter methods may be added in subclasses, but they must be decorated with @cached_property to ensure immutability.

update(boundary)[source]

Update the parameter object with the boundary object.

Parameters:

boundary (ImmutableBoundary) – The boundary object to update the parameter object with.

Returns:

Self – The updated parameter object.

Return type:

Self

replace(**updates)

Returns a new object replacing the specified fields with new values.

class AbstractPolicy[source]

Composable policy update interface for explicit/implicit steps.

create_solve_func(solver)[source]

Create a JAX-vectored solve function from a jaxopt solver instance.

Parameters:

solver (Broyden | LevenbergMarquardt | GaussNewton)

Return type:

Callable[[Float[Array, ‘N K’], Float[Array, ‘N’], Float[Array, ‘N’], Float[Array, ‘N’], Float[Array, ‘N’], P], OptStep]

abstract static initialize(grid, p)[source]

(The NECESSARY method you need to implement)#

Create an initial policy guess for the solver.

This abstract method should be implemented by subclasses to provide an initial policy for the control variables used by the HJB solver.

Notes

  • The method is called once during solver setup.

  • The solver expects every policy variable it references during iteration to be present in the returned mapping. Failing to provide required keys may raise an error when the solver starts.

  • The returned object must be a dictionary-like mapping (type PolicyDict) that contains entries for every policy variable the solver expects.

  • If you have a meaningful initial guess for the policy, you should return it here, and the solver will use it as the starting point when guess_policy is True.

  • If you do not have a meaningful initial guess, you can set guess_policy to False in the solver configuration. Then the solver will ignore this initial guess and perform an initial policy improvement step before starting iterations.

Implementations may freely use the following inputs during initialization: - p : Parameter - grid.s : Float[Array, “N”] - grid.v : Float[Array, “N”] - grid.number : int - …

returns:
  • A dictionary-like container (PolicyDict) that maps policy variable names to arrays of values.

  • Each array should have a shape compatible with the solver’s state grid

  • (typically the same shape as self.s or a 1-D array of length self.number).

Examples

class PolicyDict(AbstractPolicyDict):
    control_var1: Array
    control_var2: Array

def initialize_policy(self, grid: Grid, p: P) -> PolicyDict:
    control_var1_val: Array = ...
    return PolicyDict(
        control_var1=jnp.full_like(grid.s, control_var1_val),
        control_var2=jnp.ones(grid.number),
    )
Parameters:
  • grid (Grid)

  • p (P)

Return type:

D

update(grid)[source]

Execute compiled policy update steps and return updated grid.

Parameters:

grid (Grid)

Return type:

Grid

class AbstractPolicyDict[source]

Base TypedDict for policy variables.

This class is used to store variables that are not changed during the policy evaluation step, such as control variables or other auxiliary variables.

So any variables that are not changed during the policy evaluation step should be included here to avoid repeated computations.

Subclass this to declare concrete policy keys and their types, for example:

class PolicyDict(AbstractPolicyDict):
    investment: Array
    consumption: Array
    drift: Array
    diffusion: Array

Notes

  • This class is intended for static type checking.

  • Do not instantiate AbstractPolicyDict directly — declare a subclass instead.

  • Keep value types as Array for JAX compatibility.

class AbstractValueGuess(p, boundary)[source]

Abstract class for initial value function guess.

This class requires subclasses to implement the guess_value method, which provides an initial guess for the value function on the computational grid.

Parameters:
  • p (Parameter) – Model parameters (subclass of AbstractParameter).

  • boundary (AbstractBoundary[P]) – Boundary conditions for the state variable and value function.

s_min

Minimum boundary for the state variable.

Type:

float

s_max

Maximum boundary for the state variable.

Type:

float

v_left

Value function at the left boundary (corresponding to s_min).

Type:

float

v_right

Value function at the right boundary (corresponding to s_max).

Type:

float

guess_value()[source]

Provide an initial guess for the value function.

Parameters:

s (Float[Array, 'N'])

Return type:

Float[Array, ‘N’]

abstract guess_value(s)[source]

(The NECESSARY method you need to implement)

Provide an initial guess for the value function.

Notes

  • You can use self.s_min, self.s_max, self.v_left, and self.v_right to access the boundary conditions.

  • You can also access parameters via self.p.

  • The input s is a grid of state variable values where the value function should be evaluated.

  • The output should be an array of the same shape as s, representing the initial guess for the value function at those points.

Parameters:

s (Float[Array, "N"]) – Grid for the state variable.

Returns:

v (Float[Array, “N”]) – Initial guess for the value function on the grid s.

Return type:

Float[Array, ‘N’]

class LinearInitialValue(p, boundary)[source]

Linear value function guess.

The value function is guessed to be a linear function connecting the boundary values.

Parameters:
  • p (P)

  • boundary (ImmutableBoundary)

guess_value(s)[source]

Construct a linear initial guess linking boundary value endpoints.

Parameters:

s (Float[Array, 'N'])

Return type:

Float[Array, ‘N’]

class QuadraticInitialValue(p, boundary, a_sign, curvature=0.5)[source]

Quadratic initial value guess.

The quadratic function is defined as: v(s) = a * s^2 + b * s + c where the coefficient ‘a’ is either -1 or 1 (indicating concavity or convexity), and coefficients ‘b’ and ‘c’ are determined to satisfy the boundary values: - v(s_min) = v_left - v(s_max) = v_right

Parameters:
  • a_sign (Literal[-1, 1]) – Coefficient determining the concavity (-1) or convexity (1) of the quadratic function.

  • curvature (float, default=0.5) – A parameter in the interval (0, 1] that influences the curvature of the quadratic function. A value closer to 0 results in a flatter curve, while a value of 1 yields a standard quadratic shape.

  • p (P)

  • boundary (ImmutableBoundary)

guess_value(s)[source]

Evaluate the boundary-matching quadratic initial guess on grid s.

Parameters:

s (Float[Array, 'N'])

Return type:

Float[Array, ‘N’]

Helpers#

explicit_policy(order)[source]

Decorator to mark a method as an explicit solver with optional execution order.

Parameters:

order (int)

Return type:

Callable

implicit_policy(order, solver='gauss_newton', maxiter=10, tol=1e-08, implicit_diff=True, verbose=0, policy_order=[], **solver_kwargs)[source]

Decorator for implicit policy’s FOC.

Parameters:
  • order (int)

  • solver (Literal['gauss_newton', 'broyden', 'lm', 'newton_raphson'])

  • maxiter (int)

  • tol (float)

  • verbose (int)

  • policy_order (list[str])

Return type:

Callable

Loading Functions#

load_grid(file_path)[source]

Load a Grid object from .pkl.

Parameters:

file_path (str | Path)

Return type:

Grid

load_grids(file_path)[source]

Load a Grids object from .pkl.

Parameters:

file_path (str | Path)

Return type:

Grids

load_sensitivity_result(file_path)[source]

Load a SensitivityResult object from .pkl.

Parameters:

file_path (str | Path)

Return type:

SensitivityResult