Source code for finhjb.config
from typing import Callable, Literal
import jax
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from finhjb.types import ArrayFloat, ArrayInter
[docs]
class Config(BaseModel):
"""Runtime configuration for numerical solvers and differentiation rules.
The configuration object is validated by Pydantic and then used by
algorithm modules (`evaluation`, `policy_iteration`, and boundary routines)
to control convergence thresholds, iteration limits, and differentiation
behavior.
"""
# --- Grid ---
enable_x64: bool = True
derivative_method: Literal["central", "forward", "backward"] = "central"
# --- Policy Iteration (PI) ---
policy_guess: bool = False
pi_method: Literal["scan", "anderson"] = "scan"
pi_max_iter: int = Field(default=50, gt=0)
pi_tol: float = Field(default=1e-6, gt=0.0)
pi_patience: int = Field(default=10, ge=0)
# --- Policy Evaluation (PE) ---
pe_max_iter: int = Field(default=10, gt=0)
pe_tol: float = Field(default=1e-6, gt=0.0)
pe_patience: int = Field(default=5, ge=0)
# --- Boundary Search (BS) ---
bs_max_iter: int = Field(default=20, gt=0)
bs_tol: float = Field(default=1e-6, gt=0.0)
bs_patience: int = Field(default=5, ge=0)
# --- Anderson Acceleration (AA) ---
aa_history_size: int = Field(default=5, gt=0)
aa_mixing_frequency: int = Field(default=1, gt=0)
aa_beta: float = Field(default=1.0, gt=0.0)
aa_ridge: float = Field(default=0, ge=0.0)
_dv_func: Callable[
[
ArrayInter,
ArrayInter,
ArrayInter,
float | ArrayFloat,
],
ArrayInter,
] = PrivateAttr()
model_config = ConfigDict(use_enum_values=True, validate_assignment=True)
[docs]
@model_validator(mode="after")
def set_derivative_function(self) -> "Config":
"""Sets the derivative function based on the method."""
match self.derivative_method:
case "central":
self._dv_func = lambda v_im1, v_i, v_ip1, h: (v_ip1 - v_im1) / (2 * h)
case "forward":
self._dv_func = lambda v_im1, v_i, v_ip1, h: (v_ip1 - v_i) / h
case "backward":
self._dv_func = lambda v_im1, v_i, v_ip1, h: (v_i - v_im1) / h
case _:
raise ValueError(f"Unknown method: {self.derivative_method}!")
return self
@property
def dv_func(self) -> Callable:
"""Return the finite-difference function selected by configuration."""
return self._dv_func
[docs]
@model_validator(mode="after")
def setup_jax(self) -> "Config":
"""Applies JAX configuration after model validation."""
jax.config.update("jax_enable_x64", self.enable_x64)
return self
if __name__ == "__main__":
config = Config()
print(config)