Source code for astrodynx.utils._rot_mat

import jax.numpy as jnp
from jax.typing import ArrayLike
from jax import Array


[docs] def rotmat3dx(angle: ArrayLike) -> Array: r"""Returns a 3x3 rotation matrix for a given angle around the x-axis. Args: angle: The angle in radians to rotate around the x-axis. Returns: A 3x3 rotation matrix that rotates vectors around the x-axis by the specified angle. Notes: The rotation matrix is defined as: $$ R_x(\theta) = \begin{bmatrix} 1 & 0 & 0 \\ 0 & \cos(\theta) & -\sin(\theta) \\ 0 & \sin(\theta) & \cos(\theta) \end{bmatrix} $$ where $\theta$ is the angle of rotation. References: Battin, 1999, pp.85. Examples: Creating a rotation matrix for a 90-degree rotation (π/2 radians): >>> import jax.numpy as jnp >>> import astrodynx as adx >>> angle = jnp.pi / 2 >>> jnp.allclose(adx.utils.rotmat3dx(angle), jnp.array([[1., 0., 0.], [0., 0., -1.], [0., 1., 0.]]), atol=1e-7) Array(True, dtype=bool) Broadcasting with an array of angles: >>> angles = jnp.array([0.0, jnp.pi / 2]) >>> results = jnp.stack([adx.utils.rotmat3dx(a) for a in angles]) >>> expected0 = jnp.eye(3) >>> expected1 = jnp.array([[1., 0., 0.], [0., 0., -1.], [0., 1., 0.]]) >>> jnp.allclose(results[0], expected0, atol=1e-7) Array(True, dtype=bool) >>> jnp.allclose(results[1], expected1, atol=1e-7) Array(True, dtype=bool) """ c = jnp.cos(angle) s = jnp.sin(angle) z = jnp.zeros_like(angle) o = jnp.ones_like(angle) return jnp.stack( [ jnp.stack([o, z, z], axis=-1), jnp.stack([z, c, -s], axis=-1), jnp.stack([z, s, c], axis=-1), ], axis=-2, )
[docs] def rotmat3dy(angle: ArrayLike) -> Array: r"""Returns a 3x3 rotation matrix for a given angle around the y-axis. Args: angle: The angle in radians to rotate around the y-axis. Returns: A 3x3 rotation matrix that rotates vectors around the y-axis by the specified angle. Notes: The rotation matrix is defined as: $$ R_y(\theta) = \begin{bmatrix} \cos(\theta) & 0 & \sin(\theta) \\ 0 & 1 & 0 \\ -\sin(\theta) & 0 & \cos(\theta) \end{bmatrix} $$ where $\theta$ is the angle of rotation. References: Battin, 1999, pp.85. Examples: Creating a rotation matrix for a 90-degree rotation (π/2 radians): >>> import jax.numpy as jnp >>> import astrodynx as adx >>> angle = jnp.pi / 2 >>> jnp.allclose(adx.utils.rotmat3dy(angle), jnp.array([[0., 0., 1.], [0., 1., 0.], [-1., 0., 0.]]), atol=1e-7) Array(True, dtype=bool) Broadcasting with an array of angles: >>> angles = jnp.array([0.0, jnp.pi / 2]) >>> results = jnp.stack([adx.utils.rotmat3dy(a) for a in angles]) >>> expected0 = jnp.eye(3) >>> expected1 = jnp.array([[0., 0., 1.], [0., 1., 0.], [-1., 0., 0.]]) >>> jnp.allclose(results[0], expected0, atol=1e-7) Array(True, dtype=bool) >>> jnp.allclose(results[1], expected1, atol=1e-7) Array(True, dtype=bool) """ c = jnp.cos(angle) s = jnp.sin(angle) z = jnp.zeros_like(angle) o = jnp.ones_like(angle) return jnp.stack( [ jnp.stack([c, z, s], axis=-1), jnp.stack([z, o, z], axis=-1), jnp.stack([-s, z, c], axis=-1), ], axis=-2, )
[docs] def rotmat3dz(angle: ArrayLike) -> Array: r"""Returns a 3x3 rotation matrix for a given angle around the z-axis. Args: angle: The angle in radians to rotate around the z-axis. Returns: A 3x3 rotation matrix that rotates vectors around the z-axis by the specified angle. Notes: The rotation matrix is defined as: $$ R_z(\theta) = \begin{bmatrix} \cos(\theta) & -\sin(\theta) & 0 \\ \sin(\theta) & \cos(\theta) & 0 \\ 0 & 0 & 1 \end{bmatrix} $$ where $\theta$ is the angle of rotation. References: Battin, 1999, pp.85. Examples: Creating a rotation matrix for a 90-degree rotation (π/2 radians): >>> import jax.numpy as jnp >>> import astrodynx as adx >>> angle = jnp.pi / 2 >>> jnp.allclose(adx.utils.rotmat3dz(angle), jnp.array([[0., -1., 0.], [1., 0., 0.], [0., 0., 1.]]), atol=1e-7) Array(True, dtype=bool) Broadcasting with an array of angles: >>> angles = jnp.array([0.0, jnp.pi / 2]) >>> results = jnp.stack([adx.utils.rotmat3dz(a) for a in angles]) >>> expected0 = jnp.eye(3) >>> expected1 = jnp.array([[0., -1., 0.], [1., 0., 0.], [0., 0., 1.]]) >>> jnp.allclose(results[0], expected0, atol=1e-7) Array(True, dtype=bool) >>> jnp.allclose(results[1], expected1, atol=1e-7) Array(True, dtype=bool) """ c = jnp.cos(angle) s = jnp.sin(angle) z = jnp.zeros_like(angle) o = jnp.ones_like(angle) return jnp.stack( [ jnp.stack([c, -s, z], axis=-1), jnp.stack([s, c, z], axis=-1), jnp.stack([z, z, o], axis=-1), ], axis=-2, )