Source code for astrodynx.twobody._kep_equ

import jax.numpy as jnp
from jax.typing import ArrayLike, DTypeLike
from jax import Array
from astrodynx.twobody._uniformulas import ufunc0, ufunc1, ufunc2, ufunc3
import jax
from astrodynx.twobody._orb_integrals import equ_of_orb_uvi

"""Kepler's equations and generalized anomaly for two-body orbital mechanics."""


[docs] def kepler_equ_elps(E: ArrayLike, e: ArrayLike, M: ArrayLike = 0) -> Array: r"""Returns the Kepler's equation for elliptical orbits in the form f(E) = 0. Args: E: Eccentric anomaly. e: Eccentricity of the orbit, 0 <= e < 1. M: (optional) Mean anomaly. Returns: The value of Kepler's equation for elliptical orbits: E - e*sin(E) - M. Notes: Kepler's equation for elliptical orbits relates the eccentric anomaly E to the mean anomaly M: $$ E - e \sin E = M $$ This function returns the equation in the form f(E) = 0, which is useful for root-finding algorithms. References: Battin, 1999, pp.160. Examples: A simple example: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> E = jnp.pi/4 >>> e = 0.1 >>> M = 0.7 >>> adx.kepler_equ_elps(E, e, M) Array(0.01468..., dtype=float32, weak_type=True) With broadcasting, you can calculate the Kepler's equation for multiple eccentric anomalies, eccentricities, and mean anomalies: >>> E = jnp.array([jnp.pi/4, jnp.pi/2]) >>> e = jnp.array([0.1, 0.2]) >>> M = jnp.array([0.7, 0.8]) >>> adx.kepler_equ_elps(E, e, M) Array([0.01468..., 0.5707...], dtype=float32) """ return E - e * jnp.sin(E) - M
[docs] def dE(E: ArrayLike, e: ArrayLike) -> Array: r"""Returns the derivative of Kepler's equation for elliptical orbits with respect to the eccentric anomaly. Args: E: Eccentric anomaly. e: Eccentricity of the orbit, 0 <= e < 1. Returns: The derivative of Kepler's equation for elliptical orbits with respect to the eccentric anomaly. Notes: The derivative of Kepler's equation for elliptical orbits with respect to the eccentric anomaly is: $$ dE = 1 - e \cos E $$ where $E$ is the eccentric anomaly and $e$ is the eccentricity. """ return 1 - e * jnp.cos(E)
[docs] def kepler_equ_hypb(H: ArrayLike, e: ArrayLike, N: ArrayLike = 0) -> Array: r"""Returns the Kepler's equation for hyperbolic orbits in the form f(H) = 0. Args: H: Hyperbolic eccentric anomaly. e: Eccentricity of the orbit, e > 1. N: (optional) Hyperbolic mean anomaly. Returns: The value of Kepler's equation for hyperbolic orbits: e*sinh(H) - H - N. Notes: Kepler's equation for hyperbolic orbits relates the hyperbolic eccentric anomaly H to the hyperbolic mean anomaly N: $$ e \sinh H - H = N $$ This function returns the equation in the form f(H) = 0, which is useful for root-finding algorithms. References: Battin, 1999, pp.168. Examples: A simple example: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> H = 1.0 >>> e = 1.5 >>> N = 1.0 >>> adx.kepler_equ_hypb(H, e, N) Array(-0.2371..., dtype=float32, weak_type=True) With broadcasting, you can calculate the Kepler's equation for multiple hyperbolic eccentric anomalies, eccentricities, and hyperbolic mean anomalies: >>> H = jnp.array([1.0, 2.0]) >>> e = jnp.array([1.5, 1.5]) >>> N = jnp.array([1.0, 1.0]) >>> adx.kepler_equ_hypb(H, e, N) Array([-0.2371..., 2.4402...], dtype=float32) """ return e * jnp.sinh(H) - H - N
[docs] def dH(H: ArrayLike, e: ArrayLike) -> Array: r"""Returns the derivative of Kepler's equation for hyperbolic orbits with respect to the hyperbolic eccentric anomaly. Args: H: Hyperbolic eccentric anomaly. e: Eccentricity of the orbit, e > 1. Returns: The derivative of Kepler's equation for hyperbolic orbits with respect to the hyperbolic eccentric anomaly. Notes: The derivative of Kepler's equation for hyperbolic orbits with respect to the hyperbolic eccentric anomaly is: $$ dH = e \cosh H - 1 $$ where $H$ is the hyperbolic eccentric anomaly and $e$ is the eccentricity. """ return e * jnp.cosh(H) - 1
[docs] def mean_anomaly_elps(a: ArrayLike, deltat: ArrayLike, mu: ArrayLike = 1) -> Array: r"""Returns the mean anomaly for an elliptical orbit. Args: a: Semimajor axis of the orbit, a > 0. deltat: Time since periapsis passage. mu: (optional) Gravitational parameter of the central body. Returns: The mean anomaly for an elliptical orbit. Notes: The mean anomaly for an elliptical orbit is calculated using the formula: $$ M = \sqrt{\frac{\mu}{a^3}} \Delta t $$ where $M$ is the mean anomaly, $a>0$ is the semimajor axis, $\mu$ is the gravitational parameter, and $\Delta t$ is the time since periapsis passage. References: Battin, 1999, pp.160. Examples: A simple example of calculating the mean anomaly for an orbit with semimajor axis 1.0, gravitational parameter 1.0, and time since periapsis passage 1.0: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> a = 1.0 >>> mu = 1.0 >>> deltat = 1.0 >>> adx.mean_anomaly_elps(a, deltat, mu) Array(1., dtype=float32, weak_type=True) With broadcasting, you can calculate the mean anomaly for multiple semimajor axes, gravitational parameters, and times since periapsis passage: >>> a = jnp.array([1.0, 2.0]) >>> mu = jnp.array([1.0, 2.0]) >>> deltat = jnp.array([1.0, 1.0]) >>> adx.mean_anomaly_elps(a, deltat, mu) Array([1. , 0.5], dtype=float32) """ return jnp.sqrt(mu / a**3) * deltat
[docs] def mean_anomaly_hypb(a: ArrayLike, deltat: ArrayLike, mu: ArrayLike = 1) -> Array: r"""Returns the mean anomaly for a hyperbolic orbit. Args: a: Semimajor axis of the orbit, a < 0. deltat: Time since periapsis passage. mu: (optional) Gravitational parameter of the central body. Returns: The mean anomaly for a hyperbolic orbit. Notes: The mean anomaly for a hyperbolic orbit is calculated using the formula: $$ N = \sqrt{\frac{\mu}{-a^3}} \Delta t $$ where $N$ is the mean anomaly, $a<0$ is the semimajor axis, $\mu$ is the gravitational parameter, and $\Delta t$ is the time since periapsis passage. References: Battin, 1999, pp.166. Examples: A simple example of calculating the mean anomaly for an orbit with semimajor axis -1.0, gravitational parameter 1.0, and time since periapsis passage 1.0: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> a = -1.0 >>> mu = 1.0 >>> deltat = 1.0 >>> adx.mean_anomaly_hypb(a, deltat, mu) Array(1., dtype=float32, weak_type=True) With broadcasting, you can calculate the mean anomaly for multiple semimajor axes, gravitational parameters, and times since periapsis passage: >>> a = jnp.array([-1.0, -2.0]) >>> mu = jnp.array([1.0, 2.0]) >>> deltat = jnp.array([1.0, 1.0]) >>> adx.mean_anomaly_hypb(a, deltat, mu) Array([1. , 0.5], dtype=float32) """ return jnp.sqrt(mu / -(a**3)) * deltat
[docs] def kepler_equ_uni( chi: ArrayLike, alpha: DTypeLike = 1, r0: ArrayLike = 1, sigma0: ArrayLike = 0, deltat: ArrayLike = 0, mu: ArrayLike = 1, ) -> Array: r"""Returns the universal Kepler's equation in the form f(chi) = 0. Args: chi: The generalized anomaly. alpha: (optional) The reciprocal of the semimajor axis. r0: (optional) The radius at the initial time. sigma0: (optional) The sigma function at the initial time. deltat: (optional) The time since the initial time. mu: (optional) The gravitational parameter. Returns: The value of the universal Kepler's equation. Notes: The universal Kepler's equation is defined as: $$ r_0 U_1(\chi, \alpha) + \sigma_0 U_2(\chi, \alpha) + U_3(\chi, \alpha) - \sqrt{\mu} \Delta t = 0 $$ where $\Delta t$ is the time since the initial time, $\chi$ is the generalized anomaly, $\alpha = \frac{1}{a}$ is the reciprocal of semimajor axis, $\sigma_0$ is the sigma function at the initial time, $r_0$ is the norm of the position vector at the initial time, $\mu$ is the gravitational parameter, and $U_1$, $U_2$, and $U_3$ are the universal functions. References: Battin, 1999, pp.178. Examples: A simple example: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> chi = 1.0 >>> alpha = 1.0 >>> sigma0 = 0.0 >>> r0 = 1.0 >>> mu = 1.0 >>> deltat = 1.0 >>> adx.kepler_equ_uni(chi, alpha, r0, sigma0, deltat, mu) Array(0., dtype=float32, weak_type=True) With broadcasting: >>> chi = jnp.array([1.0, 2.0]) >>> alpha = 1. >>> sigma0 = jnp.array([0.0, 0.0]) >>> r0 = jnp.array([1.0, 1.0]) >>> deltat = jnp.array([1.0, 1.0]) >>> mu = jnp.array([1.0, 1.0]) >>> adx.kepler_equ_uni(chi, alpha, r0, sigma0, deltat, mu) Array([0., 1.], dtype=float32) """ return ( r0 * ufunc1(chi, alpha) + sigma0 * ufunc2(chi, alpha) + ufunc3(chi, alpha) - jnp.sqrt(mu) * deltat )
def dchi(chi: ArrayLike, alpha: DTypeLike, r0: ArrayLike, sigma0: ArrayLike) -> Array: return equ_of_orb_uvi( ufunc0(chi, alpha), ufunc1(chi, alpha), ufunc2(chi, alpha), r0, sigma0 )
[docs] def generalized_anomaly( alpha: ArrayLike, sigma: ArrayLike, sigma0: ArrayLike, deltat: ArrayLike = 0, mu: ArrayLike = 1, ) -> Array: r"""Returns the generalized anomaly. Args: alpha: The reciprocal of the semimajor axis. sigma: The sigma function at the current time. sigma0: The sigma function at the initial time. deltat: (optional) The time since the initial time. mu: (optional) The gravitational parameter. Returns: The generalized anomaly. Notes: The generalized anomaly is defined as: $$ \chi = \alpha \sqrt{\mu} \Delta t + \sigma - \sigma_0 $$ where $\chi$ is the generalized anomaly, $\alpha = \frac{1}{a}$ is the reciprocal of semimajor axis, $\sigma$ is the sigma function at the current time, $\sigma_0$ is the sigma function at the initial time, $\mu$ is the gravitational parameter, and $\Delta t$ is the time since the initial time. References: Battin, 1999, pp.179. Examples: A simple example: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> alpha = 1.0 >>> sigma = 1.0 >>> sigma0 = 0.0 >>> mu = 1.0 >>> deltat = 1.0 >>> adx.generalized_anomaly(alpha, sigma, sigma0, deltat, mu) Array(2., dtype=float32, weak_type=True) With broadcasting: >>> alpha = jnp.array([1.0, 1.0]) >>> sigma = jnp.array([1.0, 2.0]) >>> sigma0 = jnp.array([0.0, 0.0]) >>> mu = jnp.array([1.0, 1.0]) >>> deltat = jnp.array([1.0, 1.0]) >>> adx.generalized_anomaly(alpha, sigma, sigma0, deltat, mu) Array([2., 3.], dtype=float32) """ return alpha * jnp.sqrt(mu) * deltat + sigma - sigma0
[docs] def solve_kepler_elps( M: DTypeLike, e: DTypeLike, tol: DTypeLike = 1e-6, max_iter: int = 20 ) -> Array: r"""Returns the eccentric anomaly for an elliptical orbit. Args: M: Mean anomaly. e: Eccentricity of the orbit, 0 <= e < 1. tol: (optional) Tolerance for convergence. max_iter: (optional) Maximum number of iterations. Returns: The eccentric anomaly for an elliptical orbit. Notes: The eccentric anomaly is calculated by solving Kepler's equation for elliptical orbits: $$ E - e \sin E = M $$ where $E$ is the eccentric anomaly, $e$ is the eccentricity, and $M$ is the mean anomaly. References: Vallado, 2013, pp.76. Examples: A simple example: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> e = 0.37255 >>> M = 3.6029 >>> adx.solve_kepler_elps(M,e) Array(3.479..., dtype=float32, weak_type=True) """ E0 = M + e * jnp.sin(M) + e**2 * jnp.sin(2 * M) / 2 def cond_fn(val: tuple[int, Array, Array]) -> Array: iter_count, E, E_prev = val not_converged = jnp.abs(E - E_prev) > tol under_max_iter = iter_count < max_iter return not_converged & under_max_iter def body_fn(val: tuple[int, Array, Array]) -> Array: iter_count, E, _ = val f = E - e * jnp.sin(E) - M E_new = E - f / dE(E, e) return (iter_count + 1, E_new, E) _, E, _ = jax.lax.while_loop(cond_fn, body_fn, (0, E0, E0 + 2 * tol)) return E
[docs] def solve_kepler_hypb( N: DTypeLike, e: DTypeLike, tol: DTypeLike = 1e-6, max_iter: int = 50 ) -> Array: r"""Returns the hyperbolic eccentric anomaly for a hyperbolic orbit. Args: N: Hyperbolic mean anomaly. e: Eccentricity of the orbit, e > 1. tol: (optional) Tolerance for convergence. max_iter: (optional) Maximum number of iterations. Returns: The hyperbolic eccentric anomaly for a hyperbolic orbit. Notes: The hyperbolic eccentric anomaly is calculated by solving Kepler's equation for hyperbolic orbits: $$ e \sinh H - H = N $$ where $H$ is the hyperbolic eccentric anomaly, $e$ is the eccentricity, and $N$ is the hyperbolic mean anomaly. References: Battin, 1999, pp.168. Examples: A simple example: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> e = 2.7696 >>> N = 40.69 >>> adx.solve_kepler_hypb(N,e) Array(3.463..., dtype=float32, weak_type=True) """ H0 = jnp.log(2 * N / e + jnp.sqrt((2 * N / e) ** 2 + 1)) def cond_fn(val: tuple[int, Array, Array]) -> Array: iter_count, H, _ = val not_converged = jnp.abs(kepler_equ_hypb(H, e, N)) > tol under_max_iter = iter_count < max_iter return not_converged & under_max_iter def body_fn(val: tuple[int, Array, Array]) -> Array: iter_count, H, _ = val H_new = H - kepler_equ_hypb(H, e, N) / dH(H, e) return (iter_count + 1, H_new, H) _, H, _ = jax.lax.while_loop(cond_fn, body_fn, (0, H0, H0 + 2 * tol)) return H
[docs] def solve_kepler_uni( deltat: DTypeLike, alpha: DTypeLike, r0: DTypeLike, sigma0: DTypeLike, mu: DTypeLike = 1, tol: DTypeLike = 1e-6, max_iter: int = 50, ) -> Array: r"""Returns the generalized anomaly for a universal orbit equation. Args: deltat: The time since the initial time. alpha: The reciprocal of the semimajor axis. r0: The radius at the initial time. sigma0: The sigma function at the initial time. mu: (optional) The gravitational parameter. tol: (optional) Tolerance for convergence. max_iter: (optional) Maximum number of iterations. Returns: The generalized anomaly for a universal orbit equation. Notes: The generalized anomaly is calculated by solving the universal orbit equation: $$ r_0 U_1(\chi, \alpha) + \sigma_0 U_2(\chi, \alpha) + U_3(\chi, \alpha) - \sqrt{\mu} \Delta t = 0 $$ where $\chi$ is the generalized anomaly, $\alpha = \frac{1}{a}$ is the reciprocal of semimajor axis, $\sigma_0$ is the sigma function at the initial time, $r_0$ is the norm of the position vector at the initial time, $\mu$ is the gravitational parameter, $U_1$ is the universal function U1, $U_2$ is the universal function U2, $U_3$ is the universal function U3, and $\Delta t$ is the time since the initial time. References: Battin, 1999, pp.178. Examples: A simple example: >>> import jax.numpy as jnp >>> import astrodynx as adx >>> r0_vec = jnp.array([1.0, 0.0, 0.0]) >>> v0_vec = jnp.array([0.0, 1.1, 0.0]) >>> mu = 1.0 >>> deltat = jnp.pi*0.5 >>> r0 = jnp.linalg.vector_norm(r0_vec) >>> v0 = jnp.linalg.vector_norm(v0_vec) >>> alpha = 1.0 / adx.semimajor_axis(r0, v0, mu) >>> sigma0 = adx.twobody.sigma_fn(r0_vec, v0_vec, mu) >>> chi = adx.solve_kepler_uni(deltat, alpha.item(), r0.item(), sigma0.item(), mu) >>> assert adx.kepler_equ_uni(chi,alpha,r0, sigma0, deltat, mu) < 1e-6 """ chi0 = jnp.sqrt(mu) * jnp.abs(alpha) * deltat def cond_fn(val: tuple[int, Array, Array]) -> Array: iter_count, chi, _ = val not_converged = ( jnp.abs(kepler_equ_uni(chi, alpha, r0, sigma0, deltat, mu)) > tol ) under_max_iter = iter_count < max_iter return not_converged & under_max_iter def body_fn(val: tuple[int, Array, Array]) -> Array: iter_count, chi, _ = val chi_new = chi - kepler_equ_uni(chi, alpha, r0, sigma0, deltat, mu) / dchi( chi, alpha, r0, sigma0 ) return (iter_count + 1, chi_new, chi) _, chi, _ = jax.lax.while_loop(cond_fn, body_fn, (0, chi0, chi0 + 2 * tol)) return chi