"""
Read in configuration files.
"""
import os
import sys
import shutil
import glob
import numpy as np
import subprocess
import astropy.units as u
import json
# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------
from pathlib import Path
from typing import Union, Tuple
import warnings
import hashlib
import yaml
from deepdiff import DeepDiff
from pyretlife.retrieval.units import UnitsUtil
# -----------------------------------------------------------------------------
# DEFINITIONS
# -----------------------------------------------------------------------------
[docs]
def read_config_file(file_path: Union[Path, str]) -> dict:
"""
Read a configuration from a YAML file.
:param file_path: Union[Path, str]: Path to a file with the configuration.
:return The configuration as a dictionary.
"""
file_path = Path(file_path)
if file_path.suffix == ".yaml":
with open(file_path, "r") as yaml_file:
config = yaml.safe_load(yaml_file)
return config
else:
raise ValueError(
f"Unknown file extension: {file_path.suffix}. Please convert it to a .yaml file."
)
[docs]
def check_if_configs_match(config: dict) -> bool:
"""
The check_if_configs_match function checks if the config file in the retrieval directory
matches the input.yaml file that was used to run a previous simulation. If they match,
the function returns True; otherwise, it returns False.
:param config: the configuration dictionary
:return True if the files are the same, False if they are different.
"""
# Expected location of a config file; check if it exists
retrieval_dir = Path(config["RUN SETTINGS"]["output_folder"])
if not retrieval_dir.exists():
return True
# Find all the config files in the retrieval directory
try:
config_file = next(retrieval_dir.glob("input.yaml"))
except StopIteration:
return True
return not DeepDiff(
config,
read_config_file(config_file),
ignore_order=True,
)
# def compute_hash_of_config_file(file_path: Union[Path, str]) -> str:
# """
# Compute the hash of a configuration file.
#
# Args:
# file_path: Path to a file with the configuration.
#
# Returns:
# The hash of the configuration file.
# """
#
# file_path = Path(file_path)
# with open(file_path, "rb") as f:
# return hashlib.sha256(f.read()).hexdigest()
#
# def convert_ini_to_yaml(file_path: Union[Path, str]) -> None:
# pass
[docs]
def make_output_folder(folder_path: Union[Path, str]) -> None:
"""
The function creates a new folder at the specified path if it does not already exist.
:param folder_path: The `folder_path` parameter is the path to the folder where you want to create a new folder. It can be either a `Path` object or a string representing the path
:type folder_path: Union[Path, str]
"""
folder_path = Path(folder_path)
if not os.path.isdir(folder_path):
os.mkdir(folder_path)
[docs]
def save_configuration(input_path: Union[Path, str],output_path: Union[Path, str]):
"""
The function `save_configuration` copies a file from the input path to the output path if the output
path does not already exist.
:param input_path: The input_path parameter is the path to the configuration file that you want to
save. It can be either a Path object or a string representing the path to the file
:type input_path: Union[Path, str]
:param output_path: The output path is the path where the configuration file will be saved. It can
be either a `Path` object or a string representing the path
:type output_path: Union[Path, str]
"""
if not os.path.exists(output_path):
shutil.copyfile(input_path, output_path)
[docs]
def save_github_commit_string(input_retrieval_path: Union[Path, str],
output_path: Union[Path, str]):
if not os.path.exists(output_path + "/git_commit.txt"):
if input_retrieval_path != "":
os.system(
"git -C "
+ input_retrieval_path
+ " show --name-status >"
+ output_path
+ "/git_commit.txt"
)
[docs]
def save_retrieved_parameters(parameter_keys:list,
output_path: Union[Path, str]):
if not os.path.exists(output_path + "/params.json"):
with open(output_path + "/params.json", "w") as f:
json.dump(parameter_keys, f, indent=2)
[docs]
def save_environment_variables(environment_variables: dict,
output_path: Union[Path, str]):
if not os.path.exists(output_path + "/environment_variables.json"):
pyretlife_variables = {variable:environment_variables[variable] for variable in environment_variables.keys() if 'PYRETLIFE' in variable}
with open(output_path + "/environment_variables.json", "w") as f:
json.dump(pyretlife_variables, f, indent=2)
[docs]
def populate_dictionaries(
config: dict,
knowns: dict,
parameters: dict,
settings: dict,
units: UnitsUtil,
) -> Tuple[dict, dict, dict, UnitsUtil]:
if "USER-DEFINED UNITS" in config.keys():
for key in config["USER-DEFINED UNITS"]:
units.custom_unit(
key, u.Quantity(config["USER-DEFINED UNITS"][key])
)
linelist = []
for section in config.keys():
if section != "USER-DEFINED UNITS":
for subsection in config[section].keys():
if (
type(config[section][subsection]) is dict
and "prior" in config[section][subsection].keys()
):
parameters[subsection] = config[section][subsection]
if not "truth" in config[section][subsection].keys():
parameters[subsection]["truth"]=None
if "unit" in config[section][subsection].keys():
input_unit = u.Unit(config[section][subsection]["unit"])
else:
input_unit = units.return_units(
subsection, units.std_input_units
)
parameters[subsection]["unit"] = input_unit
parameters[subsection]["type"] = section
elif (
type(config[section][subsection]) is dict
and "truth" in config[section][subsection].keys()
):
knowns[subsection] = config[section][subsection]
if "unit" in config[section][subsection].keys():
input_unit = u.Unit(config[section][subsection]["unit"])
else:
input_unit = units.return_units(
subsection, units.std_input_units
)
knowns[subsection]["unit"] = input_unit
knowns[subsection]["type"] = section
else:
settings[subsection] = config[section][subsection]
# read lists if available. Can be a str or list.
if (
type(config[section][subsection]) is dict
and "lines" in config[section][subsection].keys()
):
if isinstance(config[section][subsection]["lines"], str):
linelist.append(config[section][subsection]["lines"])
else:
linelist.extend(config[section][subsection]["lines"])
settings["opacity_linelist"] = linelist
return knowns, parameters, settings, units
[docs]
def load_data(settings: dict, units: UnitsUtil, retrieval: bool = True) -> dict:
result_dir = settings["output_folder"]
instrument = {}
for data_file in settings["data_files"].keys():
input_string = settings["data_files"][data_file]["path"]
# Case handling for the retrieval plotting
if not retrieval:
if os.path.isfile(
result_dir
+ "/input_"
+ input_string.split("/")[-1].split(" ")[0]
):
input_string = (
result_dir + "/input_" + input_string.split("/")[-1]
)
else:
input_string = (
result_dir
+ "/input_spectrum.txt "
+ " ".join(input_string.split("/")[-1].split(" ")[1:])
)
input_data = np.genfromtxt(input_string)
# retrieve units
if "unit" in settings["data_files"][data_file].keys():
input_unit_wavelength = u.Unit(
settings["data_files"][data_file]["unit"].split(",")[0]
)
input_unit_flux = u.Unit(
settings["data_files"][data_file]["unit"].split(",")[1]
)
else:
input_unit_wavelength = units.return_units(
"wavelength", units.std_input_units
)
input_unit_flux = units.return_units("flux", units.std_input_units)
# trim spectrum
input_data = input_data[
input_data[:, 0]
>= (
settings["wavelength_range"][0]
* units.return_units("WMIN", units.std_input_units)
)
.to(input_unit_wavelength)
.value
]
input_data = input_data[
input_data[:, 0]
<= (
settings["wavelength_range"][1]
* units.return_units("WMAX", units.std_input_units)
)
.to(input_unit_wavelength)
.value
]
instrument[data_file] = {
"input_data": input_data,
"input_unit_wavelength": input_unit_wavelength,
"input_unit_flux": input_unit_flux,
}
return instrument
[docs]
def get_check_opacity_path() -> Path:
"""
The get_check_opacity_path function checks that the PYRETLIFE_OPACITY_PATH environment variable is set.
If it is not, an error message is printed and the program exits. If it is set, then a Path object pointing to
the opacity folder in this directory will be returned.
Returns:
The path to the opacity folder
"""
input_opacity_path = os.environ.get("PYRETLIFE_OPACITY_PATH")
if input_opacity_path is None:
raise RuntimeError("PYRETLIFE_OPACITY_PATH not set!")
if not Path(input_opacity_path).exists():
raise RuntimeError(
"PYRETLIFE_OPACITY_PATH set, but folder does not exist!"
)
if len(glob.glob(input_opacity_path + "/opacities/*")) == 0:
raise RuntimeError(
"PYRETLIFE_OPACITY_PATH set, but folder is not valid."
)
return Path(input_opacity_path)
[docs]
def get_check_prt_path() -> Path:
"""
The get_check_pRT_path function checks that the PYRETLIFE_PRT_PATH environment variable is set, and if so,
checks that it points to a valid folder. If all these conditions are met, then the function returns a Path object
pointing to this folder.
Returns:
The path to the petitRADTRANS folder
"""
input_prt_path = os.environ.get("PYRETLIFE_PRT_PATH")
if input_prt_path is None:
raise RuntimeError("PYRETLIFE_PRT_PATH not set!")
if not Path(input_prt_path).exists():
raise RuntimeError("PYRETLIFE_PRT_PATH set, but folder does not exist!")
if len(glob.glob(input_prt_path + "/petitRADTRANS/*")) == 0:
raise RuntimeError("PYRETLIFE_PRT_PATH set, but folder is not valid.")
return Path(input_prt_path)
[docs]
def get_retrieval_path() -> Union[Path, str]:
"""
The get_retrieval_path function checks that the PYRETLIFE_RETRIEVAL_PATH environment variable is set,
and if so, then the function returns a Path object pointing to this folder.
Returns:
The path to the retrieval folder
"""
input_retrieval_path = os.environ.get("PYRETLIFE_RETRIEVAL_PATH")
if input_retrieval_path is None:
warnings.warn("PYRETLIFE_RETRIEVAL_PATH not set. Skipping...")
input_retrieval_path = ""
return input_retrieval_path
if not Path(input_retrieval_path).exists():
warnings.warn(
"PYRETLIFE_RETRIEVAL_PATH set, but folder does not exist! Skipping..."
)
input_retrieval_path = ""
return input_retrieval_path
if not bool(
subprocess.check_output(
"git -C "
+ input_retrieval_path
+ " rev-parse --is-inside-work-tree",
shell=True,
)
.decode()
.strip("\n")
):
warnings.warn(
"PYRETLIFE_RETRIEVAL_PATH set, but not a Git Repository! Skipping..."
)
input_retrieval_path = ""
return input_retrieval_path
[docs]
def set_prt_opacity(
input_prt_path: Union[Path, str], input_opacity_path: Union[Path, str]
) -> None:
file_path = Path(input_prt_path) / "petitRADTRANS" / "path.txt"
with open(file_path, "r") as path_file:
orig_path = path_file.read()
# LEGACY: for older versions of petitRADTRANS
# if orig_path != "#\n" + str(input_opacity_path):
# with open(file_path, "w+") as input_data:
# input_data.write("#\n" + input_opacity_path)
# For new versions of pRT
os.environ["pRT_input_data_path"] = str(input_opacity_path)