API 参考#
当你已经知道自己要找哪个对象、哪个方法,或者想确认精确导出名与 loader 行为时,再来这里最合适。
如果你还需要概念解释,而不是查精确名字,请先去 建模指南 或 结果与诊断。
如果你还在熟悉工作流,建议先读这些教程,再回来查名字和成员:
如果你现在想做的是…… |
先读这一页 |
|---|---|
安装并跑通第一个案例 |
|
复现 BCW 基准案例 |
|
看懂返回对象和数值诊断 |
|
把 BCW 改成自己的模型 |
|
决定使用哪种求解工作流 |
当你需要确认精确导出名、类成员、方法签名或 loader 行为时,再回到这里。
顶层导出(finhjb)#
核心对象#
ConfigSolverGridGridsImmutableBoundary
接口类型#
AbstractBoundaryBoundaryConditionTargetAbstractModelAbstractParameterAbstractPolicyAbstractPolicyDictAbstractValueGuess
辅助#
explicit_policy(order: int)implicit_policy(...)LinearInitialValueQuadraticInitialValue
加载函数#
load_grid(path)load_grids(path)load_sensitivity_result(path)
按任务找 API#
任务 |
你最先会碰到的对象 |
|---|---|
定义模型 |
|
跑一次固定边界求解 |
|
搜索内生边界 |
|
检查一个解 |
|
保存与重载结果 |
|
加载函数详解#
三个 load_* 的核心区别,是你到底想恢复:
单个
Grid,一组
Grids,还是完整的
SensitivityResult。
函数 |
恢复对象类型 |
对应保存方法 |
典型用途 |
|---|---|---|---|
|
|
|
单次求解的完整网格 |
|
|
|
一组参数点上的网格集合 |
|
|
|
continuation summary + 全部网格 |
保证行为:
路径会自动补
.pkl;加载后会做类型校验;
用错 loader 会明确抛出
TypeError。
load_grid 示例#
import finhjb as fjb
state, _ = solver.solve()
state.grid.save("outputs/baseline_grid")
grid = fjb.load_grid("outputs/baseline_grid")
print(type(grid).__name__)
print(grid.df.head())
load_grids 示例#
import finhjb as fjb
import jax.numpy as jnp
result = solver.sensitivity_analysis(
method="hybr",
param_name="sigma",
param_values=jnp.array([0.09, 0.10, 0.11]),
)
result.grids.save("outputs/sigma_grids")
grids = fjb.load_grids("outputs/sigma_grids")
print(type(grids).__name__)
print(list(grids.values))
load_sensitivity_result 示例#
import finhjb as fjb
result.save("outputs/sigma_result")
loaded = fjb.load_sensitivity_result("outputs/sigma_result")
print(type(loaded).__name__)
print(loaded.df.head())
常见加载错误#
用
load_grid去读取 continuation 或 sensitivity 的保存结果;忘了 loader 会自动补
.pkl;其实只想要单点网格,却误用了整个
SensitivityResult。
Solver 的主要方法#
Solver.solve() -> (PolicyIterationState | EvaluationState, history)Solver.boundary_update() -> (BoundaryUpdateState, history)Solver.boundary_search(method, verbose=False) -> BoundarySearchStateSolver.sensitivity_analysis(method, param_name, param_values) -> SensitivityResult
什么时候该用哪个方法,请配合 求解器指南 阅读。
boundary_search 方法备注#
Solver.boundary_search(method=...) 当前支持:
bisectionhybrlmbroydengauss_newtonlbfgskrylovbroyden1
关键行为:
boundary_condition()返回的就是实际会被搜索的边界列表。这个列表的顺序,就是非线性方法使用的边界参数顺序。
对
bisection来说,同样的顺序还会变成嵌套搜索的外层到内层顺序。BoundaryConditionTarget.low和high只对bisection有意义。BoundaryConditionTarget.tol和max_iter也主要只对bisection有意义。其他方法主要使用
Config.bs_tol和Config.bs_max_iter。lbfgs做的是残差平方和最小化,而不是直接解 root problem。
AbstractModel 可选钩子速查#
最重要的几个可选钩子是:
jump(...):可选,默认是零,由求解器通过Grid.jump_inter调用。boundary_condition():为boundary_search()返回BoundaryConditionTarget列表。update_boundary(grid):只在boundary_update()工作流里使用。auxiliary(grid):通过Grid.aux暴露;如果没实现,grid.aux抛NotImplementedError是正常的。
Grid 的便捷属性#
Grid 常用属性:
s,v,dv,d2vs_inter,policy_inter,number_inter,jump_interdf,aux
补充说明:
jump_inter是Model.jump(...)在内部网格上的求值结果。aux只是Model.auxiliary(grid)的代理入口。auxiliary(grid)一个很常见的写法,是返回一个小字典来保存派生诊断量。
Grid 常用方法:
reset()update_grid(boundary)update_with_v_inter(v_inter)save(path)
Grids 常用方法:
getselect_gridsaddmergesave
如何解释这些对象,请看 结果与诊断。
API 详细文档#
Config#
- class Config(*, enable_x64=True, derivative_method='central', policy_guess=False, pi_method='scan', pi_max_iter=50, pi_tol=1e-06, pi_patience=10, pe_max_iter=10, pe_tol=1e-06, pe_patience=5, bs_max_iter=20, bs_tol=1e-06, bs_patience=5, aa_history_size=5, aa_mixing_frequency=1, aa_beta=1.0, aa_ridge=0)[source]
Runtime configuration for numerical solvers and differentiation rules.
The configuration object is validated by Pydantic and then used by algorithm modules (evaluation, policy_iteration, and boundary routines) to control convergence thresholds, iteration limits, and differentiation behavior.
- Parameters:
enable_x64 (bool)
derivative_method (Literal['central', 'forward', 'backward'])
policy_guess (bool)
pi_method (Literal['scan', 'anderson'])
pi_max_iter (Annotated[int, Gt(gt=0)])
pi_tol (Annotated[float, Gt(gt=0)])
pi_patience (Annotated[int, Ge(ge=0)])
pe_max_iter (Annotated[int, Gt(gt=0)])
pe_tol (Annotated[float, Gt(gt=0)])
pe_patience (Annotated[int, Ge(ge=0)])
bs_max_iter (Annotated[int, Gt(gt=0)])
bs_tol (Annotated[float, Gt(gt=0)])
bs_patience (Annotated[int, Ge(ge=0)])
aa_history_size (Annotated[int, Gt(gt=0)])
aa_mixing_frequency (Annotated[int, Gt(gt=0)])
aa_beta (Annotated[float, Gt(gt=0)])
aa_ridge (Annotated[float, Ge(ge=0)])
- model_config: ClassVar[ConfigDict] = {'use_enum_values': True, 'validate_assignment': True}
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- set_derivative_function()[source]
Sets the derivative function based on the method.
- Return type:
Config
- property dv_func: Callable
Return the finite-difference function selected by configuration.
- setup_jax()[source]
Applies JAX configuration after model validation.
- Return type:
Config
- model_post_init(context, /)
This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that’s what pydantic-core passes when calling it.
- Parameters:
self (BaseModel) – The BaseModel instance.
context (Any) – The context.
- Return type:
None
Solver#
- class Solver(boundary=None, model=None, value_guess=None, policy_guess=True, number=1000, config=<factory>, grid=None)[source]
High-level orchestrator for solving HJB models on one-dimensional grids.
- Parameters:
boundary (AbstractBoundary | None)
model (AbstractModel | None)
value_guess (AbstractValueGuess[P] | None)
policy_guess (bool)
number (int)
config (Config)
grid (Grid | None)
- solve()[source]
Run policy iteration (or one-step evaluation) on the active grid.
- Return type:
tuple[PolicyIterationState | EvaluationState, Array]
- boundary_update()[source]
Iteratively update model boundaries and re-solve the HJB system.
- Return type:
tuple[BoundaryUpdateState, Array]
- boundary_search(method, verbose=False)[source]
Search optimal boundaries by solving boundary conditions as root problems.
- Parameters:
method (Literal['gauss_newton', 'lm', 'broyden', 'lbfgs', 'bisection', 'hybr', 'broyden1', 'krylov'])
- sensitivity_analysis(method, param_name, param_values)[source]
Solve the model along a parameter path using continuation.
- Parameters:
method (Literal['gauss_newton', 'lm', 'broyden', 'lbfgs', 'bisection', 'hybr', 'broyden1', 'krylov'])
param_name (str)
param_values (Array)
- Return type:
SensitivityResult
数据结构#
- class Grid(p, boundary, model, h=0, s=<factory>, v=<factory>, v_inter=<factory>, dv=<factory>, d2v=<factory>, policy=<factory>, value_guess=None, policy_guess=False, number=1000, config=<factory>)[source]
Immutable computational grid and solver state container.
Grid stores state-space coordinates, value function approximations, derivatives, boundary values, and policy variables. The object is a Flax PyTree so it can be transformed by JAX while keeping model/config objects as static fields.
- Parameters:
p (P)
boundary (ImmutableBoundary)
model (AbstractModel[P, PolicyDictType])
h (float)
s (Float[Array, 'N'])
v (Float[Array, 'N'])
v_inter (Float[Array, 'N-2'])
dv (Float[Array, 'N'])
d2v (Float[Array, 'N'])
policy (PolicyDictType)
value_guess (AbstractValueGuess[P] | None)
policy_guess (bool)
number (int)
config (Config)
- reset()[source]
Rebuild the full grid from boundary/parameter/policy configuration.
- Return type:
Self
- update_grid(boundary)[source]
Update boundary values and resample state grid if s bounds changed.
- Parameters:
boundary (ImmutableBoundary)
- Return type:
Self
- property optimizable_boundaries[source]
Return boundary names targeted by model-defined boundary conditions.
- property policy_in_axes[source]
Returns the axes for the policy parameters.
- property s_inter: Float[Array, 'N-2']
Interior state grid (excluding both boundary points).
- property policy_inter: PolicyDictType
Interior slices of all policy arrays.
- property number_inter: int
Number of interior grid points.
- property jump_inter
Evaluate Model.jump(…) on the interior grid slices.
- update_with_v_inter(v_inter)[source]
Update full value and derivative arrays from interior values.
- Parameters:
v_inter (Float[Array, 'N-2'])
- Return type:
Self
- property df
Convert grid data to a pandas DataFrame for easy inspection.
- property aux
Return Model.auxiliary(grid=self) for user-defined diagnostics.
- save(file_path)[source]
Save the Grid to a pickle file.
- Parameters:
file_path (str | Path)
- Return type:
None
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class Grids(param_name='???', data=<factory>)[source]
Collection of solved grids indexed by a scalar continuation parameter.
- Parameters:
param_name (str)
data (dict[float, Grid])
- property values
Sorted parameter values contained in this subset.
- get(value, default=None)[source]
Get a grid by key with a default fallback.
- Parameters:
value (float)
default (Grid | None)
- Return type:
Grid | None
- save(file_path)[source]
Save the Grids to a pickle file.
- Parameters:
file_path (str | Path)
- Return type:
None
- select_grids(values, *, atol=1e-08, rtol=1e-06)[source]
Select grids for specific parameter values and return a Grids object.
- Parameters:
values (Iterable[float])
atol (float)
rtol (float)
- Return type:
Grids
- add(label, grid)[source]
Insert or replace one grid at label.
- Parameters:
label (float)
grid (Grid)
- Return type:
Grids
- merge(other)[source]
Merge two Grids collections with right-hand overwrite on conflicts.
- Parameters:
other (Grids)
- Return type:
Grids
- class ImmutableBoundary(s_min, s_max, v_left, v_right, graph)[source]
Immutable boundary values structure.
- Parameters:
s_min (float)
s_max (float)
v_left (float)
v_right (float)
graph (list[DependencyMethod])
- s_min
Minimum state variable value.
- Type:
float
- s_max
Maximum state variable value.
- Type:
float
- v_left
Value function at the left boundary.
- Type:
float
- v_right
Value function at the right boundary.
- Type:
float
- get_boundaries()[source]
Return (s_min, s_max, v_left, v_right) as a tuple.
- Return type:
tuple[float, float, float, float]
- get_boundary_dict()[source]
Return all boundary values as a dictionary keyed by boundary name.
- Return type:
dict[Literal[‘s_min’, ‘s_max’, ‘v_left’, ‘v_right’], float]
- update_boundaries(boundary_dict, p)[source]
Return a new boundary object after applying dependency graph updates.
- Parameters:
boundary_dict (dict[Literal['s_min', 's_max', 'v_left', 'v_right'], float])
p (P)
- s_changed(boundary)[source]
Check whether state-space limits changed versus another boundary.
- Parameters:
boundary (Self)
- replace(**updates)
Returns a new object replacing the specified fields with new values.
接口类型#
- class AbstractBoundary(p, s_min=None, s_max=None, v_left=None, v_right=None)[source]
An intelligent, auto-configuring class for defining HJB problem boundaries.
This class serves as a configuration object for boundary values. It automatically discovers dependencies between boundaries by inspecting the signatures of compute_<boundary_name> methods defined in subclasses.
Usage#
Subclass AbstractBoundary.
For boundaries that can be computed from parameters or other boundaries, define methods like compute_v_right(self, s_max: float) -> float. The dependency (s_max) is automatically inferred from the signature.
Instantiate the subclass, providing any boundary values that cannot be computed as keyword arguments (e.g., MyBoundary(p=params, s_min=0.0)).
Examples
class MyBoundary(AbstractBoundary): @staticmethod def compute_s_max(p) -> float: return p.x_bar @staticmethod def compute_v_right(s_max: float, p) -> float: return s_max * 2.0 boundary = MyBoundary(p=params, s_min=0.0, v_left=0.0) frozen_boundary = boundary.frozen_boundary
- p
The parameter object containing model parameters.
- Type:
Parameter
- s_min
The minimum state boundary value.
- Type:
Optional[float]
- s_max
The maximum state boundary value.
- Type:
Optional[float]
- v_left
The value function at the left boundary.
- Type:
Optional[float]
- v_right
The value function at the right boundary.
- Type:
Optional[float]
- already_boundary
A dictionary of boundary values that were provided directly.
- Type:
dict[BoundaryName, float]
- independent_boundary
A set of boundary names that were provided directly.
- Type:
set[BoundaryName]
- frozen_boundary
The fully computed and immutable boundary object.
- Type:
ImmutableBoundary
- required_boundary[source]
A set of all boundary names required to compute the full boundary.
- Type:
set[BoundaryName]
- boundary_dependencies[source]
A mapping of boundary names to their dependencies.
- Type:
dict[BoundaryName, set[BoundaryName]]
- graph[source]
A topologically sorted list of methods to compute boundary values.
- Type:
list[DependencyMethod]
- property required_boundary: set[Literal['s_min', 's_max', 'v_left', 'v_right']][source]
All dependencies required to compute boundaries.
- property boundary_dependencies: dict[Literal['s_min', 's_max', 'v_left', 'v_right'], set[Literal['s_min', 's_max', 'v_left', 'v_right']]][source]
Dict of compute methods and their dependencies.
- property graph: list[DependencyMethod][source]
Topologically sorted list of methods to compute boundary values.
- Returns:
list[DependencyMethod] – A list of methods with their metadata, sorted in the order they should be executed.
- Raises:
ValueError – If there are circular dependencies or missing dependencies.
- freeze()[source]
Computes and returns an ImmutableBoundary with all boundary values set.
- Returns:
ImmutableBoundary – The fully computed and validated immutable boundary object.
- Parameters:
p (P)
s_min (float | None)
s_max (float | None)
v_left (float | None)
v_right (float | None)
- class BoundaryConditionTarget(boundary_name, condition_func, low=None, high=None, tol=1e-06, max_iter=50)[source]
Specification for one endogenous-boundary target used by boundary search.
A target says:
which boundary field should be varied,
how to evaluate the resulting residual on a solved Grid,
and, optionally, which bracket/tolerance settings should be used for bisection-style search.
Notes
Only boundaries that appear in Model.boundary_condition() are searched by Solver.boundary_search().
The order of targets in that list defines the parameter order passed to nonlinear search methods.
For nested bisection, the same order also defines the outer-to-inner search order.
- Parameters:
boundary_name (Literal['s_min', 's_max', 'v_left', 'v_right'])
condition_func (Callable[[Grid], float])
low (float | None)
high (float | None)
tol (float)
max_iter (int)
- boundary_name
Boundary field to optimize, such as s_max or v_left.
- Type:
BoundaryName
- condition_func
Residual evaluated on the solved grid. Search aims to drive this value to zero.
- Type:
Callable[[“Grid”], float]
- low
Lower bracket used by method=”bisection”. Ignored by the other search methods.
- Type:
Optional[float]
- high
Upper bracket used by method=”bisection”. Ignored by the other search methods.
- Type:
Optional[float]
- tol
Per-target tolerance used by method=”bisection”.
- Type:
float
- max_iter
Per-target iteration cap used by method=”bisection”.
- Type:
int
- class AbstractModel(policy)[source]
Abstract base class for HJB models.
This class defines the interface for HJB models, including the required HJB residual and several optional hooks for jumps, endogenous boundaries, outer-loop boundary updates, and custom diagnostics.
- Parameters:
policy (AbstractPolicy)
- initialize_policy()
Create an initial policy guess for the solver.
- update_policy()
Update the policy variables using the current value function and its derivatives.
- hjb_residual()[source]
Calculate the pointwise HJB residual on the interior grid.
- Parameters:
v (Float[Array, 'N-2'])
dv (Float[Array, 'N-2'])
d2v (Float[Array, 'N-2'])
s (Float[Array, 'N-2'])
policy (D)
jump (Float[Array, 'N-2'])
boundary (ImmutableBoundary)
p (P)
- Return type:
Float[Array, ‘N-2’]
- jump()[source]
Calculate the jump term for the HJB equation.
- Parameters:
v (Float[Array, 'N'])
s (Float[Array, 'N'])
policy (D)
boundary (ImmutableBoundary)
p (P)
- Return type:
Float[Array, ‘N’]
- boundary_condition()[source]
Declare endogenous boundary targets for Solver.boundary_search().
- Return type:
list[BoundaryConditionTarget]
- update_boundary()[source]
Return direct boundary updates for Solver.boundary_update().
- Parameters:
grid (Grid)
- auxiliary()[source]
Return user-defined diagnostics exposed through Grid.aux.
- Parameters:
grid (Grid)
Notes
Subclasses must implement the abstract hjb_residual method.
The jump method is optional to override; the default implementation returns zero jumps.
Keep parameter orders and return shapes as specified; the solver expects these signatures.
These are static methods so remember to add the @staticmethod decorator.
- abstract static hjb_residual(v, dv, d2v, s, policy, jump, boundary, p)[source]
(The NECESSARY method you need to implement)#
Calculate the pointwise HJB residual on the interior grid.
Notes
Keep the parameter order and return shape as specified; the solver expects this signature.
This is a static method and does not have access to instance attributes.
- param v:
Value function evaluated at interior grid points.
- type v:
Float[Array, “N-2”]
- param dv:
First derivative of the value function at interior points.
- type dv:
Float[Array, “N-2”]
- param d2v:
Second derivative of the value function at interior points.
- type d2v:
Float[Array, “N-2”]
- param s:
State variable values at interior grid points.
- type s:
Float[Array, “N-2”]
- param policy:
Mapping of policy variable names to their values on the interior grid.
- type policy:
dict[str, Float[Array, “N-2”]]
- param jump:
Jump term evaluated at each interior grid point.
- type jump:
Float[Array, “N-2”]
- param boundary:
Boundary values for both state and value function.
- type boundary:
FrozenBoundary
- param p:
Model parameters.
- type p:
Parameter
- returns:
Float[Array, “N-2”] – HJB residual evaluated at each interior grid point.
Examples
@staticmethod def hjb_residual(v, dv, d2v, s, policy, jump, boundary, p): control1 = policy["control1"] residual = ... return residual
- Parameters:
v (Float[Array, 'N-2'])
dv (Float[Array, 'N-2'])
d2v (Float[Array, 'N-2'])
s (Float[Array, 'N-2'])
policy (D)
jump (Float[Array, 'N-2'])
boundary (ImmutableBoundary)
p (P)
- Return type:
Float[Array, ‘N-2’]
- static jump(v, s, policy, boundary, p)[source]
(This method is OPTIONAL to override)
Calculate the jump term for the HJB equation.
The default implementation returns zero jumps. Override this only when your HJB contains a non-zero jump or Poisson-arrival term.
Notes
The solver evaluates this hook through Grid.jump_inter, so in practice v, s, and policy are the interior-grid slices.
You can override this method in subclasses to implement specific jump dynamics.
Keep the parameter order and return shape as specified; the solver expects this signature.
This is a static method and does not have access to instance attributes.
- Parameters:
v (Float[Array, "N-2"]) – Value function evaluated at interior grid points.
s (Float[Array, "N-2"]) – State variable values at interior grid points.
policy (dict[str, Float[Array, "N-2"]]) – Mapping of policy variable names to their values on the interior grid.
boundary (FrozenBoundary) – Boundary values for both state and value function.
p (Parameter) – Model parameters.
- Returns:
Float[Array, “N-2”] – Jump term evaluated at each interior grid point.
- Return type:
Float[Array, ‘N’]
- static boundary_condition()[source]
(This method is OPTIONAL to override)
Return endogenous-boundary targets for Solver.boundary_search().
Each returned BoundaryConditionTarget specifies:
which boundary field should be optimized,
how to evaluate the boundary residual on the solved grid,
and, for method=”bisection”, the bracket and per-target search settings.
Notes
Only boundaries listed here are optimized by boundary_search().
The order of targets in the returned list defines the boundary vector order for nonlinear search methods.
For nested bisection, the same order defines the outer-to-inner search order.
low, high, tol, and max_iter are used by bisection. Other search methods use Config.bs_tol and Config.bs_max_iter.
- Return type:
list[BoundaryConditionTarget]
- static update_boundary(grid)[source]
Return direct boundary updates for the boundary-update workflow.
Implement this hook when a solved grid directly implies revised boundary values and an update error, so that Solver.boundary_update() can run the outer loop solve -> update boundary -> solve again.
- Parameters:
grid (Grid)
- static auxiliary(grid)[source]
Return user-defined diagnostics derived from a solved grid.
Grid.aux is a thin proxy for this hook. Leaving it unimplemented means grid.aux will raise NotImplementedError, which is expected.
A common pattern is to return a small dictionary of derived summaries, for example {“value_mean”: jnp.mean(grid.v)}.
- Parameters:
grid (Grid)
- class AbstractParameter[source]
Abstract base class for model parameters.
This class serves as a base for immutable, hashable parameter containers required by JAX (e.g. for use with static_argnames). Subclasses must be decorated with @struct.dataclass(frozen=True) and should declare all model parameters as class attributes with default values.
Examples
from functools import cached_property from flax import struct class Parameter(AbstractParameter): r: float g: float gamma: float ... apply_func: Callable = struct.field(pytree_node=False, repr=False) @cached_property def vFB(self) -> float: return self.r + self.gamma * self.g / (self.gamma - 1) params = Parameter(r=0.06, g=0.03, gamma=0.25, ...)
Notes
Do not add mutable fields or methods that modify instance state.
Derived-parameter methods may be added in subclasses, but they must be decorated with @cached_property to ensure immutability.
- update(boundary)[source]
Update the parameter object with the boundary object.
- Parameters:
boundary (ImmutableBoundary) – The boundary object to update the parameter object with.
- Returns:
Self – The updated parameter object.
- Return type:
Self
- replace(**updates)
Returns a new object replacing the specified fields with new values.
- class AbstractPolicy[source]
Composable policy update interface for explicit/implicit steps.
- create_solve_func(solver)[source]
Create a JAX-vectored solve function from a jaxopt solver instance.
- Parameters:
solver (Broyden | LevenbergMarquardt | GaussNewton)
- Return type:
Callable[[Float[Array, ‘N K’], Float[Array, ‘N’], Float[Array, ‘N’], Float[Array, ‘N’], Float[Array, ‘N’], P], OptStep]
- abstract static initialize(grid, p)[source]
(The NECESSARY method you need to implement)#
Create an initial policy guess for the solver.
This abstract method should be implemented by subclasses to provide an initial policy for the control variables used by the HJB solver.
Notes
The method is called once during solver setup.
The solver expects every policy variable it references during iteration to be present in the returned mapping. Failing to provide required keys may raise an error when the solver starts.
The returned object must be a dictionary-like mapping (type PolicyDict) that contains entries for every policy variable the solver expects.
If you have a meaningful initial guess for the policy, you should return it here, and the solver will use it as the starting point when guess_policy is True.
If you do not have a meaningful initial guess, you can set guess_policy to False in the solver configuration. Then the solver will ignore this initial guess and perform an initial policy improvement step before starting iterations.
Implementations may freely use the following inputs during initialization: - p : Parameter - grid.s : Float[Array, “N”] - grid.v : Float[Array, “N”] - grid.number : int - …
- returns:
A dictionary-like container (PolicyDict) that maps policy variable names to arrays of values.
Each array should have a shape compatible with the solver’s state grid
(typically the same shape as self.s or a 1-D array of length self.number).
Examples
class PolicyDict(AbstractPolicyDict): control_var1: Array control_var2: Array def initialize_policy(self, grid: Grid, p: P) -> PolicyDict: control_var1_val: Array = ... return PolicyDict( control_var1=jnp.full_like(grid.s, control_var1_val), control_var2=jnp.ones(grid.number), )
- Parameters:
grid (Grid)
p (P)
- Return type:
D
- update(grid)[source]
Execute compiled policy update steps and return updated grid.
- Parameters:
grid (Grid)
- Return type:
Grid
- class AbstractPolicyDict[source]
Base TypedDict for policy variables.
This class is used to store variables that are not changed during the policy evaluation step, such as control variables or other auxiliary variables.
So any variables that are not changed during the policy evaluation step should be included here to avoid repeated computations.
Subclass this to declare concrete policy keys and their types, for example:
class PolicyDict(AbstractPolicyDict): investment: Array consumption: Array drift: Array diffusion: Array
Notes
This class is intended for static type checking.
Do not instantiate AbstractPolicyDict directly — declare a subclass instead.
Keep value types as Array for JAX compatibility.
- class AbstractValueGuess(p, boundary)[source]
Abstract class for initial value function guess.
This class requires subclasses to implement the guess_value method, which provides an initial guess for the value function on the computational grid.
- Parameters:
p (Parameter) – Model parameters (subclass of AbstractParameter).
boundary (AbstractBoundary[P]) – Boundary conditions for the state variable and value function.
- s_min
Minimum boundary for the state variable.
- Type:
float
- s_max
Maximum boundary for the state variable.
- Type:
float
- v_left
Value function at the left boundary (corresponding to s_min).
- Type:
float
- v_right
Value function at the right boundary (corresponding to s_max).
- Type:
float
- guess_value()[source]
Provide an initial guess for the value function.
- Parameters:
s (Float[Array, 'N'])
- Return type:
Float[Array, ‘N’]
- abstract guess_value(s)[source]
(The NECESSARY method you need to implement)
Provide an initial guess for the value function.
Notes
You can use self.s_min, self.s_max, self.v_left, and self.v_right to access the boundary conditions.
You can also access parameters via self.p.
The input s is a grid of state variable values where the value function should be evaluated.
The output should be an array of the same shape as s, representing the initial guess for the value function at those points.
- Parameters:
s (Float[Array, "N"]) – Grid for the state variable.
- Returns:
v (Float[Array, “N”]) – Initial guess for the value function on the grid s.
- Return type:
Float[Array, ‘N’]
- class LinearInitialValue(p, boundary)[source]
Linear value function guess.
The value function is guessed to be a linear function connecting the boundary values.
- Parameters:
p (P)
boundary (ImmutableBoundary)
- guess_value(s)[source]
Construct a linear initial guess linking boundary value endpoints.
- Parameters:
s (Float[Array, 'N'])
- Return type:
Float[Array, ‘N’]
- class QuadraticInitialValue(p, boundary, a_sign, curvature=0.5)[source]
Quadratic initial value guess.
The quadratic function is defined as: v(s) = a * s^2 + b * s + c where the coefficient ‘a’ is either -1 or 1 (indicating concavity or convexity), and coefficients ‘b’ and ‘c’ are determined to satisfy the boundary values: - v(s_min) = v_left - v(s_max) = v_right
- Parameters:
a_sign (Literal[-1, 1]) – Coefficient determining the concavity (-1) or convexity (1) of the quadratic function.
curvature (float, default=0.5) – A parameter in the interval (0, 1] that influences the curvature of the quadratic function. A value closer to 0 results in a flatter curve, while a value of 1 yields a standard quadratic shape.
p (P)
boundary (ImmutableBoundary)
- guess_value(s)[source]
Evaluate the boundary-matching quadratic initial guess on grid s.
- Parameters:
s (Float[Array, 'N'])
- Return type:
Float[Array, ‘N’]
辅助函数#
- explicit_policy(order)[source]
Decorator to mark a method as an explicit solver with optional execution order.
- Parameters:
order (int)
- Return type:
Callable
- implicit_policy(order, solver='gauss_newton', maxiter=10, tol=1e-08, implicit_diff=True, verbose=0, policy_order=[], **solver_kwargs)[source]
Decorator for implicit policy’s FOC.
- Parameters:
order (int)
solver (Literal['gauss_newton', 'broyden', 'lm', 'newton_raphson'])
maxiter (int)
tol (float)
verbose (int)
policy_order (list[str])
- Return type:
Callable
加载函数#
- load_grid(file_path)[source]
Load a
Gridobject from.pkl.- Parameters:
file_path (str | Path)
- Return type:
Grid
- load_grids(file_path)[source]
Load a
Gridsobject from.pkl.- Parameters:
file_path (str | Path)
- Return type:
Grids
- load_sensitivity_result(file_path)[source]
Load a
SensitivityResultobject from.pkl.- Parameters:
file_path (str | Path)
- Return type:
SensitivityResult