from numba import prange, njit, objmode
from numba.np.ufunc.parallel import (
_get_thread_id as get_thread_id,
get_num_threads,
)
import numpy as np
from tardis.montecarlo.montecarlo_numba.r_packet import (
RPacket,
PacketStatus,
)
from tardis.montecarlo.montecarlo_numba.numba_interface import (
PacketCollection,
VPacketCollection,
RPacketTracker,
NumbaModel,
numba_plasma_initialize,
Estimators,
)
from tardis.montecarlo import (
montecarlo_configuration as montecarlo_configuration,
)
from tardis.montecarlo.montecarlo_numba.single_packet_loop import (
single_packet_loop,
)
from tardis.montecarlo.montecarlo_numba import njit_dict
from numba.typed import List
from tardis.util.base import update_iterations_pbar, update_packet_pbar
[docs]def montecarlo_radial1d(
model,
plasma,
iteration,
no_of_packets,
total_iterations,
show_progress_bars,
runner,
):
packet_collection = PacketCollection(
runner.input_r,
runner.input_nu,
runner.input_mu,
runner.input_energy,
runner._output_nu,
runner._output_energy,
)
numba_model = NumbaModel(
runner.r_inner_cgs,
runner.r_outer_cgs,
runner.v_inner_cgs,
runner.v_outer_cgs,
model.time_explosion.to("s").value,
)
numba_plasma = numba_plasma_initialize(plasma, runner.line_interaction_type)
estimators = Estimators(
runner.j_estimator,
runner.nu_bar_estimator,
runner.j_blue_estimator,
runner.Edotlu_estimator,
runner.photo_ion_estimator,
runner.stim_recomb_estimator,
runner.bf_heating_estimator,
runner.stim_recomb_cooling_estimator,
runner.photo_ion_estimator_statistics,
)
packet_seeds = montecarlo_configuration.packet_seeds
number_of_vpackets = montecarlo_configuration.number_of_vpackets
(
v_packets_energy_hist,
last_interaction_type,
last_interaction_in_nu,
last_line_interaction_in_id,
last_line_interaction_out_id,
virt_packet_nus,
virt_packet_energies,
virt_packet_initial_mus,
virt_packet_initial_rs,
virt_packet_last_interaction_in_nu,
virt_packet_last_interaction_type,
virt_packet_last_line_interaction_in_id,
virt_packet_last_line_interaction_out_id,
rpacket_trackers,
) = montecarlo_main_loop(
packet_collection,
numba_model,
numba_plasma,
estimators,
runner.spectrum_frequency.value,
number_of_vpackets,
packet_seeds,
montecarlo_configuration.VPACKET_LOGGING,
iteration=iteration,
show_progress_bars=show_progress_bars,
no_of_packets=no_of_packets,
total_iterations=total_iterations,
)
runner._montecarlo_virtual_luminosity.value[:] = v_packets_energy_hist
runner.last_interaction_type = last_interaction_type
runner.last_interaction_in_nu = last_interaction_in_nu
runner.last_line_interaction_in_id = last_line_interaction_in_id
runner.last_line_interaction_out_id = last_line_interaction_out_id
if montecarlo_configuration.VPACKET_LOGGING and number_of_vpackets > 0:
runner.virt_packet_nus = np.concatenate(virt_packet_nus).ravel()
runner.virt_packet_energies = np.concatenate(
virt_packet_energies
).ravel()
runner.virt_packet_initial_mus = np.concatenate(
virt_packet_initial_mus
).ravel()
runner.virt_packet_initial_rs = np.concatenate(
virt_packet_initial_rs
).ravel()
runner.virt_packet_last_interaction_in_nu = np.concatenate(
virt_packet_last_interaction_in_nu
).ravel()
runner.virt_packet_last_interaction_type = np.concatenate(
virt_packet_last_interaction_type
).ravel()
runner.virt_packet_last_line_interaction_in_id = np.concatenate(
virt_packet_last_line_interaction_in_id
).ravel()
runner.virt_packet_last_line_interaction_out_id = np.concatenate(
virt_packet_last_line_interaction_out_id
).ravel()
update_iterations_pbar(1)
# Condition for Checking if RPacket Tracking is enabled
if montecarlo_configuration.RPACKET_TRACKING:
runner.rpacket_tracker = rpacket_trackers
[docs]@njit(**njit_dict)
def montecarlo_main_loop(
packet_collection,
numba_model,
numba_plasma,
estimators,
spectrum_frequency,
number_of_vpackets,
packet_seeds,
virtual_packet_logging,
iteration,
show_progress_bars,
no_of_packets,
total_iterations,
):
"""
This is the main loop of the MonteCarlo routine that generates packets
and sends them through the ejecta.
Parameters
----------
packet_collection : PacketCollection
numba_model : NumbaModel
numba_plasma : NumbaPlasma
estimators : NumbaEstimators
spectrum_frequency : astropy.units.Quantity
frequency binspas
number_of_vpackets : int
VPackets released per interaction
packet_seeds : numpy.array
virtual_packet_logging : bool
Option to enable virtual packet logging.
"""
output_nus = np.empty_like(packet_collection.packets_input_nu)
last_interaction_types = (
np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1
)
output_energies = np.empty_like(packet_collection.packets_output_nu)
last_interaction_in_nus = np.empty_like(packet_collection.packets_output_nu)
last_line_interaction_in_ids = (
np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1
)
last_line_interaction_out_ids = (
np.ones_like(packet_collection.packets_output_nu, dtype=np.int64) * -1
)
v_packets_energy_hist = np.zeros_like(spectrum_frequency)
delta_nu = spectrum_frequency[1] - spectrum_frequency[0]
# Pre-allocate a list of vpacket collections for later storage
vpacket_collections = List()
# Configuring the Tracking for R_Packets
rpacket_trackers = List()
for i in range(len(output_nus)):
vpacket_collections.append(
VPacketCollection(
i,
spectrum_frequency,
montecarlo_configuration.v_packet_spawn_start_frequency,
montecarlo_configuration.v_packet_spawn_end_frequency,
number_of_vpackets,
montecarlo_configuration.temporary_v_packet_bins,
)
)
rpacket_trackers.append(RPacketTracker())
main_thread_id = get_thread_id()
n_threads = get_num_threads()
estimator_list = List()
for i in range(n_threads): # betting get tid goes from 0 to num threads
estimator_list.append(
Estimators(
np.copy(estimators.j_estimator),
np.copy(estimators.nu_bar_estimator),
np.copy(estimators.j_blue_estimator),
np.copy(estimators.Edotlu_estimator),
np.copy(estimators.photo_ion_estimator),
np.copy(estimators.stim_recomb_estimator),
np.copy(estimators.bf_heating_estimator),
np.copy(estimators.stim_recomb_cooling_estimator),
np.copy(estimators.photo_ion_estimator_statistics),
)
)
# Arrays for vpacket logging
virt_packet_nus = []
virt_packet_energies = []
virt_packet_initial_mus = []
virt_packet_initial_rs = []
virt_packet_last_interaction_in_nu = []
virt_packet_last_interaction_type = []
virt_packet_last_line_interaction_in_id = []
virt_packet_last_line_interaction_out_id = []
for i in prange(len(output_nus)):
tid = get_thread_id()
if show_progress_bars:
if tid == main_thread_id:
with objmode:
update_amount = 1 * n_threads
update_packet_pbar(
update_amount,
current_iteration=iteration,
no_of_packets=no_of_packets,
total_iterations=total_iterations,
)
if montecarlo_configuration.single_packet_seed != -1:
seed = packet_seeds[montecarlo_configuration.single_packet_seed]
np.random.seed(seed)
else:
seed = packet_seeds[i]
np.random.seed(seed)
r_packet = RPacket(
numba_model.r_inner[0],
packet_collection.packets_input_mu[i],
packet_collection.packets_input_nu[i],
packet_collection.packets_input_energy[i],
seed,
i,
)
local_estimators = estimator_list[tid]
vpacket_collection = vpacket_collections[i]
rpacket_tracker = rpacket_trackers[i]
loop = single_packet_loop(
r_packet,
numba_model,
numba_plasma,
estimators,
vpacket_collection,
rpacket_tracker,
)
output_nus[i] = r_packet.nu
last_interaction_in_nus[i] = r_packet.last_interaction_in_nu
last_line_interaction_in_ids[i] = r_packet.last_line_interaction_in_id
last_line_interaction_out_ids[i] = r_packet.last_line_interaction_out_id
if r_packet.status == PacketStatus.REABSORBED:
output_energies[i] = -r_packet.energy
last_interaction_types[i] = r_packet.last_interaction_type
elif r_packet.status == PacketStatus.EMITTED:
output_energies[i] = r_packet.energy
last_interaction_types[i] = r_packet.last_interaction_type
vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx]
vpackets_energy = vpacket_collection.energies[: vpacket_collection.idx]
vpackets_initial_mu = vpacket_collection.initial_mus[
: vpacket_collection.idx
]
vpackets_initial_r = vpacket_collection.initial_rs[
: vpacket_collection.idx
]
v_packets_idx = np.floor(
(vpackets_nu - spectrum_frequency[0]) / delta_nu
).astype(np.int64)
# if we're only in a single-packet mode
# if montecarlo_configuration.single_packet_seed == -1:
# break
for j, idx in enumerate(v_packets_idx):
if (vpackets_nu[j] < spectrum_frequency[0]) or (
vpackets_nu[j] > spectrum_frequency[-1]
):
continue
v_packets_energy_hist[idx] += vpackets_energy[j]
for sub_estimator in estimator_list:
estimators.increment(sub_estimator)
if virtual_packet_logging:
for vpacket_collection in vpacket_collections:
vpackets_nu = vpacket_collection.nus[: vpacket_collection.idx]
vpackets_energy = vpacket_collection.energies[
: vpacket_collection.idx
]
vpackets_initial_mu = vpacket_collection.initial_mus[
: vpacket_collection.idx
]
vpackets_initial_r = vpacket_collection.initial_rs[
: vpacket_collection.idx
]
virt_packet_nus.append(np.ascontiguousarray(vpackets_nu))
virt_packet_energies.append(np.ascontiguousarray(vpackets_energy))
virt_packet_initial_mus.append(
np.ascontiguousarray(vpackets_initial_mu)
)
virt_packet_initial_rs.append(
np.ascontiguousarray(vpackets_initial_r)
)
virt_packet_last_interaction_in_nu.append(
np.ascontiguousarray(
vpacket_collection.last_interaction_in_nu[
: vpacket_collection.idx
]
)
)
virt_packet_last_interaction_type.append(
np.ascontiguousarray(
vpacket_collection.last_interaction_type[
: vpacket_collection.idx
]
)
)
virt_packet_last_line_interaction_in_id.append(
np.ascontiguousarray(
vpacket_collection.last_interaction_in_id[
: vpacket_collection.idx
]
)
)
virt_packet_last_line_interaction_out_id.append(
np.ascontiguousarray(
vpacket_collection.last_interaction_out_id[
: vpacket_collection.idx
]
)
)
if montecarlo_configuration.RPACKET_TRACKING:
for rpacket_tracker in rpacket_trackers:
rpacket_tracker.finalize_array()
packet_collection.packets_output_energy[:] = output_energies[:]
packet_collection.packets_output_nu[:] = output_nus[:]
return (
v_packets_energy_hist,
last_interaction_types,
last_interaction_in_nus,
last_line_interaction_in_ids,
last_line_interaction_out_ids,
virt_packet_nus,
virt_packet_energies,
virt_packet_initial_mus,
virt_packet_initial_rs,
virt_packet_last_interaction_in_nu,
virt_packet_last_interaction_type,
virt_packet_last_line_interaction_in_id,
virt_packet_last_line_interaction_out_id,
rpacket_trackers,
)