Source code for astrodynx.prop._cowell_method

from astrodynx import diffeq
from jaxtyping import ArrayLike, DTypeLike, PyTree
from typing import Any, NamedTuple

"""General implementations of Cowell's method for propagating orbits under perturbing forces."""


[docs] class OrbDynx(NamedTuple): """Orbital dynamics configuration for Cowell's method propagation. This NamedTuple encapsulates the essential components needed for orbital propagation using Cowell's method, including the differential equation terms, static arguments, and optional event detection. Attributes: terms: The differential equation terms defining the orbital dynamics. Typically an ODETerm containing the vector field function that computes accelerations from gravitational and perturbation forces. args: Static arguments passed to the differential equation. Common arguments include gravitational parameter (mu), perturbation parameters (J2, R_eq), and event thresholds (rmin). Defaults to {"mu": 1.0}. event: Event detection configuration for terminating propagation early when specific conditions are met (e.g., ground impact, apogee passage). Notes: This class is designed to work with the ODE solver. The terms should define the complete orbital dynamics including all relevant forces and perturbations. The args parameter uses JAX's PyTree structure, allowing for efficient compilation and automatic differentiation of the propagation process. Examples: Basic two-body orbital dynamics: >>> import jax.numpy as jnp >>> from astrodynx import diffeq >>> import astrodynx as adx >>> def vector_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(vector_field), ... args={"mu": 1.0} ... ) With J2 perturbations and event detection: >>> def perturbed_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... acc += adx.gravity.j2_acc(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(perturbed_field), ... args={"mu": 1.0, "J2": 1e-3, "R_eq": 1.0, "rmin": 0.1}, ... event=diffeq.Event(adx.events.radius_islow) ... ) """ terms: diffeq.ODETerm args: PyTree[Any] = {"mu": 1.0} event: diffeq.Event = None
[docs] def cowell_method( orbdyn: OrbDynx, x0: ArrayLike, t1: DTypeLike, saveat: diffeq.SaveAt, dt0: DTypeLike = None, max_steps: int = 4096, solver: diffeq.AbstractSolver = diffeq.Tsit5(), stepsize_controller: diffeq.AbstractStepSizeController = diffeq.PIDController( rtol=1e-8, atol=1e-8 ), ) -> diffeq.Solution: """General Cowell's method orbital propagation function. This function provides a flexible interface for propagating orbital states using Cowell's method with customizable step size control and output times. Args: orbdyn: Orbital dynamics configuration containing the differential equation terms, static arguments, and optional events. x0: (6,)Initial state vector [x, y, z, vx, vy, vz] in canonical units. Position components are in distance units, velocity components are in distance/time units. t1: Final integration time in canonical time units. Must be positive for forward propagation. saveat: Configuration specifying when to save the solution during integration. dt0: Initial time step size for integration in canonical time units. If None, the solver will choose an appropriate initial step size. max_steps: Maximum number of integration steps before terminating unconditionally. Prevents infinite loops in case of integration issues. Defaults to 4096. solver: Numerical integration method. Defaults to diffeq.Tsit5() (5th-order Runge-Kutta method). stepsize_controller: Step size controller for adaptive or fixed stepping. Returns: Integration solution containing - ts: Array of time points where solution was saved - ys: Array of state vectors at each time point - stats: Integration statistics and diagnostics - result: Integration termination status Notes: This function serves as a general interface for Cowell's method orbital propagation. Users can specify custom step size controllers and output times to suit their specific needs. It is recommended to use the specialized functions (:meth:`fixed_steps`, :meth:`adaptive_steps`, :meth:`custom_steps`, :meth:`to_final`) for common use cases for better clarity and ease of use. Examples: Basic orbital propagation with dense output: >>> import jax.numpy as jnp >>> from astrodynx import diffeq >>> import astrodynx as adx >>> def vector_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(vector_field), ... args={"mu": 1.0} ... ) >>> x0 = jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]) >>> t1 = jnp.pi*2 # One orbital period >>> saveat=diffeq.SaveAt(dense=True,subs=diffeq.SubSaveAt(t1=True)) >>> sol = adx.prop.cowell_method( ... orbdyn, x0, t1, ... saveat=saveat ... ) >>> tf = sol.ts[-1] >>> yf = sol.ys[-1] >>> yf_dense = sol.evaluate(tf) >>> assert jnp.allclose(yf, yf_dense, atol=1e-7) """ return diffeq.diffeqsolve( terms=orbdyn.terms, solver=solver, t0=0, t1=t1, dt0=dt0, y0=x0, args=orbdyn.args, max_steps=max_steps, stepsize_controller=stepsize_controller, saveat=saveat, event=orbdyn.event, )
[docs] def fixed_steps( orbdyn: OrbDynx, x0: ArrayLike, t1: DTypeLike, dt: DTypeLike, max_steps: int = 4096, solver: diffeq.AbstractSolver = diffeq.Tsit5(), ) -> diffeq.Solution: """Propagate orbital state using Cowell's method with fixed step sizes. This function solves the orbital dynamics differential equation using a constant step size integrator. It's suitable for scenarios where uniform time sampling is required or when computational efficiency is prioritized over adaptive error control. Args: orbdyn: Orbital dynamics configuration containing the differential equation terms, static arguments, and optional events. x0: (6,)Initial state vector [x, y, z, vx, vy, vz] in canonical units. Position components are in distance units, velocity components are in distance/time units. t1: Final integration time in canonical time units. Must be positive for forward propagation. dt: Fixed time step size for integration in canonical time units. Smaller values increase accuracy but require more computational time. max_steps: Maximum number of integration steps before terminating unconditionally. Prevents infinite loops in case of integration issues. Defaults to 4096. solver: Numerical integration method. Defaults to diffeq.Tsit5() (5th-order Runge-Kutta method). Returns: Integration solution containing - ts: Array of time points where solution was saved - ys: Array of state vectors at each time point - stats: Integration statistics and diagnostics - result: Integration termination status Notes: This function uses a constant step size controller, which means the integrator will take exactly dt-sized steps regardless of local error. This can be more efficient than adaptive methods but may sacrifice accuracy in regions where the dynamics change rapidly. The solution is saved at the initial time (t0=True) and at every integration step (steps=True), providing a complete trajectory. Examples: Basic orbital propagation with fixed steps: >>> import jax.numpy as jnp >>> from astrodynx import diffeq >>> import astrodynx as adx >>> def vector_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(vector_field), ... args={"mu": 1.0} ... ) >>> x0 = jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]) >>> t1 = jnp.pi*2 # One orbital period >>> dt = 0.01 # Fixed step size >>> sol = adx.prop.fixed_steps(orbdyn, x0, t1, dt) """ return diffeq.diffeqsolve( terms=orbdyn.terms, solver=solver, t0=0, t1=t1, dt0=dt, y0=x0, args=orbdyn.args, max_steps=max_steps, stepsize_controller=diffeq.ConstantStepSize(), saveat=diffeq.SaveAt(t0=True, steps=True), event=orbdyn.event, )
[docs] def adaptive_steps( orbdyn: OrbDynx, x0: ArrayLike, t1: DTypeLike, max_steps: int = 4096, solver: diffeq.AbstractSolver = diffeq.Tsit5(), stepsize_controller: diffeq.AbstractStepSizeController = diffeq.PIDController( rtol=1e-8, atol=1e-8 ), ) -> diffeq.Solution: """Propagate orbital state using Cowell's method with adaptive step sizes. This function solves the orbital dynamics differential equation using an adaptive step size integrator that automatically adjusts the time step based on local error estimates. This provides optimal balance between accuracy and computational efficiency. Args: orbdyn: Orbital dynamics configuration containing the differential equation terms, static arguments, and optional events. x0: (6,)Initial state vector [x, y, z, vx, vy, vz] in canonical units. Position components are in distance units, velocity components are in distance/time units. t1: Final integration time in canonical time units. Must be positive for forward propagation. max_steps: Maximum number of integration steps before terminating unconditionally. Prevents infinite loops and controls computational cost. Defaults to 4096. solver: Numerical integration method. Defaults to diffeq.Tsit5() (5th-order Runge-Kutta method with embedded error estimation). stepsize_controller: Adaptive step size controller. Returns: Integration solution containing - ts: Array of time points where solution was saved (variable spacing) - ys: Array of state vectors at each time point - stats: Integration statistics including step counts and rejections - result: Integration termination status Notes: The adaptive step size controller automatically adjusts the time step to maintain the specified error tolerances. This results in smaller steps during periods of rapid change (e.g., near periapsis) and larger steps during smoother motion (e.g., near apoapsis). The initial step size is set to 1% of the total integration time (t1 * 0.01), which provides a reasonable starting point for most orbital scenarios. The solution is saved at the initial time and at every accepted step, providing a complete trajectory with optimal time resolution. Examples: Orbital propagation with adaptive steps: >>> import jax.numpy as jnp >>> from astrodynx import diffeq >>> import astrodynx as adx >>> def vector_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(vector_field), ... args={"mu": 1.0} ... ) >>> x0 = jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]) >>> t1 = jnp.pi*2 # One orbital period >>> sol = adx.prop.adaptive_steps(orbdyn, x0, t1) >>> xf = sol.ys[jnp.isfinite(sol.ts)][-1] Eccentric orbit with J2 perturbations and event detection: >>> def perturbed_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... acc += adx.gravity.j2_acc(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(perturbed_field), ... args = {"mu": 1.0, "rmin": 0.7, "J2": 1e-6, "R_eq": 1.0}, ... event = diffeq.Event(adx.events.radius_islow) ... ) >>> x0 = jnp.array([1.0, 0.0, 0.0, 0.0, 0.9, 0.0]) >>> sol = adx.prop.adaptive_steps(orbdyn, x0, t1) >>> xf = sol.ys[jnp.isfinite(sol.ts)][-1] >>> expected = jnp.array([-0.59,0.36, 0.,-0.58,-1.16, 0.]) """ return diffeq.diffeqsolve( terms=orbdyn.terms, solver=solver, t0=0, t1=t1, dt0=None, y0=x0, args=orbdyn.args, max_steps=max_steps, stepsize_controller=stepsize_controller, saveat=diffeq.SaveAt(t0=True, steps=True), event=orbdyn.event, )
[docs] def custom_steps( orbdyn: OrbDynx, x0: ArrayLike, t1: DTypeLike, ts: ArrayLike, solver: diffeq.AbstractSolver = diffeq.Tsit5(), stepsize_controller: diffeq.AbstractStepSizeController = diffeq.PIDController( rtol=1e-8, atol=1e-8 ), ) -> diffeq.Solution: """Propagate orbital state using Cowell's method with custom output times. This function solves the orbital dynamics differential equation and saves the solution at user-specified time points. It uses adaptive step size control for accuracy while providing output at exactly the requested times through interpolation. Args: orbdyn: Orbital dynamics configuration containing the differential equation terms, static arguments, and optional events. x0: (6,)Initial state vector [x, y, z, vx, vy, vz] in canonical units. Position components are in distance units, velocity components are in distance/time units. t1: Final integration time in canonical time units. Must be positive and should be >= max(ts) for complete coverage. ts: Array of time points where the solution should be saved, in canonical time units. Can be irregularly spaced and does not need to include t=0 or t=t1. solver: Numerical integration method. Defaults to diffeq.Tsit5() (5th-order Runge-Kutta method). stepsize_controller: Adaptive step size controller. Returns: Integration solution containing - ts: Array of requested time points (same as input ts) - ys: Array of interpolated state vectors at requested times - stats: Integration statistics from the adaptive stepping - result: Integration termination status Notes: This function is ideal when you need the orbital state at specific times (e.g., for comparison with observations, mission planning, or analysis at predetermined epochs). The integrator uses adaptive stepping internally but interpolates to provide output at exactly the requested times. The max_steps parameter is set to None, allowing unlimited steps to ensure the integration can reach all requested time points. If any requested time in ts is beyond t1, those points will not be computed. Ensure t1 >= max(ts) for complete coverage. Examples: Orbital state at specific observation times: >>> import jax.numpy as jnp >>> from astrodynx import diffeq >>> import astrodynx as adx >>> def vector_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(vector_field), ... args={"mu": 1.0} ... ) >>> x0 = jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]) >>> t1 = jnp.pi*2 # One orbital period >>> obs_times = jnp.array([0.5, 1.2, 2.8, 4.1, 5.9]) >>> sol = adx.prop.custom_steps(orbdyn, x0, t1, obs_times) """ return diffeq.diffeqsolve( terms=orbdyn.terms, solver=solver, t0=0, t1=t1, dt0=None, y0=x0, args=orbdyn.args, max_steps=None, stepsize_controller=stepsize_controller, saveat=diffeq.SaveAt(ts=ts), event=orbdyn.event, )
[docs] def to_final( orbdyn: OrbDynx, x0: ArrayLike, t1: DTypeLike, solver: diffeq.AbstractSolver = diffeq.Tsit5(), stepsize_controller: diffeq.AbstractStepSizeController = diffeq.PIDController( rtol=1e-8, atol=1e-8 ), ) -> diffeq.Solution: """Propagate orbital state using Cowell's method to final time only. This function solves the orbital dynamics differential equation and returns only the final state at time t1. It's the most memory-efficient option when intermediate trajectory points are not needed, such as for state transition matrix calculations or endpoint optimization problems. Args: orbdyn: Orbital dynamics configuration containing the differential equation terms, static arguments, and optional events. x0: (6,)Initial state vector [x, y, z, vx, vy, vz] in canonical units. Position components are in distance units, velocity components are in distance/time units. t1: Final integration time in canonical time units. Must be positive for forward propagation. solver: Numerical integration method. Defaults to diffeq.Tsit5() (5th-order Runge-Kutta method). stepsize_controller: Adaptive step size controller. Returns: Integration solution containing - ts: Single-element array containing only t1 - ys: Single state vector at the final time t1 - stats: Integration statistics from the adaptive stepping - result: Integration termination status Notes: This function is optimized for memory efficiency by saving only the final state. It uses adaptive step size control internally but discards all intermediate results, making it ideal for: - State transition matrix computations - Optimization problems requiring only final states - Monte Carlo simulations with many trajectories - Sensitivity analysis using automatic differentiation The maximum number of steps is limited to 4096 to prevent runaway computations while still allowing for complex orbital dynamics. Examples: Simple state propagation to final time: >>> import jax.numpy as jnp >>> from astrodynx import diffeq >>> import astrodynx as adx >>> def vector_field(t, x, args): ... acc = adx.gravity.point_mass_grav(t, x, args) ... return jnp.concatenate([x[3:], acc]) >>> orbdyn = adx.prop.OrbDynx( ... terms=diffeq.ODETerm(vector_field), ... args={"mu": 1.0} ... ) >>> x0 = jnp.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0]) >>> t1 = jnp.pi*2 # One orbital period >>> sol = adx.prop.to_final(orbdyn, x0, t1) """ return diffeq.diffeqsolve( terms=orbdyn.terms, solver=solver, t0=0, t1=t1, dt0=None, y0=x0, args=orbdyn.args, max_steps=4096, stepsize_controller=stepsize_controller, saveat=diffeq.SaveAt(subs=diffeq.SubSaveAt(t1=True)), event=orbdyn.event, )