-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_training_dataset.py
187 lines (156 loc) · 8.78 KB
/
make_training_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
""" Script for generating the full formal training dataset. """
import os
from pathlib import Path
from contextlib import redirect_stdout
import hashlib
import numpy as np
# pylint: disable=no-member
import astropy.units as u
# Tell the libraries where the JKTEBOP executable lives.
# The conda yaml based env sets this but it's not set for venvs.
if not "JKTEBOP_DIR" in os.environ:
os.environ["JKTEBOP_DIR"] = "~/jktebop/"
# pylint: disable=wrong-import-position
# Put these after the above environ statements so the values are picked up if needed
from traininglib import datasets, plots
from ebop_maven.libs import orbital
from ebop_maven.libs.tee import Tee
# By splitting the dataset over multiple files we have the option of using a subset of the dataset
# using a wildcard match, for example "trainset00?.tfrecord" picks up the first 10 files only.
DATASET_SIZE = 500000
FILE_COUNT = DATASET_SIZE // 10000
FILE_PREFIX = "trainset"
dataset_dir = Path(f"./datasets/formal-training-dataset-{DATASET_SIZE // 1000}k/")
dataset_dir.mkdir(parents=True, exist_ok=True)
def generate_instances_from_distributions(label: str):
"""
Generates system instances by picking from random distributions over the
JKTEBOP parameter range.
:label: a useful label to use within messages
:returns: a generator over instance parameter dictionaries, one per system
"""
# pylint: disable=too-many-locals, invalid-name
generated_counter = 0
set_id = ''.join(filter(str.isdigit, label))
# Don't use the built-in hash() function; it's not consistent across processes!!!
seed = int.from_bytes(hashlib.shake_128(label.encode("utf8")).digest(8))
rng = np.random.default_rng(seed)
while True: # infinite loop; we will continue to yield new instances until generator is closed
# The "label" params are rA_plus_rB, k, J, ecosw, esinw and bP (or inc, depending on model)
# Extending range of rA_plus_rB beyond 0.2+0.2 we need to be able to predict systems where
# JKTEBOP may not be suitable (we may not know this unless the predictions are reliable).
# The distributions for k and J are chosen through testing, being those which yield a
# model capable of making good predictions over the various testing datasets.
rA_plus_rB = rng.uniform(low=0.001, high=0.45001)
k = rng.normal(loc=0.5, scale=0.8)
rA = rA_plus_rB / (1 + k) # Not used directly as labels, but useful
rB = rA_plus_rB - rA
# Simple uniform dist for inc (JKTEBOP bottoms out at 50 deg)
inc = rng.uniform(low=50., high=90.00001) * u.deg
J = rng.normal(loc=0.5, scale=0.8)
# We need a version of JKTEBOP which supports negative L3 input values
# (not so for version 43) in order to train a model to predict L3.
L3 = rng.normal(0., 0.1)
L3 = 0 # continue to override until revised JKTEBOP released
# The qphot mass ratio value (MB/MA) affects the lightcurve via the ellipsoidal effect
# due to distortion of the stars' shape. Set to -100 to force spherical stars or derive
# a value from other params. We're using the k-q relations of Demircan & Kahraman (1991)
# Both <1.66 M_sun (k=q^0.935), both >1.66 M_sun (k=q^0.542), MB-low/MA-high (k=q^0.724)
# and approx' single rule is k = q^0.715 which we use here (tests find this works best).
qphot = rng.normal(loc=k**1.4, scale=0.3) if k > 0 else 0
# We generate ecc and omega (argument of periastron) from appropriate distributions.
# They're not used directly as labels, but they make up ecosw and esinw which are.
# Eccentricity is uniform selection, restricted with eqn 5 of Wells & Prsa (2024) which
# reduces the max eccentricity with the separation plus 10% are fixed at 0 to ensure
# sufficient examples. This trains better than simple uniform or normal distributions tried.
ecc = rng.choice([0, rng.uniform(low=0., high=1-(1.5*rA_plus_rB))], p=[0.1, 0.9])
omega = rng.uniform(low=0., high=360.) * u.deg
# Now we can calculate the derived values, sufficient to check we've a usable system
inc_rad = inc.to(u.rad).value
omega_rad = omega.to(u.rad).value
esinw = ecc * np.sin(omega_rad)
ecosw = ecc * np.cos(omega_rad)
imp_prm = orbital.impact_parameter(rA, inc, ecc, None, esinw, orbital.EclipseType.BOTH)
# Create the pset dictionary.
generated_counter += 1
inst_id = f"{set_id}/{generated_counter:06d}"
yield {
"id": inst_id,
# Basic system params for generating the model light-curve
# The keys (& those for LD below) are those expected by make_dateset
"rA_plus_rB": rA_plus_rB,
"k": k,
"inc": inc.to(u.deg).value,
"qphot": qphot,
"ecosw": ecosw,
"esinw": esinw,
"J": J,
"L3": L3,
**datasets.default_limb_darkening_params,
# Further params for potential use as labels/features
"sini": np.sin(inc_rad),
"cosi": np.cos(inc_rad),
"rA": rA,
"rB": rB,
"ecc": ecc,
"omega": omega.to(u.deg).value,
"bP": imp_prm[0],
"bS": imp_prm[1],
"phiS": orbital.secondary_eclipse_phase(ecosw, ecc),
"dS_over_dP": orbital.ratio_of_eclipse_duration(esinw),
}
def is_usable_instance(k: float=0, J: float=0, qphot: float=0, ecc: float=0,
bP: float= None, bS: float=None,
rA: float=1., rB: float=1., inc: float=0,
**_ # Used to ignore any unexpected **params
) -> bool:
"""
Checks various parameter values to decide whether this represents a usable instance.
Checks on;
- is system physically plausible
- will it generate eclipses
- is it suitable for modelling with JKTEBOP
"""
# pylint: disable=invalid-name, too-many-arguments, unused-argument
usable = False
# Use invalid values as defaults so that if any are missing we fail
# Physically plausible (qphot of -100 is a magic number to force spherical)
usable = k > 0 and J > 0 and (qphot > 0 or qphot == -100) and ecc < 1
# Will eclipse
if usable:
usable = all(b is not None and b <= 1 + k for b in [bP, bS])
# Compatible with JKTEBOP restrictions
# Soft restriction of rA & rB both <= 0.23 as its model is not well suited to r >~ 0.2
# Hard restrictions of rA+rB < 0.8 (covered by above), inc > 50, k <= 100
if usable:
usable = rA <= 0.23 and rB <= 0.23 and inc > 50 and k <= 100
return usable
# ------------------------------------------------------------------------------
# Makes the formal training dataset based on the above generator function which
# samples parameter distributions over JKTEBOP's usable range.
# ------------------------------------------------------------------------------
if __name__ == "__main__":
with redirect_stdout(Tee(open(dataset_dir/"dataset.log", "w", encoding="utf8"))):
datasets.make_dataset(instance_count=DATASET_SIZE,
file_count=FILE_COUNT,
output_dir=dataset_dir,
generator_func=generate_instances_from_distributions,
check_func=is_usable_instance,
file_prefix=FILE_PREFIX,
valid_ratio=0.2,
test_ratio=0,
max_workers=5,
save_param_csvs=True,
verbose=True,
simulate=False)
# Histograms are generated from the CSV files as they cover params not saved to tfrecord
csvs = sorted(dataset_dir.glob(f"**/{FILE_PREFIX}*.csv"))
plots.plot_dataset_histograms(csvs, cols=5).savefig(dataset_dir/"train-histogram-full.png")
plots.plot_dataset_histograms(csvs, ["rA_plus_rB", "k", "J", "inc", "ecosw", "esinw"],
cols=2).savefig(dataset_dir/"train-histogram-main.eps")
# Simple diagnostic plot of the mags feature of a small sample of the instances.
print("Plotting a sample of the set's mags features")
dataset_files = sorted(dataset_dir.glob(f"**/training/{FILE_PREFIX}000.tfrecord"))
fig = plots.plot_dataset_instance_mags_features(dataset_files, cols=5, max_instances=50)
fig.savefig(dataset_dir / "sample.png", dpi=150)
fig.clf()