Source code for meepmeep.backends.numba.orbit3dd.light_travel_time

#  MeepMeep: fast orbit calculations for exoplanet modelling
#  Copyright (C) 2022-2026 Hannu Parviainen
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.

"""Multi-expansion-point light-travel-time correction evaluators with parameter derivatives."""

from numba import njit, prange, types, get_num_threads, get_thread_id
from numba.extending import overload
from numpy import zeros, pi, floor, ndarray

from ..point3d.zvelocity import zvel_c
from .zposition import _zpos_osd, _zpos_ow
from ..utils import mean_anomaly_at_transit_with_derivatives
from ._common import _is_1d_array


# Time taken by light to traverse one solar radius, in days. Kept in sync
# with ``orbit3d.LTT_DAYS_PER_RSUN``.
LTT_DAYS_PER_RSUN = 2.685885891543453e-05


@njit(fastmath=True)
def _ltt_transit_z_and_d(tpa, p, e, w, dt, ep_table, ep_times, coeffs, dcoeffs):
    """Compute ``z(t_transit)`` and its full chain-rule derivative.

    Helper for the light-travel-time derivatives. The transit time depends
    on the orbital parameters via
    :math:`t_o = M_\\mathrm{tr}(e, w) \\cdot p / (2\\pi)`, so the total
    derivative is

    .. math::

        \\frac{\\mathrm d}{\\mathrm d\\theta_k}\\bigl[z(t_\\mathrm{transit}(\\theta);\\theta)\\bigr]
            = v_z(t_\\mathrm{transit})\\,\\frac{\\mathrm d t_o}{\\mathrm d \\theta_k}
              + \\left.\\frac{\\partial z}{\\partial \\theta_k}\\right|_{t=t_\\mathrm{transit}}

    with the non-zero entries of ``dto`` being ``dto/dp = M_tr / (2π)``,
    ``dto/de = (dM_tr/de) · p / (2π)``, ``dto/dw = (dM_tr/dw) · p / (2π)``,
    and zero for ``phase, a, i``.

    Parameters
    ----------
    tpa : float
        Periastron time anchoring the expansion-point grid (see :func:`_pos_osd`).
    p, e, w : float
        Orbital period [days], eccentricity, argument of periastron [radians].
    dt, ep_table, ep_times, coeffs, dcoeffs :
        Multi-expansion-point dispatch arrays.

    Returns
    -------
    z_tr : float
        Line-of-sight planet coordinate at transit time.
    dz_tr_total : ndarray, shape (7,)
        Full total derivative of ``z(t_transit)`` w.r.t.
        ``(tc, p, a, i, e, w, lan)``.

    Notes
    -----
    The ``phase`` slot inherits the multi-expansion-point caveat documented at module
    level: it reflects a per-expansion-point phase shift at the expansion point containing
    ``t_transit``, not a global user-facing T0 shift.
    """
    m_tr, dm_tr_de, dm_tr_dw = mean_anomaly_at_transit_with_derivatives(e, w)
    two_pi = 2.0 * pi
    to = m_tr / two_pi * p
    t_transit = tpa + to

    # Evaluate z and its (∂z/∂θ)|_{t=t_transit}.
    z_tr, dz_tr_partial = _zpos_osd(t_transit, tpa, p, dt, ep_table, ep_times, coeffs, dcoeffs)
    # Velocity at transit (for the dt_transit/dθ chain term).
    vz_tr = _zvel_os(t_transit, tpa, p, dt, ep_table, ep_times, coeffs)

    # dto/dθ: only slots 1 (p), 4 (e), 5 (w) are non-zero.
    dto = zeros(7)
    dto[1] = m_tr / two_pi
    dto[4] = dm_tr_de * p / two_pi
    dto[5] = dm_tr_dw * p / two_pi

    dz_tr_total = zeros(7)
    for k in range(7):
        dz_tr_total[k] = vz_tr * dto[k] + dz_tr_partial[k]
    return z_tr, dz_tr_total


@njit(fastmath=True)
def _zvel_os(t, tpa, p, dt, ep_table, ep_times, coeffs):
    """Local z-velocity helper used by ``_ltt_transit_z_and_d``.

    Mirrors ``orbit3d.zvel_os`` but kept private here to avoid a
    cross-module import cycle.

    Parameters
    ----------
    t : float
        Time at which to evaluate the z-velocity.
    tpa, p, dt, ep_table, ep_times, coeffs :
        See :func:`_pos_osd` (no ``dcoeffs`` — this is a value-only helper).

    Returns
    -------
    vz : float
        Line-of-sight velocity [:math:`R_\\star/\\mathrm{day}`].
    """
    epoch = floor((t - tpa) / p)
    tc = t - tpa - epoch * p
    ix = ep_table[int(floor(tc / (dt * p)))]
    return zvel_c(tc - ep_times[ix] * p, coeffs[ix])


@njit(fastmath=True)
def _light_travel_time_osd(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs):
    """Scalar kernel for :func:`light_travel_time_od`. See that function for documentation."""
    z_t, dz_t = _zpos_osd(t, tpa, p, dt, ep_table, ep_times, coeffs, dcoeffs)
    z_tr, dz_tr = _ltt_transit_z_and_d(tpa, p, e, w, dt, ep_table, ep_times, coeffs, dcoeffs)
    factor = -rstar * LTT_DAYS_PER_RSUN
    ltt = factor * (z_t - z_tr)
    dltt = zeros(7)
    for k in range(7):
        dltt[k] = factor * (dz_t[k] - dz_tr[k])
    return ltt, dltt


@njit(fastmath=True)
def light_travel_time_ovd(times, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs):
    """Vector kernel for :func:`light_travel_time_od`. See that function for documentation."""
    n = times.size
    ltt = zeros(n)
    dltt = zeros((n, 7))
    factor = -rstar * LTT_DAYS_PER_RSUN
    # Reference (z and its full derivative chain) computed once.
    z_tr, dz_tr = _ltt_transit_z_and_d(tpa, p, e, w, dt, ep_table, ep_times, coeffs, dcoeffs)
    dz = zeros(7)
    for j in range(n):
        z = _zpos_ow(times[j], tpa, p, dt, ep_table, ep_times, coeffs, dcoeffs, dz)
        ltt[j] = factor * (z - z_tr)
        for k in range(7):
            dltt[j, k] = factor * (dz[k] - dz_tr[k])
    return ltt, dltt


@njit(fastmath=True, parallel=True)
def light_travel_time_ovdp(times, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs):
    """Parallel (prange) twin of :func:`light_travel_time_ovd`.

    The z-gradient scratch is hoisted per thread; a single shared buffer
    would be a data race under ``prange``. The transit reference and its
    derivative chain are computed once, as in the serial kernel.
    """
    n = times.size
    ltt = zeros(n)
    dltt = zeros((n, 7))
    factor = -rstar * LTT_DAYS_PER_RSUN
    z_tr, dz_tr = _ltt_transit_z_and_d(tpa, p, e, w, dt, ep_table, ep_times, coeffs, dcoeffs)
    dz = zeros((get_num_threads(), 7))
    for j in prange(n):
        dzj = dz[get_thread_id()]
        z = _zpos_ow(times[j], tpa, p, dt, ep_table, ep_times, coeffs, dcoeffs, dzj)
        ltt[j] = factor * (z - z_tr)
        for kk in range(7):
            dltt[j, kk] = factor * (dzj[kk] - dz_tr[kk])
    return ltt, dltt


[docs] def light_travel_time_od(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs): """Light travel time correction with gradients. Accepts a scalar time or a 1-D array of times and dispatches to the scalar (:func:`_light_travel_time_osd`) or vector (:func:`light_travel_time_ovd`) kernel at compile time (inside ``@njit``) or at call time (pure Python). The correction is referenced to primary transit: .. math:: \\mathrm{ltt}(t) = -(z(t) - z(t_\\mathrm{transit}))\\,r_\\star\\,(R_\\odot / c) where :math:`t_\\mathrm{transit} = t_\\mathrm{pa} + M_\\mathrm{tr}(e, w)\\,p/(2\\pi)`. Per spec, the partial derivative w.r.t. ``rstar`` is *not* returned - only the seven orbital derivatives in the canonical ``(tc, p, a, i, e, w, lan)`` order. The reference ``z(t_transit)`` and its parameter derivatives are computed by :func:`_ltt_transit_z_and_d`, which includes the chain rule through ``t_transit(p, e, w)`` using ``vz(t_transit)``. Parameters ---------- t : float or ndarray Time(s) at which to evaluate the correction and gradient. tpa : float Periastron time anchoring the expansion-point grid (see :func:`_pos_osd`). p : float Orbital period [days]. e : float Eccentricity. w : float Argument of periastron [radians]. rstar : float Stellar radius [R_sun]. dt, ep_table, ep_times, coeffs, dcoeffs : Multi-expansion-point dispatch arrays. Returns ------- ltt : float or ndarray Light travel time correction [days]. Arrays of shape (N,) for an array time argument. dltt : ndarray Gradient w.r.t. ``(tc, p, a, i, e, w, lan)``. Shape (7,) for a scalar time, (N, 7) for an array time. """ if isinstance(t, ndarray): return light_travel_time_ovd(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs) return _light_travel_time_osd(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs)
@overload(light_travel_time_od, jit_options={'fastmath': True}) def _light_travel_time_od_overload(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs): if _is_1d_array(t): def impl(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs): return light_travel_time_ovd(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs) return impl if isinstance(t, types.Float): def impl(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs): return _light_travel_time_osd(t, tpa, p, e, w, rstar, dt, ep_table, ep_times, coeffs, dcoeffs) return impl return None