# Copyright (C) 2023, 2024, 2025 Cecilio García Quirós
"""
Rewrite amplitude, phase ansatze to use numba
There are functions for evaluation in one single point and
in a time array, which is parallelize with prange.
Unfortunately this requires some code duplication respect to
the ansatze defined in internals.py.
"""
import numpy as np
from numba import njit, prange
from phenomxpy.config import NUMBA_USE_CACHE, NUMBA_USE_PARALLEL
# Build waveform A * exp(-i phase)
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def combine_amp_phase(amp, phase):
"""Numba version to build waveform from arrays of amplitude and phase."""
out = np.empty(len(amp), np.cdouble)
for idx in prange(len(amp)):
out[idx] = amp[idx] * np.exp(-1j * phase[idx])
return out
#########################################
# AMPLITUDE ANSATZE
#########################################
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_inspiral_ansatz_amplitude(x, fac0, pn_real_coeffs, pn_imag_coeffs, pseudo_pn_coeffs):
"""Numba version of phenomt.internals.inspiral_ansatz_amplitude for one time step."""
xhalf = np.sqrt(x)
x1half = x * xhalf
x2 = x * x
x2half = x2 * xhalf
x3 = x2 * x
x3half = x3 * xhalf
x4 = x2 * x2
x4half = x4 * xhalf
x5 = x3 * x2
ampreal = (
pn_real_coeffs[0]
+ pn_real_coeffs[1] * xhalf
+ pn_real_coeffs[2] * x
+ pn_real_coeffs[3] * x1half
+ pn_real_coeffs[4] * x2
+ pn_real_coeffs[5] * x2half
+ pn_real_coeffs[6] * x3
+ pn_real_coeffs[7] * x3half
+ pn_real_coeffs[8] * np.log(16 * x) * x3
)
ampimag = (
pn_imag_coeffs[0] * xhalf
+ pn_imag_coeffs[1] * x
+ pn_imag_coeffs[2] * x1half
+ pn_imag_coeffs[3] * x2
+ pn_imag_coeffs[4] * x2half
+ pn_imag_coeffs[5] * x3
+ pn_imag_coeffs[6] * x3half
)
ampreal += pseudo_pn_coeffs[0] * x4 + pseudo_pn_coeffs[1] * x4half + pseudo_pn_coeffs[2] * x5
return fac0 * x * (ampreal + 1j * ampimag)
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def numba_inspiral_ansatz_amplitude_array(x, fac0, pn_real_coeffs, pn_imag_coeffs, pseudo_pn_coeffs):
"""Numba version of phenomt.internals.inspiral_ansatz_amplitude for an array of time steps."""
out = np.empty(len(x), np.cdouble)
for idx in prange(len(x)):
out[idx] = numba_inspiral_ansatz_amplitude(x[idx], fac0, pn_real_coeffs, pn_imag_coeffs, pseudo_pn_coeffs)
return out
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_ringdown_ansatz_amplitude(time, c1_prec, c2_prec, c3, c4_prec, alpha1RD_prec, tshift):
"""Numba version of phenomt.internals.ringdown_ansatz_amplitude for one time step."""
tanh = np.tanh(c2_prec * (time - tshift) + c3)
expAlpha = np.exp(-alpha1RD_prec * (time - tshift))
return expAlpha * (c1_prec * tanh + c4_prec)
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def numba_ringdown_ansatz_amplitude_array(times, c1_prec, c2_prec, c3, c4_prec, alpha1RD_prec, tshift):
"""Numba version of phenomt.internals.ringdown_ansatz_amplitude for an array of time steps."""
out = np.empty(len(times))
for idx in prange(len(times)):
out[idx] = numba_ringdown_ansatz_amplitude(times[idx], c1_prec, c2_prec, c3, c4_prec, alpha1RD_prec, tshift)
return out
#########################################
# OMEGA ANSATZE
#########################################
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_pn_ansatz_omega(theta, coefficients):
"""Numba version of phenomt.internals.pn_ansatz_omega for one time step."""
theta2 = theta * theta
theta3 = theta2 * theta
theta4 = theta2 * theta2
theta5 = theta3 * theta2
theta6 = theta3 * theta3
theta7 = theta4 * theta3
logterm = 107 * np.log(theta) / 280
fac = theta3 / 4
return fac * (
1
+ coefficients[0] * theta2
+ coefficients[1] * theta3
+ coefficients[2] * theta4
+ coefficients[3] * theta5
+ coefficients[4] * theta6
+ logterm * theta6
+ coefficients[5] * theta7
)
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def numba_pn_ansatz_omega_array(theta, coefficients):
"""Numba version of phenomt.internals.pn_ansatz_omega for an array of time steps."""
out = np.empty(len(theta))
for idx in prange(len(theta)):
out[idx] = numba_pn_ansatz_omega(theta[idx], coefficients)
return out
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_inspiral_ansatz_omega(time, eta, pn_coefficients, coefficients):
"""Numba version of phenomt.internals.inspiral_ansatz_omega for one time step."""
theta = np.power(-eta * time / 5, -1 / 8)
# taylort3 = numba_pn_ansatz_omega(theta, pn_coefficients)
theta2 = theta * theta
theta3 = theta * theta2
theta4 = theta * theta3
theta5 = theta * theta4
theta6 = theta * theta5
theta7 = theta * theta6
theta8 = theta4 * theta4
theta9 = theta8 * theta
theta10 = theta9 * theta
theta11 = theta10 * theta
theta12 = theta11 * theta
theta13 = theta12 * theta
logterm = 107 * np.log(theta) / 280
fac = theta3 / 4
taylort3 = (
1
+ pn_coefficients[0] * theta2
+ pn_coefficients[1] * theta3
+ pn_coefficients[2] * theta4
+ pn_coefficients[3] * theta5
+ pn_coefficients[4] * theta6
+ logterm * theta6
+ pn_coefficients[5] * theta7
)
out = (
coefficients[0] * theta8
+ coefficients[1] * theta9
+ coefficients[2] * theta10
+ coefficients[3] * theta11
+ coefficients[4] * theta12
+ coefficients[5] * theta13
)
omega = fac * (taylort3 + out)
# omega = fac * out + taylort3
return omega
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def numba_inspiral_ansatz_omega_array(times, eta, pn_coefficients, pseudo_coefficients):
"""Numba version of phenomt.internals.inspiral_ansatz_omega for an array of time steps."""
out = np.empty(len(times))
for idx in prange(len(times)):
out[idx] = numba_inspiral_ansatz_omega(times[idx], eta, pn_coefficients, pseudo_coefficients)
return out
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_inspiral_ansatz_domega(time, eta, pn_coefficients, coefficients):
"""Numba version for the first derivative of phenomt.internals.inspiral_ansatz_omega for one time step."""
theta = np.power(-eta * time / 5, -1.0 / 8)
theta2 = theta * theta
theta3 = theta * theta2
theta4 = theta * theta3
theta5 = theta * theta4
theta6 = theta * theta5
theta7 = theta * theta6
theta8 = theta4 * theta4
theta9 = theta8 * theta
theta10 = theta9 * theta
theta11 = theta10 * theta
theta12 = theta11 * theta
theta13 = theta12 * theta
logterm = np.log(theta)
der_omega = (
0.25
* theta2
* (
3
+ 5 * pn_coefficients[0] * theta2
+ 6 * pn_coefficients[1] * theta3
+ 7 * pn_coefficients[2] * theta4
+ 8 * pn_coefficients[3] * theta5
+ 107.0 / 280 * theta6
+ 9 * pn_coefficients[4] * theta6
+ 10 * pn_coefficients[5] * theta7
+ 11 * coefficients[0] * theta8
+ 12 * coefficients[1] * theta9
+ 13 * coefficients[2] * theta10
+ 14 * coefficients[3] * theta11
+ 15 * coefficients[4] * theta12
+ 16 * coefficients[5] * theta13
)
+ 963 * theta8 * logterm / 1120
)
der_theta = 0.125 * np.power(5 / eta, 1.0 / 8) * np.power(-time, -9.0 / 8)
return der_theta * der_omega
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_ringdown_ansatz_omega(times, c1, c2, c3, c4, omegaRING):
"""Numba version of phenomt.internals.ringdown_ansatz_omega for one time step."""
expC = np.exp(-c2 * times)
expC2 = expC * expC
num = -c1 * c2 * (2 * c4 * expC2 + c3 * expC)
den = 1 + c4 * expC2 + c3 * expC
return num / den + omegaRING
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def numba_ringdown_ansatz_omega_array(times, c1, c2, c3, c4, omegaRING):
"""Numba version of phenomt.internals.ringdown_ansatz_omega for an array of time steps."""
out = np.empty(len(times))
for idx in prange(len(times)):
out[idx] = numba_ringdown_ansatz_omega(times[idx], c1, c2, c3, c4, omegaRING)
return out
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_ringdown_ansatz_domega(time, c1, c2, c3, c4):
"""Numba version for the first derivative of phenomt.internals.ringdown_ansatz_omega for one time step."""
expC = np.exp(c2 * time)
expC2 = expC * expC
num = c1 * c2 * c2 * expC * (4 * c4 * expC + c3 * (c4 + expC2))
den = c4 + expC * (c3 + expC)
return num / (den * den)
#########################################
# PHASE ANSATZE
#########################################
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_pn_ansatz_22_phase(thetabar, eta, powers_of_5, pn_coefficients):
"""Numba version of phenomt.internals.pn_ansatz_phase for the 22 mode in one time step."""
thetabar = thetabar * powers_of_5[1]
thetabar2 = thetabar * thetabar
thetabar3 = thetabar * thetabar2
thetabar4 = thetabar * thetabar3
thetabar5 = thetabar * thetabar4
thetabar6 = thetabar * thetabar5
thetabar7 = thetabar * thetabar6
logthetabar = np.log(thetabar)
aux = (
1
/ eta
/ thetabar5
* (
-168
- 280 * pn_coefficients[0] * thetabar2
- 420 * pn_coefficients[1] * thetabar3
- 840 * pn_coefficients[2] * thetabar4
+ 840 * pn_coefficients[3] * (logthetabar - 0.125 * np.log(5)) * thetabar5
- 321 * thetabar6
+ 840 * pn_coefficients[4] * thetabar6
+ 321 * logthetabar * thetabar6
+ 420 * pn_coefficients[5] * thetabar7
)
) / 84.0
return aux
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_inspiral_22_phase(time, eta, powers_of_5, pn_coefficients, pseudo_pn_coefficients, phOffInsp):
"""Numba version of phenomt.internals.inspiral_ansatz_phase for the 22 mode in one time step."""
thetabar = np.power(-eta * time, -1.0 / 8)
thetabar2 = thetabar * thetabar
thetabar3 = thetabar * thetabar2
thetabar4 = thetabar * thetabar3
thetabar5 = thetabar * thetabar4
thetabar6 = thetabar * thetabar5
thetabar7 = thetabar * thetabar6
logmtime = np.log(-time)
log_theta_bar = np.log(np.power(5, 0.125)) - 0.125 * (np.log(eta) + logmtime)
aux = (
-(
1
/ powers_of_5[5]
/ (eta * eta)
/ time
/ thetabar7
* (
3 * (-107 + 280 * pn_coefficients[4]) * powers_of_5[6]
+ 321 * log_theta_bar * powers_of_5[6]
+ 420 * pn_coefficients[5] * thetabar * powers_of_5[7]
+ 56 * (25 * pseudo_pn_coefficients[0] + 3 * eta * time) * thetabar2
+ 1050 * pseudo_pn_coefficients[1] * powers_of_5[1] * thetabar3
+ 280 * (3 * pseudo_pn_coefficients[2] + eta * pn_coefficients[0] * time) * powers_of_5[2] * thetabar4
+ 140 * (5 * pseudo_pn_coefficients[3] + 3 * eta * pn_coefficients[1] * time) * powers_of_5[3] * thetabar5
+ 120 * (5 * pseudo_pn_coefficients[4] + 7 * eta * pn_coefficients[2] * time) * powers_of_5[4] * thetabar6
+ 525 * pseudo_pn_coefficients[5] * powers_of_5[5] * thetabar7
+ 105 * eta * pn_coefficients[3] * time * logmtime * powers_of_5[5] * thetabar7
)
)
/ 84.0
)
aux = aux + phOffInsp
return aux
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def njit_inspiral_22_phase(times, eta, powers_of_5, pn_coefficients, pseudo_pn_coefficients, phOffInsp):
"""Numba version of phenomt.internals.inspiral_ansatz_phase for the 22 mode in an array of time steps."""
out = np.empty(len(times))
for idx in prange(len(times)):
out[idx] = numba_inspiral_22_phase(times[idx], eta, powers_of_5, pn_coefficients, pseudo_pn_coefficients, phOffInsp)
return out
[docs]
@njit(cache=NUMBA_USE_CACHE)
def numba_ringdown_ansatz_phase(time, c1_prec, c2, c3, c4, omegaRING_prec, phOffRD):
"""Numba version of phenomt.internals.ringdown_ansatz_phase for one time step."""
expC = np.exp(-c2 * time)
num = 1 + c3 * expC + c4 * expC * expC
den = 1 + c3 + c4
aux = np.log(num / den)
return c1_prec * aux + omegaRING_prec * time + phOffRD
[docs]
@njit(parallel=NUMBA_USE_PARALLEL, cache=NUMBA_USE_CACHE)
def numba_ringdown_ansatz_phase_array(times, c1_prec, c2, c3, c4, omegaRING_prec, phOffRD):
"""Numba version of phenomt.internals.ringdown_ansatz_phase for an array of time steps."""
out = np.empty(len(times))
for idx in prange(len(times)):
out[idx] = numba_ringdown_ansatz_phase(times[idx], c1_prec, c2, c3, c4, omegaRING_prec, phOffRD)
return out