API 参考#

当你已经知道自己要找哪个对象、哪个方法,或者想确认精确导出名与 loader 行为时,再来这里最合适。

如果你还需要概念解释,而不是查精确名字,请先去 建模指南结果与诊断

如果你还在熟悉工作流,建议先读这些教程,再回来查名字和成员:

如果你现在想做的是……

先读这一页

安装并跑通第一个案例

安装与环境

复现 BCW 基准案例

快速开始

看懂返回对象和数值诊断

结果与诊断

把 BCW 改成自己的模型

把 BCW 改成你自己的模型

决定使用哪种求解工作流

求解器指南

当你需要确认精确导出名、类成员、方法签名或 loader 行为时,再回到这里。

顶层导出(finhjb#

核心对象#

  • Config

  • Solver

  • Grid

  • Grids

  • ImmutableBoundary

接口类型#

  • AbstractBoundary

  • BoundaryConditionTarget

  • AbstractModel

  • AbstractParameter

  • AbstractPolicy

  • AbstractPolicyDict

  • AbstractValueGuess

辅助#

  • explicit_policy(order: int)

  • implicit_policy(...)

  • LinearInitialValue

  • QuadraticInitialValue

加载函数#

  • load_grid(path)

  • load_grids(path)

  • load_sensitivity_result(path)

按任务找 API#

任务

你最先会碰到的对象

定义模型

AbstractParameterAbstractBoundaryAbstractPolicyAbstractModel

跑一次固定边界求解

SolverConfig

搜索内生边界

BoundaryConditionTargetSolver.boundary_search()

检查一个解

GridGrid.dfGrid.aux

保存与重载结果

Grid.saveload_gridGrids.saveload_gridsload_sensitivity_result

加载函数详解#

三个 load_* 的核心区别,是你到底想恢复:

  • 单个 Grid

  • 一组 Grids

  • 还是完整的 SensitivityResult

函数

恢复对象类型

对应保存方法

典型用途

load_grid(path)

Grid

state.grid.save(path)

单次求解的完整网格

load_grids(path)

Grids

result.grids.save(path)

一组参数点上的网格集合

load_sensitivity_result(path)

SensitivityResult

result.save(path)

continuation summary + 全部网格

保证行为:

  • 路径会自动补 .pkl

  • 加载后会做类型校验;

  • 用错 loader 会明确抛出 TypeError

load_grid 示例#

import finhjb as fjb

state, _ = solver.solve()
state.grid.save("outputs/baseline_grid")

grid = fjb.load_grid("outputs/baseline_grid")
print(type(grid).__name__)
print(grid.df.head())

load_grids 示例#

import finhjb as fjb
import jax.numpy as jnp

result = solver.sensitivity_analysis(
    method="hybr",
    param_name="sigma",
    param_values=jnp.array([0.09, 0.10, 0.11]),
)
result.grids.save("outputs/sigma_grids")

grids = fjb.load_grids("outputs/sigma_grids")
print(type(grids).__name__)
print(list(grids.values))

load_sensitivity_result 示例#

import finhjb as fjb

result.save("outputs/sigma_result")
loaded = fjb.load_sensitivity_result("outputs/sigma_result")

print(type(loaded).__name__)
print(loaded.df.head())

常见加载错误#

  1. load_grid 去读取 continuation 或 sensitivity 的保存结果;

  2. 忘了 loader 会自动补 .pkl

  3. 其实只想要单点网格,却误用了整个 SensitivityResult

Solver 的主要方法#

  • Solver.solve() -> (PolicyIterationState | EvaluationState, history)

  • Solver.boundary_update() -> (BoundaryUpdateState, history)

  • Solver.boundary_search(method, verbose=False) -> BoundarySearchState

  • Solver.sensitivity_analysis(method, param_name, param_values) -> SensitivityResult

什么时候该用哪个方法,请配合 求解器指南 阅读。

AbstractModel 可选钩子速查#

最重要的几个可选钩子是:

  • jump(...):可选,默认是零,由求解器通过 Grid.jump_inter 调用。

  • boundary_condition():为 boundary_search() 返回 BoundaryConditionTarget 列表。

  • update_boundary(grid):只在 boundary_update() 工作流里使用。

  • auxiliary(grid):通过 Grid.aux 暴露;如果没实现,grid.auxNotImplementedError 是正常的。

Grid 的便捷属性#

Grid 常用属性:

  • s, v, dv, d2v

  • s_inter, policy_inter, number_inter, jump_inter

  • df, aux

补充说明:

  • jump_interModel.jump(...) 在内部网格上的求值结果。

  • aux 只是 Model.auxiliary(grid) 的代理入口。

  • auxiliary(grid) 一个很常见的写法,是返回一个小字典来保存派生诊断量。

Grid 常用方法:

  • reset()

  • update_grid(boundary)

  • update_with_v_inter(v_inter)

  • save(path)

Grids 常用方法:

  • get

  • select_grids

  • add

  • merge

  • save

如何解释这些对象,请看 结果与诊断

API 详细文档#

Config#

class Config(*, enable_x64=True, derivative_method='central', policy_guess=False, pi_method='scan', pi_max_iter=50, pi_tol=1e-06, pi_patience=10, pe_max_iter=10, pe_tol=1e-06, pe_patience=5, bs_max_iter=20, bs_tol=1e-06, bs_patience=5, aa_history_size=5, aa_mixing_frequency=1, aa_beta=1.0, aa_ridge=0)[source]

Runtime configuration for numerical solvers and differentiation rules.

The configuration object is validated by Pydantic and then used by algorithm modules (evaluation, policy_iteration, and boundary routines) to control convergence thresholds, iteration limits, and differentiation behavior.

Parameters:
  • enable_x64 (bool)

  • derivative_method (Literal['central', 'forward', 'backward'])

  • policy_guess (bool)

  • pi_method (Literal['scan', 'anderson'])

  • pi_max_iter (Annotated[int, Gt(gt=0)])

  • pi_tol (Annotated[float, Gt(gt=0)])

  • pi_patience (Annotated[int, Ge(ge=0)])

  • pe_max_iter (Annotated[int, Gt(gt=0)])

  • pe_tol (Annotated[float, Gt(gt=0)])

  • pe_patience (Annotated[int, Ge(ge=0)])

  • bs_max_iter (Annotated[int, Gt(gt=0)])

  • bs_tol (Annotated[float, Gt(gt=0)])

  • bs_patience (Annotated[int, Ge(ge=0)])

  • aa_history_size (Annotated[int, Gt(gt=0)])

  • aa_mixing_frequency (Annotated[int, Gt(gt=0)])

  • aa_beta (Annotated[float, Gt(gt=0)])

  • aa_ridge (Annotated[float, Ge(ge=0)])

model_config: ClassVar[ConfigDict] = {'use_enum_values': True, 'validate_assignment': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

set_derivative_function()[source]

Sets the derivative function based on the method.

Return type:

Config

property dv_func: Callable

Return the finite-difference function selected by configuration.

setup_jax()[source]

Applies JAX configuration after model validation.

Return type:

Config

model_post_init(context, /)

This function is meant to behave like a BaseModel method to initialise private attributes.

It takes context as an argument since that’s what pydantic-core passes when calling it.

Parameters:
  • self (BaseModel) – The BaseModel instance.

  • context (Any) – The context.

Return type:

None

Solver#

class Solver(boundary=None, model=None, value_guess=None, policy_guess=True, number=1000, config=<factory>, grid=None)[source]

High-level orchestrator for solving HJB models on one-dimensional grids.

Parameters:
  • boundary (AbstractBoundary | None)

  • model (AbstractModel | None)

  • value_guess (AbstractValueGuess[P] | None)

  • policy_guess (bool)

  • number (int)

  • config (Config)

  • grid (Grid | None)

solve()[source]

Run policy iteration (or one-step evaluation) on the active grid.

Return type:

tuple[PolicyIterationState | EvaluationState, Array]

boundary_update()[source]

Iteratively update model boundaries and re-solve the HJB system.

Return type:

tuple[BoundaryUpdateState, Array]

boundary_search(method, verbose=False)[source]

Search optimal boundaries by solving boundary conditions as root problems.

Parameters:

method (Literal['gauss_newton', 'lm', 'broyden', 'lbfgs', 'bisection', 'hybr', 'broyden1', 'krylov'])

sensitivity_analysis(method, param_name, param_values)[source]

Solve the model along a parameter path using continuation.

Parameters:
  • method (Literal['gauss_newton', 'lm', 'broyden', 'lbfgs', 'bisection', 'hybr', 'broyden1', 'krylov'])

  • param_name (str)

  • param_values (Array)

Return type:

SensitivityResult

数据结构#

class Grid(p, boundary, model, h=0, s=<factory>, v=<factory>, v_inter=<factory>, dv=<factory>, d2v=<factory>, policy=<factory>, value_guess=None, policy_guess=False, number=1000, config=<factory>)[source]

Immutable computational grid and solver state container.

Grid stores state-space coordinates, value function approximations, derivatives, boundary values, and policy variables. The object is a Flax PyTree so it can be transformed by JAX while keeping model/config objects as static fields.

Parameters:
  • p (P)

  • boundary (ImmutableBoundary)

  • model (AbstractModel[P, PolicyDictType])

  • h (float)

  • s (Float[Array, 'N'])

  • v (Float[Array, 'N'])

  • v_inter (Float[Array, 'N-2'])

  • dv (Float[Array, 'N'])

  • d2v (Float[Array, 'N'])

  • policy (PolicyDictType)

  • value_guess (AbstractValueGuess[P] | None)

  • policy_guess (bool)

  • number (int)

  • config (Config)

reset()[source]

Rebuild the full grid from boundary/parameter/policy configuration.

Return type:

Self

update_grid(boundary)[source]

Update boundary values and resample state grid if s bounds changed.

Parameters:

boundary (ImmutableBoundary)

Return type:

Self

property optimizable_boundaries[source]

Return boundary names targeted by model-defined boundary conditions.

property policy_in_axes[source]

Returns the axes for the policy parameters.

property s_inter: Float[Array, 'N-2']

Interior state grid (excluding both boundary points).

property policy_inter: PolicyDictType

Interior slices of all policy arrays.

property number_inter: int

Number of interior grid points.

property jump_inter

Evaluate Model.jump(…) on the interior grid slices.

update_with_v_inter(v_inter)[source]

Update full value and derivative arrays from interior values.

Parameters:

v_inter (Float[Array, 'N-2'])

Return type:

Self

property df

Convert grid data to a pandas DataFrame for easy inspection.

property aux

Return Model.auxiliary(grid=self) for user-defined diagnostics.

save(file_path)[source]

Save the Grid to a pickle file.

Parameters:

file_path (str | Path)

Return type:

None

replace(**updates)

Returns a new object replacing the specified fields with new values.

class Grids(param_name='???', data=<factory>)[source]

Collection of solved grids indexed by a scalar continuation parameter.

Parameters:
  • param_name (str)

  • data (dict[float, Grid])

property values

Sorted parameter values contained in this subset.

get(value, default=None)[source]

Get a grid by key with a default fallback.

Parameters:
  • value (float)

  • default (Grid | None)

Return type:

Grid | None

save(file_path)[source]

Save the Grids to a pickle file.

Parameters:

file_path (str | Path)

Return type:

None

select_grids(values, *, atol=1e-08, rtol=1e-06)[source]

Select grids for specific parameter values and return a Grids object.

Parameters:
  • values (Iterable[float])

  • atol (float)

  • rtol (float)

Return type:

Grids

add(label, grid)[source]

Insert or replace one grid at label.

Parameters:
  • label (float)

  • grid (Grid)

Return type:

Grids

merge(other)[source]

Merge two Grids collections with right-hand overwrite on conflicts.

Parameters:

other (Grids)

Return type:

Grids

class ImmutableBoundary(s_min, s_max, v_left, v_right, graph)[source]

Immutable boundary values structure.

Parameters:
  • s_min (float)

  • s_max (float)

  • v_left (float)

  • v_right (float)

  • graph (list[DependencyMethod])

s_min

Minimum state variable value.

Type:

float

s_max

Maximum state variable value.

Type:

float

v_left

Value function at the left boundary.

Type:

float

v_right

Value function at the right boundary.

Type:

float

get_boundaries()[source]

Return (s_min, s_max, v_left, v_right) as a tuple.

Return type:

tuple[float, float, float, float]

get_boundary_dict()[source]

Return all boundary values as a dictionary keyed by boundary name.

Return type:

dict[Literal[‘s_min’, ‘s_max’, ‘v_left’, ‘v_right’], float]

update_boundaries(boundary_dict, p)[source]

Return a new boundary object after applying dependency graph updates.

Parameters:
  • boundary_dict (dict[Literal['s_min', 's_max', 'v_left', 'v_right'], float])

  • p (P)

s_changed(boundary)[source]

Check whether state-space limits changed versus another boundary.

Parameters:

boundary (Self)

replace(**updates)

Returns a new object replacing the specified fields with new values.

接口类型#

class AbstractBoundary(p, s_min=None, s_max=None, v_left=None, v_right=None)[source]

An intelligent, auto-configuring class for defining HJB problem boundaries.

This class serves as a configuration object for boundary values. It automatically discovers dependencies between boundaries by inspecting the signatures of compute_<boundary_name> methods defined in subclasses.

Usage#

  1. Subclass AbstractBoundary.

  2. For boundaries that can be computed from parameters or other boundaries, define methods like compute_v_right(self, s_max: float) -> float. The dependency (s_max) is automatically inferred from the signature.

  3. Instantiate the subclass, providing any boundary values that cannot be computed as keyword arguments (e.g., MyBoundary(p=params, s_min=0.0)).

Examples

class MyBoundary(AbstractBoundary):
    @staticmethod
    def compute_s_max(p) -> float:
        return p.x_bar

    @staticmethod
    def compute_v_right(s_max: float, p) -> float:
        return s_max * 2.0

boundary = MyBoundary(p=params, s_min=0.0, v_left=0.0)
frozen_boundary = boundary.frozen_boundary
p

The parameter object containing model parameters.

Type:

Parameter

s_min

The minimum state boundary value.

Type:

Optional[float]

s_max

The maximum state boundary value.

Type:

Optional[float]

v_left

The value function at the left boundary.

Type:

Optional[float]

v_right

The value function at the right boundary.

Type:

Optional[float]

already_boundary

A dictionary of boundary values that were provided directly.

Type:

dict[BoundaryName, float]

independent_boundary

A set of boundary names that were provided directly.

Type:

set[BoundaryName]

frozen_boundary

The fully computed and immutable boundary object.

Type:

ImmutableBoundary

required_boundary[source]

A set of all boundary names required to compute the full boundary.

Type:

set[BoundaryName]

boundary_dependencies[source]

A mapping of boundary names to their dependencies.

Type:

dict[BoundaryName, set[BoundaryName]]

graph[source]

A topologically sorted list of methods to compute boundary values.

Type:

list[DependencyMethod]

property required_boundary: set[Literal['s_min', 's_max', 'v_left', 'v_right']][source]

All dependencies required to compute boundaries.

property boundary_dependencies: dict[Literal['s_min', 's_max', 'v_left', 'v_right'], set[Literal['s_min', 's_max', 'v_left', 'v_right']]][source]

Dict of compute methods and their dependencies.

property graph: list[DependencyMethod][source]

Topologically sorted list of methods to compute boundary values.

Returns:

list[DependencyMethod] – A list of methods with their metadata, sorted in the order they should be executed.

Raises:

ValueError – If there are circular dependencies or missing dependencies.

freeze()[source]

Computes and returns an ImmutableBoundary with all boundary values set.

Returns:

ImmutableBoundary – The fully computed and validated immutable boundary object.

Parameters:
  • p (P)

  • s_min (float | None)

  • s_max (float | None)

  • v_left (float | None)

  • v_right (float | None)

class BoundaryConditionTarget(boundary_name, condition_func, low=None, high=None, tol=1e-06, max_iter=50)[source]

Specification for one endogenous-boundary target used by boundary search.

A target says:

  • which boundary field should be varied,

  • how to evaluate the resulting residual on a solved Grid,

  • and, optionally, which bracket/tolerance settings should be used for bisection-style search.

Notes

  • Only boundaries that appear in Model.boundary_condition() are searched by Solver.boundary_search().

  • The order of targets in that list defines the parameter order passed to nonlinear search methods.

  • For nested bisection, the same order also defines the outer-to-inner search order.

Parameters:
  • boundary_name (Literal['s_min', 's_max', 'v_left', 'v_right'])

  • condition_func (Callable[[Grid], float])

  • low (float | None)

  • high (float | None)

  • tol (float)

  • max_iter (int)

boundary_name

Boundary field to optimize, such as s_max or v_left.

Type:

BoundaryName

condition_func

Residual evaluated on the solved grid. Search aims to drive this value to zero.

Type:

Callable[[“Grid”], float]

low

Lower bracket used by method=”bisection”. Ignored by the other search methods.

Type:

Optional[float]

high

Upper bracket used by method=”bisection”. Ignored by the other search methods.

Type:

Optional[float]

tol

Per-target tolerance used by method=”bisection”.

Type:

float

max_iter

Per-target iteration cap used by method=”bisection”.

Type:

int

class AbstractModel(policy)[source]

Abstract base class for HJB models.

This class defines the interface for HJB models, including the required HJB residual and several optional hooks for jumps, endogenous boundaries, outer-loop boundary updates, and custom diagnostics.

Parameters:

policy (AbstractPolicy)

initialize_policy()

Create an initial policy guess for the solver.

update_policy()

Update the policy variables using the current value function and its derivatives.

hjb_residual()[source]

Calculate the pointwise HJB residual on the interior grid.

Parameters:
  • v (Float[Array, 'N-2'])

  • dv (Float[Array, 'N-2'])

  • d2v (Float[Array, 'N-2'])

  • s (Float[Array, 'N-2'])

  • policy (D)

  • jump (Float[Array, 'N-2'])

  • boundary (ImmutableBoundary)

  • p (P)

Return type:

Float[Array, ‘N-2’]

jump()[source]

Calculate the jump term for the HJB equation.

Parameters:
  • v (Float[Array, 'N'])

  • s (Float[Array, 'N'])

  • policy (D)

  • boundary (ImmutableBoundary)

  • p (P)

Return type:

Float[Array, ‘N’]

boundary_condition()[source]

Declare endogenous boundary targets for Solver.boundary_search().

Return type:

list[BoundaryConditionTarget]

update_boundary()[source]

Return direct boundary updates for Solver.boundary_update().

Parameters:

grid (Grid)

auxiliary()[source]

Return user-defined diagnostics exposed through Grid.aux.

Parameters:

grid (Grid)

Notes

  • Subclasses must implement the abstract hjb_residual method.

  • The jump method is optional to override; the default implementation returns zero jumps.

  • Keep parameter orders and return shapes as specified; the solver expects these signatures.

  • These are static methods so remember to add the @staticmethod decorator.

abstract static hjb_residual(v, dv, d2v, s, policy, jump, boundary, p)[source]

(The NECESSARY method you need to implement)#

Calculate the pointwise HJB residual on the interior grid.

Notes

  • Keep the parameter order and return shape as specified; the solver expects this signature.

  • This is a static method and does not have access to instance attributes.

param v:

Value function evaluated at interior grid points.

type v:

Float[Array, “N-2”]

param dv:

First derivative of the value function at interior points.

type dv:

Float[Array, “N-2”]

param d2v:

Second derivative of the value function at interior points.

type d2v:

Float[Array, “N-2”]

param s:

State variable values at interior grid points.

type s:

Float[Array, “N-2”]

param policy:

Mapping of policy variable names to their values on the interior grid.

type policy:

dict[str, Float[Array, “N-2”]]

param jump:

Jump term evaluated at each interior grid point.

type jump:

Float[Array, “N-2”]

param boundary:

Boundary values for both state and value function.

type boundary:

FrozenBoundary

param p:

Model parameters.

type p:

Parameter

returns:

Float[Array, “N-2”] – HJB residual evaluated at each interior grid point.

Examples

@staticmethod
def hjb_residual(v, dv, d2v, s, policy, jump, boundary, p):
    control1 = policy["control1"]
    residual = ...
    return residual
Parameters:
  • v (Float[Array, 'N-2'])

  • dv (Float[Array, 'N-2'])

  • d2v (Float[Array, 'N-2'])

  • s (Float[Array, 'N-2'])

  • policy (D)

  • jump (Float[Array, 'N-2'])

  • boundary (ImmutableBoundary)

  • p (P)

Return type:

Float[Array, ‘N-2’]

static jump(v, s, policy, boundary, p)[source]

(This method is OPTIONAL to override)

Calculate the jump term for the HJB equation.

The default implementation returns zero jumps. Override this only when your HJB contains a non-zero jump or Poisson-arrival term.

Notes

  • The solver evaluates this hook through Grid.jump_inter, so in practice v, s, and policy are the interior-grid slices.

  • You can override this method in subclasses to implement specific jump dynamics.

  • Keep the parameter order and return shape as specified; the solver expects this signature.

  • This is a static method and does not have access to instance attributes.

Parameters:
  • v (Float[Array, "N-2"]) – Value function evaluated at interior grid points.

  • s (Float[Array, "N-2"]) – State variable values at interior grid points.

  • policy (dict[str, Float[Array, "N-2"]]) – Mapping of policy variable names to their values on the interior grid.

  • boundary (FrozenBoundary) – Boundary values for both state and value function.

  • p (Parameter) – Model parameters.

Returns:

Float[Array, “N-2”] – Jump term evaluated at each interior grid point.

Return type:

Float[Array, ‘N’]

static boundary_condition()[source]

(This method is OPTIONAL to override)

Return endogenous-boundary targets for Solver.boundary_search().

Each returned BoundaryConditionTarget specifies:

  • which boundary field should be optimized,

  • how to evaluate the boundary residual on the solved grid,

  • and, for method=”bisection”, the bracket and per-target search settings.

Notes

  • Only boundaries listed here are optimized by boundary_search().

  • The order of targets in the returned list defines the boundary vector order for nonlinear search methods.

  • For nested bisection, the same order defines the outer-to-inner search order.

  • low, high, tol, and max_iter are used by bisection. Other search methods use Config.bs_tol and Config.bs_max_iter.

Return type:

list[BoundaryConditionTarget]

static update_boundary(grid)[source]

Return direct boundary updates for the boundary-update workflow.

Implement this hook when a solved grid directly implies revised boundary values and an update error, so that Solver.boundary_update() can run the outer loop solve -> update boundary -> solve again.

Parameters:

grid (Grid)

static auxiliary(grid)[source]

Return user-defined diagnostics derived from a solved grid.

Grid.aux is a thin proxy for this hook. Leaving it unimplemented means grid.aux will raise NotImplementedError, which is expected.

A common pattern is to return a small dictionary of derived summaries, for example {“value_mean”: jnp.mean(grid.v)}.

Parameters:

grid (Grid)

class AbstractParameter[source]

Abstract base class for model parameters.

This class serves as a base for immutable, hashable parameter containers required by JAX (e.g. for use with static_argnames). Subclasses must be decorated with @struct.dataclass(frozen=True) and should declare all model parameters as class attributes with default values.

Examples

from functools import cached_property
from flax import struct

class Parameter(AbstractParameter):
    r: float
    g: float
    gamma: float
    ...

    apply_func: Callable = struct.field(pytree_node=False, repr=False)

    @cached_property
    def vFB(self) -> float:
        return self.r + self.gamma * self.g / (self.gamma - 1)

params = Parameter(r=0.06, g=0.03, gamma=0.25, ...)

Notes

  • Do not add mutable fields or methods that modify instance state.

  • Derived-parameter methods may be added in subclasses, but they must be decorated with @cached_property to ensure immutability.

update(boundary)[source]

Update the parameter object with the boundary object.

Parameters:

boundary (ImmutableBoundary) – The boundary object to update the parameter object with.

Returns:

Self – The updated parameter object.

Return type:

Self

replace(**updates)

Returns a new object replacing the specified fields with new values.

class AbstractPolicy[source]

Composable policy update interface for explicit/implicit steps.

create_solve_func(solver)[source]

Create a JAX-vectored solve function from a jaxopt solver instance.

Parameters:

solver (Broyden | LevenbergMarquardt | GaussNewton)

Return type:

Callable[[Float[Array, ‘N K’], Float[Array, ‘N’], Float[Array, ‘N’], Float[Array, ‘N’], Float[Array, ‘N’], P], OptStep]

abstract static initialize(grid, p)[source]

(The NECESSARY method you need to implement)#

Create an initial policy guess for the solver.

This abstract method should be implemented by subclasses to provide an initial policy for the control variables used by the HJB solver.

Notes

  • The method is called once during solver setup.

  • The solver expects every policy variable it references during iteration to be present in the returned mapping. Failing to provide required keys may raise an error when the solver starts.

  • The returned object must be a dictionary-like mapping (type PolicyDict) that contains entries for every policy variable the solver expects.

  • If you have a meaningful initial guess for the policy, you should return it here, and the solver will use it as the starting point when guess_policy is True.

  • If you do not have a meaningful initial guess, you can set guess_policy to False in the solver configuration. Then the solver will ignore this initial guess and perform an initial policy improvement step before starting iterations.

Implementations may freely use the following inputs during initialization: - p : Parameter - grid.s : Float[Array, “N”] - grid.v : Float[Array, “N”] - grid.number : int - …

returns:
  • A dictionary-like container (PolicyDict) that maps policy variable names to arrays of values.

  • Each array should have a shape compatible with the solver’s state grid

  • (typically the same shape as self.s or a 1-D array of length self.number).

Examples

class PolicyDict(AbstractPolicyDict):
    control_var1: Array
    control_var2: Array

def initialize_policy(self, grid: Grid, p: P) -> PolicyDict:
    control_var1_val: Array = ...
    return PolicyDict(
        control_var1=jnp.full_like(grid.s, control_var1_val),
        control_var2=jnp.ones(grid.number),
    )
Parameters:
  • grid (Grid)

  • p (P)

Return type:

D

update(grid)[source]

Execute compiled policy update steps and return updated grid.

Parameters:

grid (Grid)

Return type:

Grid

class AbstractPolicyDict[source]

Base TypedDict for policy variables.

This class is used to store variables that are not changed during the policy evaluation step, such as control variables or other auxiliary variables.

So any variables that are not changed during the policy evaluation step should be included here to avoid repeated computations.

Subclass this to declare concrete policy keys and their types, for example:

class PolicyDict(AbstractPolicyDict):
    investment: Array
    consumption: Array
    drift: Array
    diffusion: Array

Notes

  • This class is intended for static type checking.

  • Do not instantiate AbstractPolicyDict directly — declare a subclass instead.

  • Keep value types as Array for JAX compatibility.

class AbstractValueGuess(p, boundary)[source]

Abstract class for initial value function guess.

This class requires subclasses to implement the guess_value method, which provides an initial guess for the value function on the computational grid.

Parameters:
  • p (Parameter) – Model parameters (subclass of AbstractParameter).

  • boundary (AbstractBoundary[P]) – Boundary conditions for the state variable and value function.

s_min

Minimum boundary for the state variable.

Type:

float

s_max

Maximum boundary for the state variable.

Type:

float

v_left

Value function at the left boundary (corresponding to s_min).

Type:

float

v_right

Value function at the right boundary (corresponding to s_max).

Type:

float

guess_value()[source]

Provide an initial guess for the value function.

Parameters:

s (Float[Array, 'N'])

Return type:

Float[Array, ‘N’]

abstract guess_value(s)[source]

(The NECESSARY method you need to implement)

Provide an initial guess for the value function.

Notes

  • You can use self.s_min, self.s_max, self.v_left, and self.v_right to access the boundary conditions.

  • You can also access parameters via self.p.

  • The input s is a grid of state variable values where the value function should be evaluated.

  • The output should be an array of the same shape as s, representing the initial guess for the value function at those points.

Parameters:

s (Float[Array, "N"]) – Grid for the state variable.

Returns:

v (Float[Array, “N”]) – Initial guess for the value function on the grid s.

Return type:

Float[Array, ‘N’]

class LinearInitialValue(p, boundary)[source]

Linear value function guess.

The value function is guessed to be a linear function connecting the boundary values.

Parameters:
  • p (P)

  • boundary (ImmutableBoundary)

guess_value(s)[source]

Construct a linear initial guess linking boundary value endpoints.

Parameters:

s (Float[Array, 'N'])

Return type:

Float[Array, ‘N’]

class QuadraticInitialValue(p, boundary, a_sign, curvature=0.5)[source]

Quadratic initial value guess.

The quadratic function is defined as: v(s) = a * s^2 + b * s + c where the coefficient ‘a’ is either -1 or 1 (indicating concavity or convexity), and coefficients ‘b’ and ‘c’ are determined to satisfy the boundary values: - v(s_min) = v_left - v(s_max) = v_right

Parameters:
  • a_sign (Literal[-1, 1]) – Coefficient determining the concavity (-1) or convexity (1) of the quadratic function.

  • curvature (float, default=0.5) – A parameter in the interval (0, 1] that influences the curvature of the quadratic function. A value closer to 0 results in a flatter curve, while a value of 1 yields a standard quadratic shape.

  • p (P)

  • boundary (ImmutableBoundary)

guess_value(s)[source]

Evaluate the boundary-matching quadratic initial guess on grid s.

Parameters:

s (Float[Array, 'N'])

Return type:

Float[Array, ‘N’]

辅助函数#

explicit_policy(order)[source]

Decorator to mark a method as an explicit solver with optional execution order.

Parameters:

order (int)

Return type:

Callable

implicit_policy(order, solver='gauss_newton', maxiter=10, tol=1e-08, implicit_diff=True, verbose=0, policy_order=[], **solver_kwargs)[source]

Decorator for implicit policy’s FOC.

Parameters:
  • order (int)

  • solver (Literal['gauss_newton', 'broyden', 'lm', 'newton_raphson'])

  • maxiter (int)

  • tol (float)

  • verbose (int)

  • policy_order (list[str])

Return type:

Callable

加载函数#

load_grid(file_path)[source]

Load a Grid object from .pkl.

Parameters:

file_path (str | Path)

Return type:

Grid

load_grids(file_path)[source]

Load a Grids object from .pkl.

Parameters:

file_path (str | Path)

Return type:

Grids

load_sensitivity_result(file_path)[source]

Load a SensitivityResult object from .pkl.

Parameters:

file_path (str | Path)

Return type:

SensitivityResult

相关页面#