Source code for pyretlife.retrieval.run

"""
This module contains the `RetrievalObject` class, which is the main
class of the pyretlife package.
"""

__author__ = "Konrad, Alei, Burr, Molliere, Quanz"
__copyright__ = "Copyright 2022, Konrad, Alei, Molliere, Quanz"
__maintainer__ = "Björn S. Konrad, Eleonora Alei, Zachary Burr"
__email__ = "konradb@ethz.ch, elalei@phys.ethz.ch, zaburr@phys.ethz.ch"
__status__ = "Development"


# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------

import importlib
import json
import os
import sys
from pathlib import Path
import numpy as np
from typing import Union, Tuple
from astropy import constants as const

from petitRADTRANS.radtrans import Radtrans

from pyretlife.retrieval.atmospheric_variables import (
    calculate_gravity,
    calculate_polynomial_profile,
    calculate_vae_profile,
    calculate_guillot_profile,
    calculate_isothermal_profile,
    calculate_madhuseager_profile,
    calculate_mod_madhuseager_profile,
    calculate_line_profile,
    calculate_spline_profile,
    calculate_adiabat_profile,
    calculate_abundances,
    condense_water,
    set_log_ground_pressure,
    assign_cloud_parameters,
    calculate_mmw_VMR,
    calculate_inert,
    convert_VMR_to_MMR,
)
from pyretlife.retrieval.configuration_ingestion import (
    read_config_file,
    check_if_configs_match,
    save_configuration,
    save_input_spectra,
    save_github_commit_string,
    save_retrieved_parameters,
    save_environment_variables,
    populate_dictionaries,
    make_output_folder,
    load_data,
    get_check_opacity_path,
    get_check_prt_path,
    get_retrieval_path,
    set_prt_opacity,
)
from pyretlife.retrieval.likelihood_validation import (
    validate_pt_profile,
    validate_cube_finite,
    validate_positive_mass,
    validate_positive_temperatures,
    validate_sum_of_abundances,
    validate_spectrum_goodness,
    validate_clouds,
    validate_abundances,
)
from pyretlife.retrieval.priors import assign_priors
from pyretlife.retrieval.radiative_transfer import (
    define_linelists,
    calculate_moon_flux,
    assign_reflectance_emissivity,
    calculate_emission_flux,
    scale_flux_to_distance,
    rebin_spectrum,
)
from pyretlife.retrieval.units import (UnitsUtil,
                                       convert_spectrum,
                                       convert_knowns_and_parameters,
                                       )


# -----------------------------------------------------------------------------
# DEFINITIONS
# -----------------------------------------------------------------------------


[docs] class RetrievalObject: """ This class binds together all the different parts of the retrieval. Args: run_retrieval: Attributes: config: The configuration (i.e., the contents of the YAML or INI file for a given retrieval) as a dictionary. input_prt_path: Path to the petitRADTRANS installation. input_opacity_path: Path to the opacity data. ... TODO: Keep adding attributes here to document them. TODO: Maybe also add an `__repr__` method to this class? """ def __init__( self, run_retrieval: bool = True, ): """ This function reads the config.ini file and initializes all the variables. It also ensures that the run is not rewritten unintentionally. """ # Store constructor arguments self.vae_pt = None self.moon_flux = None self.MMW = None self.inert = None self.temp = None self.press = None self.moon_vars = None self.scat_vars = None self.cloud_vars = None self.phys_vars = None self.chem_vars = None self.temp_vars = None self.config = None self.config_default = None self.rt_object = None self.run_retrieval = run_retrieval self.knowns = {} self.parameters = {} self.settings = {} self.instrument = {} # # Get and check the goodness of the environmental variables self.input_opacity_path = get_check_opacity_path() self.input_prt_path = get_check_prt_path() self.input_retrieval_path = get_retrieval_path() sys.path.append(str(self.input_prt_path)) # set_prt_opacity(self.input_prt_path, self.input_opacity_path) #outdated os.environ["pRT_input_data_path"] = str(self.input_opacity_path) # Create a units object to enable unit conversions self.units = UnitsUtil()
[docs] def load_configuration(self, config_file: str): # Load standard configurations (hard-coded) self.config_default = read_config_file( file_path=Path(self.input_retrieval_path+"/configs/config_default.yaml") ) # Read in the configuration and check if there is already one in the file self.config = read_config_file(file_path=Path(config_file)) # Check if configuration file exists and if it matches if not check_if_configs_match(config=self.config): raise RuntimeError("Config exists and does not match!") # Save config into the four dictionaries ( self.knowns, self.parameters, self.settings, self.units, ) = populate_dictionaries( self.config_default, self.knowns, self.parameters, self.settings, self.units, ) ( self.knowns, self.parameters, self.settings, self.units, ) = populate_dictionaries( self.config, self.knowns, self.parameters, self.settings, self.units ) if "data_files" in self.settings.keys(): self.instrument = load_data(self.settings, self.units,retrieval=self.run_retrieval)
# IF CLOUDS, ASSIGN P0 WHEN NOT PROVIDED # TODO P0_test() # TODO implement validation # validate_config(self.config)
[docs] def unit_conversion(self): self.instrument = convert_spectrum(self.instrument, self.units) self.knowns = convert_knowns_and_parameters(self.knowns, self.units) self.parameters = convert_knowns_and_parameters( self.parameters, self.units )
[docs] def assign_knowns(self): self.temp_vars = {} self.chem_vars = {} self.phys_vars = {} self.cloud_vars = {} self.scat_vars = {} self.moon_vars = {} # Add the known parameters to the dictionary for par in self.knowns.keys(): if self.knowns[par]["type"] == "TEMPERATURE PARAMETERS": self.temp_vars[par] = self.knowns[par]["truth"] elif self.knowns[par]["type"] == "CHEMICAL COMPOSITION PARAMETERS": self.chem_vars[par] = self.knowns[par]["truth"] elif self.knowns[par]["type"] == "PHYSICAL PARAMETERS": self.phys_vars[par] = self.knowns[par]["truth"] elif self.knowns[par]["type"] == "CLOUD PARAMETERS": if ('Pcloud' in par) or (par == 'cloud_fraction') or ('_cloud_top' in par): self.cloud_vars[par] = self.knowns[par]["truth"] else: # TODO review this snippet if ( not "_".join(par.split("_", 2)[:2]) in self.cloud_vars.keys() ): self.cloud_vars["_".join(par.split("_", 2)[:2])] = {} try: self.cloud_vars["_".join(par.split("_", 2)[:2])][ par.split("_", 2)[2] ] = self.knowns[par]["truth"] except: self.cloud_vars["_".join(par.split("_", 2)[:2])][ "abundance" ] = self.knowns[par]["truth"] self.chem_vars[par.split("_", 1)[0]] = self.knowns[par][ "truth" ] elif self.knowns[par]["type"] == "SCATTERING PARAMETERS": self.scat_vars[par] = self.knowns[par]["truth"] elif self.knowns[par]["type"] == "MOON PARAMETERS": self.moon_vars[par] = self.knowns[par]["truth"] # in case the PT profile is known, assign it already if self.settings["parameterization"] == "input": self.press, self.temp = np.loadtxt( self.temp_vars["input_path"], unpack=True )
[docs] def assign_prior_functions(self): #self.parameters = read_input_prior(self.parameters) self.parameters = assign_priors(self.parameters)
# TODO Check that all priors are valid (invalid_prior function in priors.py otherwise)
[docs] def petitRADTRANS_initialization(self): """ Initializes the rt_object given the wavelength range. """ ( used_line_species, used_rayleigh_species, used_cia_species, used_cloud_species, ) = define_linelists( self.config, self.settings, self.input_opacity_path ) # TODO implement verbose output old_stdout = sys.stdout sys.stdout = open(os.devnull, "w") ls = sorted(used_line_species)[::-1] self.rt_object = Radtrans( pressures=np.logspace( self.settings["log_top_pressure"], 0, self.settings["n_layers"], base=10,), line_species=ls, rayleigh_species=sorted(used_rayleigh_species), gas_continuum_contributors=sorted(used_cia_species), cloud_species=sorted(used_cloud_species), wavelength_boundaries=self.settings["wavelength_range"], line_opacity_mode="c-k", scattering_in_emission= True in self.settings["include_scattering"].values(),) sys.stdout = old_stdout
[docs] def unity_cube_to_prior_space(self, cube): cube_copy = cube.copy() M_is_in = False R_pl = None idx_M_pl = None for par in self.parameters.keys(): idx = list(self.parameters.keys()).index(par) if par == "M_pl": M_is_in = True idx_M_pl = idx continue prior = self.parameters[par]["prior"] cube_copy[idx] = prior["function"](cube_copy[idx], prior["prior_specs"]) if par == "R_pl": R_pl = cube_copy[idx] if M_is_in: #Need to treat M separately to allow for 2D prior prior = self.parameters['M_pl']["prior"] if prior['kind'] == '2d-uniform': if R_pl is None: R_pl = self.knowns['R_pl']['input_truth'] cube_copy[idx_M_pl] = prior["function"](cube_copy[idx_M_pl], R_pl) else: cube_copy[idx_M_pl] = prior["function"](cube_copy[idx_M_pl], prior["prior_specs"]) return cube_copy
[docs] def assign_cube_to_parameters(self, cube): for par in self.parameters.keys(): idx = list(self.parameters.keys()).index(par) if self.parameters[par]["type"] == "TEMPERATURE PARAMETERS": self.temp_vars[par] = cube[idx] elif ( self.parameters[par]["type"] == "CHEMICAL COMPOSITION PARAMETERS" ): self.chem_vars[par] = cube[idx] elif self.parameters[par]["type"] == "PHYSICAL PARAMETERS": self.phys_vars[par] = cube[idx] elif self.parameters[par]["type"] == "CLOUD PARAMETERS": if ('Pcloud' in par) or (par == 'cloud_fraction') or ('_cloud_top' in par): self.cloud_vars[par] = cube[idx] else: if ( not "_".join(par.split("_", 2)[:2]) in self.cloud_vars.keys() ): self.cloud_vars["_".join(par.split("_", 2)[:2])] = {} try: self.cloud_vars["_".join(par.split("_", 2)[:2])][ par.split("_", 2)[2] ] = cube[idx] except: self.cloud_vars["_".join(par.split("_", 2)[:2])][ "abundance" ] = cube[idx] self.chem_vars[par.split("_", 1)[0]] = cube[idx] elif self.parameters[par]["type"] == "SCATTERING PARAMETERS": self.scat_vars[par] = cube[idx] elif self.parameters[par]["type"] == "MOON PARAMETERS": self.moon_vars[par] = cube[idx]
[docs] def calculate_pt_profile( self, parameterization, log_ground_pressure, log_top_pressure, layers ): """ Creates the pressure-temperature profile from the temperature parameters and the pressure. """ self.press = np.array( np.logspace(log_top_pressure, log_ground_pressure, layers, base=10) ) if parameterization == "polynomial": self.temp = calculate_polynomial_profile(self.press, self.temp_vars) # TODO understand what is going on here elif parameterization == "vae_pt": self.temp = calculate_vae_profile(self.press, self.vae_pt, self.temp_vars) elif parameterization == "guillot": self.temp = calculate_guillot_profile(self.press, self.temp_vars) elif self.settings["parameterization"] == "isothermal": self.temp = calculate_isothermal_profile(self.press, self.temp_vars) elif self.settings["parameterization"] == "madhuseager": self.temp = calculate_madhuseager_profile(self.press, self.temp_vars) elif self.settings["parameterization"] == "mod_madhuseager": self.temp = calculate_mod_madhuseager_profile(self.press, self.temp_vars) elif self.settings["parameterization"] == "spline": self.temp = calculate_spline_profile(self.press, self.temp_vars, self.phys_vars, self.settings) elif self.settings["parameterization"] == "adiabat": self.temp = calculate_adiabat_profile(self.press, self.temp_vars, self.phys_vars) elif self.settings["parameterization"] == "line": self.temp = calculate_line_profile(self.press, self.temp_vars, self.phys_vars) else: raise ValueError("Unknown PT setting!") return
[docs] def calculate_abundances(self): self.abundances_VMR, self.chem_vars_VMR = calculate_abundances(self.chem_vars, self.press, self.settings) self.condensation_pressures = None if self.settings['condensation']: if 'H2O_Drying' in self.chem_vars.keys(): self.abundances_VMR, self.condensation_pressures = condense_water(self.abundances_VMR,self.press[::-1],self.temp[::-1],self.phys_vars,self.settings,drying=self.chem_vars['H2O_Drying']) else: self.abundances_VMR, self.condensation_pressures = condense_water(self.abundances_VMR,self.press[::-1],self.temp[::-1],self.phys_vars,self.settings,drying=0) ( self.abundances_VMR, self.cloud_vars, self.cloud_radii, self.cloud_lnorm, self.cloud_Pcloud, self.cloud_fraction ) = assign_cloud_parameters( self.abundances_VMR, self.cloud_vars, self.press, self.phys_vars, self.condensation_pressures )
[docs] def calculate_spectrum(self, em_contr=False): self.inert = calculate_inert(self.abundances_VMR) self.MMW = calculate_mmw_VMR(self.abundances_VMR, self.settings, self.inert) abundances_MMR, inert_MMR = convert_VMR_to_MMR(self.abundances_VMR, self.settings, self.inert, self.MMW) # initialize calculated pressure self.rt_object.pressures = self.press * 1e6 if self.settings["include_moon"]: self.moon_flux = calculate_moon_flux(self.rt_object.frequencies, self.moon_vars) if ( self.settings["include_scattering"]["direct_light"] or self.settings["include_scattering"]["thermal"] ): ( self.rt_object.reflectance, self.rt_object.emissivity, ) = assign_reflectance_emissivity( self.scat_vars, self.rt_object.frequencies ) # Calculate the forward model; this returns the frequency # and the flux F_nu in erg/cm^2/s/Hz. freq, cloud_free_flux, cloud_free_em_contr = calculate_emission_flux(self.rt_object, self.settings, self.temp, abundances_MMR, self.phys_vars["g"], self.MMW, self.cloud_radii, self.cloud_lnorm, self.scat_vars, em_contr=em_contr, Pcloud=None) if self.cloud_fraction is not None: freq, cloudy_flux, cloudy_em_contr = calculate_emission_flux(self.rt_object, self.settings, self.temp, abundances_MMR, self.phys_vars["g"], self.MMW, self.cloud_radii, self.cloud_lnorm, self.scat_vars, em_contr=em_contr, Pcloud=self.cloud_Pcloud) mixed_flux = cloud_free_flux*(1-self.cloud_fraction) + cloudy_flux*self.cloud_fraction if em_contr: mixed_em_contr = cloud_free_em_contr#*(1-self.cloud_fraction) + cloudy_em_contr*self.cloud_fraction else: mixed_flux = cloud_free_flux if em_contr: mixed_em_contr = cloud_free_em_contr wavelength = ( const.c.cgs.value / freq * 1e4 ) if em_contr: return wavelength, mixed_flux, mixed_em_contr else: return wavelength, mixed_flux
[docs] def distance_scale_spectrum(self): if self.phys_vars["d_syst"] is not None: # WARNING! THIS CONVERTS UNITS OF PRT SPECTRUM from cm-2 to m-2 self.rt_object.flux = scale_flux_to_distance( self.rt_object.flux, self.phys_vars["R_pl"], self.phys_vars["d_syst"], ) if self.settings["include_moon"]: self.moon_flux = scale_flux_to_distance( self.moon_flux, self.moon_vars["R_m"], self.phys_vars["d_syst"], )
[docs] def calculate_log_likelihood(self, cube): """ Calculates the log(likelihood) of the forward model generated with parameters and known variables. """ # Generate dictionaries for the different classes of parameters # and add the known parameters as well as a sample of the # retrieved parameters to them self.assign_cube_to_parameters(cube) self.phys_vars = set_log_ground_pressure(self.phys_vars, self.config, self.knowns) # TODO expand on tests here # test goodness of random draw if validate_pt_profile(self.settings, self.temp_vars, self.phys_vars): return -1e99 if validate_cube_finite(cube): return -1e99 if validate_positive_mass(self.phys_vars): return -1e99 self.phys_vars = calculate_gravity(self.phys_vars,self.config) if self.settings["parameterization"] != "input": self.calculate_pt_profile( parameterization=self.settings["parameterization"], log_ground_pressure=self.phys_vars["log_P0"], log_top_pressure=self.settings["log_top_pressure"], layers=self.settings["n_layers"], ) if validate_positive_temperatures(self.temp): return -1e99 self.calculate_abundances() #if validate_abundances(self.abundances_VMR,self.chem_vars_VMR): # return -1e99 if validate_clouds(self.press, self.temp, self.cloud_vars): return -1e99 if validate_sum_of_abundances(self.abundances_VMR): return -1e99 self.rt_object.wavelength, self.rt_object.flux = self.calculate_spectrum() if validate_spectrum_goodness(self.rt_object.flux): return -1e99 self.distance_scale_spectrum() # Calculate total log-likelihood (sum over instruments) log_likelihood = 0.0 for inst in self.instrument.keys(): # Rebin the spectrum according to the input spectrum if wavelengths # differ strongly rebinned_flux = rebin_spectrum( self.instrument[inst], self.rt_object.wavelength, self.rt_object.flux, ) if self.settings["include_moon"] == "True": rebinned_flux += rebin_spectrum( self.instrument[inst], self.rt_object.wavelength, self.moon_flux, ) # Calculate log-likelihood log_likelihood += -0.5 * np.sum( ( (rebinned_flux - self.instrument[inst]["flux"]) / self.instrument[inst]["error"] ) ** 2.0 #+ np.log(2*np.pi*self.instrument[inst]["error"]**2) ) return log_likelihood
[docs] def generate_new_spectrum(self): """ Runs the forward model and calculates a (new) spectrum rather than running a retrieval. """ # Calculate values from given parameters self.phys_vars = set_log_ground_pressure(self.phys_vars, self.config, self.knowns) self.phys_vars = calculate_gravity(self.phys_vars,self.config) if self.settings["parameterization"] != "input": self.calculate_pt_profile( parameterization=self.settings["parameterization"], log_ground_pressure=self.phys_vars["log_P0"], log_top_pressure=self.settings["log_top_pressure"], layers=self.settings["n_layers"], ) if validate_positive_temperatures(self.temp): raise ValueError("PT profile parameters resulted in negative temperatures!") self.calculate_abundances() #if validate_abundances(self.abundances_VMR,self.chem_vars_VMR): # return -1e99 if validate_clouds(self.press, self.temp, self.cloud_vars): raise ValueError("Issue with cloud positioning!") if validate_sum_of_abundances(self.abundances_VMR): raise ValueError("Gas abundances do not sum to 1!") # Perform radiative transfer self.rt_object.wavelength, self.rt_object.flux = self.calculate_spectrum() # if validate_spectrum_goodness(self.rt_object.flux): # return -1e99 self.distance_scale_spectrum() # Rebin to match desired spectral resolution instrument = {} instrument["wavelength"] = np.exp(np.arange( np.log(self.rt_object.wavelength[0]), np.log(self.rt_object.wavelength[-1]), 1/self.settings["resolution"])) rebinned_flux = rebin_spectrum( instrument, self.rt_object.wavelength, self.rt_object.flux, ) if self.settings["include_moon"] == "True": rebinned_flux += rebin_spectrum( instrument, self.rt_object.wavelength, self.moon_flux, ) # Save output to file np.savetxt(self.settings["output_folder"], np.column_stack((instrument["wavelength"], rebinned_flux)))
[docs] def vae_initialization(self): # TODO see if it can be improved # if the vae_pt is selected initialize the pt profile model if self.settings["parameterization"] == "vae_pt": from pyretlife.retrieval import pt_vae as vae self.vae_pt = vae.VAE_PT_Model_Flow( os.path.dirname(os.path.realpath(__file__)) + "/vae_pt_models/Flow/" + self.settings["vae_net"], ) if self.settings["parameterization"] == "vae_pt_flow": from pyretlife.retrieval import pt_vae as vae self.vae_pt = vae.VAE_PT_Model_Flow( os.path.dirname(os.path.realpath(__file__)) + "/vae_pt_models/Flow/" + self.settings["vae_net"], flow_path=os.path.dirname(os.path.realpath(__file__)) + "/vae_pt_models/Flow/flow-state-dict.pt", )
[docs] def saving_inputs_to_folder(self,config_file: Union[Path, str]): make_output_folder(self.settings["output_folder"]) save_input_spectra(self.settings["data_files"], self.settings["output_folder"]) save_configuration(input_path = Path(self.input_retrieval_path+"/configs/config_default.yaml"), output_path = Path(self.config['RUN SETTINGS']['output_folder']+'/input_default_config.yaml')) save_configuration(input_path = config_file, output_path = Path(self.config['RUN SETTINGS']['output_folder']+'/input_new.yaml')) save_github_commit_string(self.input_retrieval_path,self.settings["output_folder"]) save_retrieved_parameters(list(self.parameters.keys()), self.settings["output_folder"]) save_environment_variables(os.environ, self.settings["output_folder"])