"""
This module contains the `RetrievalObject` class, which is the main
class of the pyretlife package.
"""
__author__ = "Konrad, Alei, Burr, Molliere, Quanz"
__copyright__ = "Copyright 2022, Konrad, Alei, Molliere, Quanz"
__maintainer__ = "Björn S. Konrad, Eleonora Alei, Zachary Burr"
__email__ = "konradb@ethz.ch, elalei@phys.ethz.ch, zaburr@phys.ethz.ch"
__status__ = "Development"
# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------
import importlib
import json
import os
import sys
from pathlib import Path
import numpy as np
from typing import Union, Tuple
from astropy import constants as const
from petitRADTRANS.radtrans import Radtrans
from pyretlife.retrieval.atmospheric_variables import (
calculate_gravity,
calculate_polynomial_profile,
calculate_vae_profile,
calculate_guillot_profile,
calculate_isothermal_profile,
calculate_madhuseager_profile,
calculate_mod_madhuseager_profile,
calculate_line_profile,
calculate_spline_profile,
calculate_adiabat_profile,
calculate_abundances,
condense_water,
set_log_ground_pressure,
assign_cloud_parameters,
calculate_mmw_VMR,
calculate_inert,
convert_VMR_to_MMR,
)
from pyretlife.retrieval.configuration_ingestion import (
read_config_file,
check_if_configs_match,
save_configuration,
save_input_spectra,
save_github_commit_string,
save_retrieved_parameters,
save_environment_variables,
populate_dictionaries,
make_output_folder,
load_data,
get_check_opacity_path,
get_check_prt_path,
get_retrieval_path,
set_prt_opacity,
)
from pyretlife.retrieval.likelihood_validation import (
validate_pt_profile,
validate_cube_finite,
validate_positive_mass,
validate_positive_temperatures,
validate_sum_of_abundances,
validate_spectrum_goodness,
validate_clouds,
validate_abundances,
)
from pyretlife.retrieval.priors import assign_priors
from pyretlife.retrieval.radiative_transfer import (
define_linelists,
calculate_moon_flux,
assign_reflectance_emissivity,
calculate_emission_flux,
scale_flux_to_distance,
rebin_spectrum,
)
from pyretlife.retrieval.units import (UnitsUtil,
convert_spectrum,
convert_knowns_and_parameters,
)
# -----------------------------------------------------------------------------
# DEFINITIONS
# -----------------------------------------------------------------------------
[docs]
class RetrievalObject:
"""
This class binds together all the different parts of the retrieval.
Args:
run_retrieval:
Attributes:
config: The configuration (i.e., the contents of the YAML or
INI file for a given retrieval) as a dictionary.
input_prt_path: Path to the petitRADTRANS installation.
input_opacity_path: Path to the opacity data.
...
TODO: Keep adding attributes here to document them.
TODO: Maybe also add an `__repr__` method to this class?
"""
def __init__(
self,
run_retrieval: bool = True,
):
"""
This function reads the config.ini file and initializes all
the variables. It also ensures that the run is not rewritten
unintentionally.
"""
# Store constructor arguments
self.vae_pt = None
self.moon_flux = None
self.MMW = None
self.inert = None
self.temp = None
self.press = None
self.moon_vars = None
self.scat_vars = None
self.cloud_vars = None
self.phys_vars = None
self.chem_vars = None
self.temp_vars = None
self.config = None
self.config_default = None
self.rt_object = None
self.run_retrieval = run_retrieval
self.knowns = {}
self.parameters = {}
self.settings = {}
self.instrument = {}
# # Get and check the goodness of the environmental variables
self.input_opacity_path = get_check_opacity_path()
self.input_prt_path = get_check_prt_path()
self.input_retrieval_path = get_retrieval_path()
sys.path.append(str(self.input_prt_path))
# set_prt_opacity(self.input_prt_path, self.input_opacity_path) #outdated
os.environ["pRT_input_data_path"] = str(self.input_opacity_path)
# Create a units object to enable unit conversions
self.units = UnitsUtil()
[docs]
def load_configuration(self, config_file: str):
# Load standard configurations (hard-coded)
self.config_default = read_config_file(
file_path=Path(self.input_retrieval_path+"/configs/config_default.yaml")
)
# Read in the configuration and check if there is already one in the file
self.config = read_config_file(file_path=Path(config_file))
# Check if configuration file exists and if it matches
if not check_if_configs_match(config=self.config):
raise RuntimeError("Config exists and does not match!")
# Save config into the four dictionaries
(
self.knowns,
self.parameters,
self.settings,
self.units,
) = populate_dictionaries(
self.config_default,
self.knowns,
self.parameters,
self.settings,
self.units,
)
(
self.knowns,
self.parameters,
self.settings,
self.units,
) = populate_dictionaries(
self.config,
self.knowns,
self.parameters,
self.settings,
self.units
)
if "data_files" in self.settings.keys():
self.instrument = load_data(self.settings, self.units,retrieval=self.run_retrieval)
# IF CLOUDS, ASSIGN P0 WHEN NOT PROVIDED
# TODO P0_test()
# TODO implement validation
# validate_config(self.config)
[docs]
def unit_conversion(self):
self.instrument = convert_spectrum(self.instrument, self.units)
self.knowns = convert_knowns_and_parameters(self.knowns, self.units)
self.parameters = convert_knowns_and_parameters(
self.parameters, self.units
)
[docs]
def assign_knowns(self):
self.temp_vars = {}
self.chem_vars = {}
self.phys_vars = {}
self.cloud_vars = {}
self.scat_vars = {}
self.moon_vars = {}
# Add the known parameters to the dictionary
for par in self.knowns.keys():
if self.knowns[par]["type"] == "TEMPERATURE PARAMETERS":
self.temp_vars[par] = self.knowns[par]["truth"]
elif self.knowns[par]["type"] == "CHEMICAL COMPOSITION PARAMETERS":
self.chem_vars[par] = self.knowns[par]["truth"]
elif self.knowns[par]["type"] == "PHYSICAL PARAMETERS":
self.phys_vars[par] = self.knowns[par]["truth"]
elif self.knowns[par]["type"] == "CLOUD PARAMETERS":
if ('Pcloud' in par) or (par == 'cloud_fraction') or ('_cloud_top' in par):
self.cloud_vars[par] = self.knowns[par]["truth"]
else:
# TODO review this snippet
if (
not "_".join(par.split("_", 2)[:2])
in self.cloud_vars.keys()
):
self.cloud_vars["_".join(par.split("_", 2)[:2])] = {}
try:
self.cloud_vars["_".join(par.split("_", 2)[:2])][
par.split("_", 2)[2]
] = self.knowns[par]["truth"]
except:
self.cloud_vars["_".join(par.split("_", 2)[:2])][
"abundance"
] = self.knowns[par]["truth"]
self.chem_vars[par.split("_", 1)[0]] = self.knowns[par][
"truth"
]
elif self.knowns[par]["type"] == "SCATTERING PARAMETERS":
self.scat_vars[par] = self.knowns[par]["truth"]
elif self.knowns[par]["type"] == "MOON PARAMETERS":
self.moon_vars[par] = self.knowns[par]["truth"]
# in case the PT profile is known, assign it already
if self.settings["parameterization"] == "input":
self.press, self.temp = np.loadtxt(
self.temp_vars["input_path"], unpack=True
)
[docs]
def assign_prior_functions(self):
#self.parameters = read_input_prior(self.parameters)
self.parameters = assign_priors(self.parameters)
# TODO Check that all priors are valid (invalid_prior function in priors.py otherwise)
[docs]
def petitRADTRANS_initialization(self):
"""
Initializes the rt_object given the wavelength range.
"""
(
used_line_species,
used_rayleigh_species,
used_cia_species,
used_cloud_species,
) = define_linelists(
self.config, self.settings, self.input_opacity_path
)
# TODO implement verbose output
old_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
ls = sorted(used_line_species)[::-1]
self.rt_object = Radtrans(
pressures=np.logspace(
self.settings["log_top_pressure"],
0,
self.settings["n_layers"],
base=10,),
line_species=ls,
rayleigh_species=sorted(used_rayleigh_species),
gas_continuum_contributors=sorted(used_cia_species),
cloud_species=sorted(used_cloud_species),
wavelength_boundaries=self.settings["wavelength_range"],
line_opacity_mode="c-k",
scattering_in_emission= True in self.settings["include_scattering"].values(),)
sys.stdout = old_stdout
[docs]
def unity_cube_to_prior_space(self, cube):
cube_copy = cube.copy()
M_is_in = False
R_pl = None
idx_M_pl = None
for par in self.parameters.keys():
idx = list(self.parameters.keys()).index(par)
if par == "M_pl":
M_is_in = True
idx_M_pl = idx
continue
prior = self.parameters[par]["prior"]
cube_copy[idx] = prior["function"](cube_copy[idx], prior["prior_specs"])
if par == "R_pl":
R_pl = cube_copy[idx]
if M_is_in: #Need to treat M separately to allow for 2D prior
prior = self.parameters['M_pl']["prior"]
if prior['kind'] == '2d-uniform':
if R_pl is None:
R_pl = self.knowns['R_pl']['input_truth']
cube_copy[idx_M_pl] = prior["function"](cube_copy[idx_M_pl], R_pl)
else:
cube_copy[idx_M_pl] = prior["function"](cube_copy[idx_M_pl], prior["prior_specs"])
return cube_copy
[docs]
def assign_cube_to_parameters(self, cube):
for par in self.parameters.keys():
idx = list(self.parameters.keys()).index(par)
if self.parameters[par]["type"] == "TEMPERATURE PARAMETERS":
self.temp_vars[par] = cube[idx]
elif (
self.parameters[par]["type"]
== "CHEMICAL COMPOSITION PARAMETERS"
):
self.chem_vars[par] = cube[idx]
elif self.parameters[par]["type"] == "PHYSICAL PARAMETERS":
self.phys_vars[par] = cube[idx]
elif self.parameters[par]["type"] == "CLOUD PARAMETERS":
if ('Pcloud' in par) or (par == 'cloud_fraction') or ('_cloud_top' in par):
self.cloud_vars[par] = cube[idx]
else:
if (
not "_".join(par.split("_", 2)[:2])
in self.cloud_vars.keys()
):
self.cloud_vars["_".join(par.split("_", 2)[:2])] = {}
try:
self.cloud_vars["_".join(par.split("_", 2)[:2])][
par.split("_", 2)[2]
] = cube[idx]
except:
self.cloud_vars["_".join(par.split("_", 2)[:2])][
"abundance"
] = cube[idx]
self.chem_vars[par.split("_", 1)[0]] = cube[idx]
elif self.parameters[par]["type"] == "SCATTERING PARAMETERS":
self.scat_vars[par] = cube[idx]
elif self.parameters[par]["type"] == "MOON PARAMETERS":
self.moon_vars[par] = cube[idx]
[docs]
def calculate_pt_profile(
self, parameterization, log_ground_pressure, log_top_pressure, layers
):
"""
Creates the pressure-temperature profile from the temperature
parameters and the pressure.
"""
self.press = np.array(
np.logspace(log_top_pressure, log_ground_pressure, layers, base=10)
)
if parameterization == "polynomial":
self.temp = calculate_polynomial_profile(self.press, self.temp_vars)
# TODO understand what is going on here
elif parameterization == "vae_pt":
self.temp = calculate_vae_profile(self.press, self.vae_pt, self.temp_vars)
elif parameterization == "guillot":
self.temp = calculate_guillot_profile(self.press, self.temp_vars)
elif self.settings["parameterization"] == "isothermal":
self.temp = calculate_isothermal_profile(self.press, self.temp_vars)
elif self.settings["parameterization"] == "madhuseager":
self.temp = calculate_madhuseager_profile(self.press, self.temp_vars)
elif self.settings["parameterization"] == "mod_madhuseager":
self.temp = calculate_mod_madhuseager_profile(self.press, self.temp_vars)
elif self.settings["parameterization"] == "spline":
self.temp = calculate_spline_profile(self.press, self.temp_vars, self.phys_vars, self.settings)
elif self.settings["parameterization"] == "adiabat":
self.temp = calculate_adiabat_profile(self.press, self.temp_vars, self.phys_vars)
elif self.settings["parameterization"] == "line":
self.temp = calculate_line_profile(self.press, self.temp_vars, self.phys_vars)
else:
raise ValueError("Unknown PT setting!")
return
[docs]
def calculate_abundances(self):
self.abundances_VMR, self.chem_vars_VMR = calculate_abundances(self.chem_vars, self.press, self.settings)
self.condensation_pressures = None
if self.settings['condensation']:
if 'H2O_Drying' in self.chem_vars.keys():
self.abundances_VMR, self.condensation_pressures = condense_water(self.abundances_VMR,self.press[::-1],self.temp[::-1],self.phys_vars,self.settings,drying=self.chem_vars['H2O_Drying'])
else:
self.abundances_VMR, self.condensation_pressures = condense_water(self.abundances_VMR,self.press[::-1],self.temp[::-1],self.phys_vars,self.settings,drying=0)
(
self.abundances_VMR,
self.cloud_vars,
self.cloud_radii,
self.cloud_lnorm,
self.cloud_Pcloud,
self.cloud_fraction
) = assign_cloud_parameters(
self.abundances_VMR, self.cloud_vars, self.press, self.phys_vars, self.condensation_pressures
)
[docs]
def calculate_spectrum(self, em_contr=False):
self.inert = calculate_inert(self.abundances_VMR)
self.MMW = calculate_mmw_VMR(self.abundances_VMR, self.settings, self.inert)
abundances_MMR, inert_MMR = convert_VMR_to_MMR(self.abundances_VMR, self.settings, self.inert, self.MMW)
# initialize calculated pressure
self.rt_object.pressures = self.press * 1e6
if self.settings["include_moon"]:
self.moon_flux = calculate_moon_flux(self.rt_object.frequencies, self.moon_vars)
if (
self.settings["include_scattering"]["direct_light"]
or self.settings["include_scattering"]["thermal"]
):
(
self.rt_object.reflectance,
self.rt_object.emissivity,
) = assign_reflectance_emissivity(
self.scat_vars, self.rt_object.frequencies
)
# Calculate the forward model; this returns the frequency
# and the flux F_nu in erg/cm^2/s/Hz.
freq, cloud_free_flux, cloud_free_em_contr = calculate_emission_flux(self.rt_object, self.settings, self.temp,
abundances_MMR, self.phys_vars["g"],
self.MMW, self.cloud_radii, self.cloud_lnorm,
self.scat_vars, em_contr=em_contr, Pcloud=None)
if self.cloud_fraction is not None:
freq, cloudy_flux, cloudy_em_contr = calculate_emission_flux(self.rt_object, self.settings, self.temp,
abundances_MMR, self.phys_vars["g"],
self.MMW, self.cloud_radii, self.cloud_lnorm,
self.scat_vars, em_contr=em_contr, Pcloud=self.cloud_Pcloud)
mixed_flux = cloud_free_flux*(1-self.cloud_fraction) + cloudy_flux*self.cloud_fraction
if em_contr:
mixed_em_contr = cloud_free_em_contr#*(1-self.cloud_fraction) + cloudy_em_contr*self.cloud_fraction
else:
mixed_flux = cloud_free_flux
if em_contr:
mixed_em_contr = cloud_free_em_contr
wavelength = (
const.c.cgs.value / freq * 1e4
)
if em_contr:
return wavelength, mixed_flux, mixed_em_contr
else:
return wavelength, mixed_flux
[docs]
def distance_scale_spectrum(self):
if self.phys_vars["d_syst"] is not None:
# WARNING! THIS CONVERTS UNITS OF PRT SPECTRUM from cm-2 to m-2
self.rt_object.flux = scale_flux_to_distance(
self.rt_object.flux,
self.phys_vars["R_pl"],
self.phys_vars["d_syst"],
)
if self.settings["include_moon"]:
self.moon_flux = scale_flux_to_distance(
self.moon_flux,
self.moon_vars["R_m"],
self.phys_vars["d_syst"],
)
[docs]
def calculate_log_likelihood(self, cube):
"""
Calculates the log(likelihood) of the forward model generated
with parameters and known variables.
"""
# Generate dictionaries for the different classes of parameters
# and add the known parameters as well as a sample of the
# retrieved parameters to them
self.assign_cube_to_parameters(cube)
self.phys_vars = set_log_ground_pressure(self.phys_vars, self.config, self.knowns)
# TODO expand on tests here
# test goodness of random draw
if validate_pt_profile(self.settings, self.temp_vars, self.phys_vars):
return -1e99
if validate_cube_finite(cube):
return -1e99
if validate_positive_mass(self.phys_vars):
return -1e99
self.phys_vars = calculate_gravity(self.phys_vars,self.config)
if self.settings["parameterization"] != "input":
self.calculate_pt_profile(
parameterization=self.settings["parameterization"],
log_ground_pressure=self.phys_vars["log_P0"],
log_top_pressure=self.settings["log_top_pressure"],
layers=self.settings["n_layers"],
)
if validate_positive_temperatures(self.temp):
return -1e99
self.calculate_abundances()
#if validate_abundances(self.abundances_VMR,self.chem_vars_VMR):
# return -1e99
if validate_clouds(self.press, self.temp, self.cloud_vars):
return -1e99
if validate_sum_of_abundances(self.abundances_VMR):
return -1e99
self.rt_object.wavelength, self.rt_object.flux = self.calculate_spectrum()
if validate_spectrum_goodness(self.rt_object.flux):
return -1e99
self.distance_scale_spectrum()
# Calculate total log-likelihood (sum over instruments)
log_likelihood = 0.0
for inst in self.instrument.keys():
# Rebin the spectrum according to the input spectrum if wavelengths
# differ strongly
rebinned_flux = rebin_spectrum(
self.instrument[inst],
self.rt_object.wavelength,
self.rt_object.flux,
)
if self.settings["include_moon"] == "True":
rebinned_flux += rebin_spectrum(
self.instrument[inst],
self.rt_object.wavelength,
self.moon_flux,
)
# Calculate log-likelihood
log_likelihood += -0.5 * np.sum(
(
(rebinned_flux - self.instrument[inst]["flux"])
/ self.instrument[inst]["error"]
)
** 2.0
#+ np.log(2*np.pi*self.instrument[inst]["error"]**2)
)
return log_likelihood
[docs]
def generate_new_spectrum(self):
"""
Runs the forward model and calculates a (new) spectrum rather than
running a retrieval.
"""
# Calculate values from given parameters
self.phys_vars = set_log_ground_pressure(self.phys_vars, self.config, self.knowns)
self.phys_vars = calculate_gravity(self.phys_vars,self.config)
if self.settings["parameterization"] != "input":
self.calculate_pt_profile(
parameterization=self.settings["parameterization"],
log_ground_pressure=self.phys_vars["log_P0"],
log_top_pressure=self.settings["log_top_pressure"],
layers=self.settings["n_layers"],
)
if validate_positive_temperatures(self.temp):
raise ValueError("PT profile parameters resulted in negative temperatures!")
self.calculate_abundances()
#if validate_abundances(self.abundances_VMR,self.chem_vars_VMR):
# return -1e99
if validate_clouds(self.press, self.temp, self.cloud_vars):
raise ValueError("Issue with cloud positioning!")
if validate_sum_of_abundances(self.abundances_VMR):
raise ValueError("Gas abundances do not sum to 1!")
# Perform radiative transfer
self.rt_object.wavelength, self.rt_object.flux = self.calculate_spectrum()
# if validate_spectrum_goodness(self.rt_object.flux):
# return -1e99
self.distance_scale_spectrum()
# Rebin to match desired spectral resolution
instrument = {}
instrument["wavelength"] = np.exp(np.arange(
np.log(self.rt_object.wavelength[0]),
np.log(self.rt_object.wavelength[-1]),
1/self.settings["resolution"]))
rebinned_flux = rebin_spectrum(
instrument,
self.rt_object.wavelength,
self.rt_object.flux,
)
if self.settings["include_moon"] == "True":
rebinned_flux += rebin_spectrum(
instrument,
self.rt_object.wavelength,
self.moon_flux,
)
# Save output to file
np.savetxt(self.settings["output_folder"],
np.column_stack((instrument["wavelength"], rebinned_flux)))
[docs]
def vae_initialization(self):
# TODO see if it can be improved
# if the vae_pt is selected initialize the pt profile model
if self.settings["parameterization"] == "vae_pt":
from pyretlife.retrieval import pt_vae as vae
self.vae_pt = vae.VAE_PT_Model_Flow(
os.path.dirname(os.path.realpath(__file__))
+ "/vae_pt_models/Flow/"
+ self.settings["vae_net"],
)
if self.settings["parameterization"] == "vae_pt_flow":
from pyretlife.retrieval import pt_vae as vae
self.vae_pt = vae.VAE_PT_Model_Flow(
os.path.dirname(os.path.realpath(__file__))
+ "/vae_pt_models/Flow/"
+ self.settings["vae_net"],
flow_path=os.path.dirname(os.path.realpath(__file__))
+ "/vae_pt_models/Flow/flow-state-dict.pt",
)