Source code for finhjb.interface.guess
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Generic, Literal
import jax.numpy as jnp
from finhjb.interface.parameter import P
from finhjb.structure._boundary import ImmutableBoundary
from finhjb.types import ArrayN
[docs]
@dataclass
class AbstractValueGuess(ABC, Generic[P]):
"""
Abstract class for initial value function guess.
This class requires subclasses to implement the `guess_value` method,
which provides an initial guess for the value function on the computational grid.
Parameters
----------
p : Parameter
Model parameters (subclass of AbstractParameter).
boundary : AbstractBoundary[P]
Boundary conditions for the state variable and value function.
Attributes
----------
s_min : float
Minimum boundary for the state variable.
s_max : float
Maximum boundary for the state variable.
v_left : float
Value function at the left boundary (corresponding to `s_min`).
v_right : float
Value function at the right boundary (corresponding to `s_max`).
Methods
-------
guess_value
Provide an initial guess for the value function.
"""
p: P = field(repr=False)
boundary: ImmutableBoundary = field(repr=True)
s_min: float = field(init=False, repr=False)
s_max: float = field(init=False, repr=False)
v_left: float = field(init=False, repr=False)
v_right: float = field(init=False, repr=False)
def __post_init__(self):
"""Cache boundary scalars for repeated value-guess computations."""
self.s_min, self.s_max, self.v_left, self.v_right = (
self.boundary.get_boundaries()
)
[docs]
@abstractmethod
def guess_value(self, s: ArrayN) -> ArrayN:
"""
(The NECESSARY method you need to implement)
Provide an initial guess for the value function.
Notes
-----
- You can use `self.s_min`, `self.s_max`, `self.v_left`, and `self.v_right`
to access the boundary conditions.
- You can also access parameters via `self.p`.
- The input `s` is a grid of state variable values
where the value function should be evaluated.
- The output should be an array of the same shape as `s`,
representing the initial guess for the value function at those points.
Parameters
----------
s : Float[Array, "N"]
Grid for the state variable.
Returns
-------
v: Float[Array, "N"]
Initial guess for the value function on the grid `s`.
"""
pass
[docs]
@dataclass
class LinearInitialValue(AbstractValueGuess[P]):
"""
Linear value function guess.
The value function is guessed to be a linear function connecting the boundary values.
"""
[docs]
def guess_value(self, s: ArrayN) -> ArrayN:
"""Construct a linear initial guess linking boundary value endpoints."""
return jnp.linspace(self.v_left, self.v_right, s.size)
[docs]
@dataclass
class QuadraticInitialValue(AbstractValueGuess[P]):
"""
Quadratic initial value guess.
The quadratic function is defined as:
v(s) = a * s^2 + b * s + c
where the coefficient 'a' is either -1 or 1 (indicating concavity or convexity),
and coefficients 'b' and 'c' are determined to satisfy the boundary values:
- v(s_min) = v_left
- v(s_max) = v_right
Parameters
----------
a_sign: Literal[-1, 1]
Coefficient determining the concavity (-1) or convexity (1) of the quadratic function.
curvature: float, default=0.5
A parameter in the interval (0, 1] that influences the curvature of the quadratic function.
A value closer to 0 results in a flatter curve, while a value of 1 yields a standard quadratic shape.
"""
a_sign: Literal[-1, 1]
curvature: float = 0.5
def _calculate_coefficients(self):
"""Calculate the coefficients of the quadratic function."""
# Validate inputs
if self.a_sign not in (-1, 1):
raise ValueError("Coefficient 'a' must be either -1 or 1.")
if self.curvature <= 0 or self.curvature > 1:
raise ValueError("Curvature must be in the interval (0, 1).")
# Calculate coefficients
# Ensure the quadratic function meets the boundary conditions
slope = (self.v_right - self.v_left) / (self.s_max - self.s_min)
self.a = abs(slope) / (self.s_max - self.s_min) * self.a_sign * self.curvature
self.b = (
(self.v_right - self.v_left) - self.a * (self.s_max**2 - self.s_min**2)
) / (self.s_max - self.s_min)
self.c = self.v_left - self.a * self.s_min**2 - self.b * self.s_min
[docs]
def guess_value(self, s: ArrayN) -> ArrayN:
"""Evaluate the boundary-matching quadratic initial guess on grid `s`."""
self._calculate_coefficients()
return self.a * s**2 + self.b * s + self.c