Skip to content

Commit

Permalink
WIP - Making changes to RESSPECT as I test working with the external …
Browse files Browse the repository at this point in the history
…laiss_resspect_classifier library.
  • Loading branch information
drewoldag committed Nov 8, 2024
1 parent a1b8a6f commit bf849b9
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 99 deletions.
24 changes: 16 additions & 8 deletions src/resspect/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import io
import logging
import os
import pandas as pd
import tarfile
Expand All @@ -28,6 +29,7 @@

__all__ = ['DataBase']

logger = logging.getLogger(__name__)

class DataBase:
"""DataBase object, upon which the active learning loop is performed.
Expand Down Expand Up @@ -239,9 +241,9 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
"Bump", or "Malanchev". Default is "Bazin".
"""

if survey not in ['DES', 'LSST']:
raise ValueError('Only "DES" and "LSST" filters are ' + \
'implemented at this point!')
# if survey not in ['DES', 'LSST']:
# raise ValueError('Only "DES" and "LSST" filters are ' + \
# 'implemented at this point!')

# read matrix with features
if '.tar.gz' in path_to_features_file:
Expand All @@ -264,7 +266,9 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
# Create the filter-feature names based on the survey.
survey_filters = FILTER_SETS[survey]
self.features_names = feature_extractor_class.get_features(survey_filters)
self.metadata_names = feature_extractor_class.get_metadata_header()

#! This section needs to be made dynamic between this line and the following comment
self.metadata_names = ['id', 'redshift', 'type', 'code', 'orig_sample', 'queryable']
if 'objid' in data.keys():
self.metadata_names = ['objid', 'redshift', 'type', 'code', 'orig_sample', 'queryable']
Expand All @@ -275,6 +279,7 @@ def load_features_from_file(self, path_to_features_file: str, screen=False,
for name in self.telescope_names:
if 'cost_' + name in data.keys():
self.metadata_names = self.metadata_names + ['cost_' + name]
#! End of section that needs to be made dynamic

if sample == None:
self.features = data[self.features_names].values
Expand Down Expand Up @@ -414,14 +419,14 @@ def load_features(self, path_to_file: str, feature_extractor: str ='Bazin',
if feature_extractor == "photometry":
self.load_photometry_features(path_to_file, screen=screen,
survey=survey, sample=sample)
elif feature_extractor in FEATURE_EXTRACTOR_REGISTRY:
else: # feature_extractor in FEATURE_EXTRACTOR_REGISTRY:
self.load_features_from_file(
path_to_file, screen=screen, survey=survey,
sample=sample, feature_extractor=feature_extractor)
else:
feature_extractors = ', '.join(FEATURE_EXTRACTOR_REGISTRY.keys())
raise ValueError(f'Only {feature_extractors} or photometry features are implemented!'
'\n Feel free to add other options.')
# else:
# feature_extractors = ', '.join(FEATURE_EXTRACTOR_REGISTRY.keys())
# raise ValueError(f'Only {feature_extractors} or photometry features are implemented!'
# '\n Feel free to add other options.')

def load_plasticc_mjd(self, path_to_data_dir):
"""Return all MJDs from 1 file from PLAsTiCC simulations.
Expand Down Expand Up @@ -478,6 +483,9 @@ def identify_keywords(self):
id_name = 'id'
elif 'objid' in self.metadata_names:
id_name = 'objid'
else:
logger.warning('No object identification found in metadata - using first column as object identification!')
id_name = self.metadata_names[0]

return id_name

Expand Down
18 changes: 18 additions & 0 deletions src/resspect/feature_extractors/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from abc import ABC, abstractmethod

class ResspectFeatureExtractor(ABC):
def __init__(self):
pass

@abstractmethod
def fit(self, band: str) -> list:
pass

@abstractmethod
def fit_all(self):
raise NotImplementedError()

@classmethod
@abstractmethod
def get_features(cls, filters: list) -> list[str]:
raise NotImplementedError()
63 changes: 44 additions & 19 deletions src/resspect/feature_extractors/feature_extractor_utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
import itertools
from typing import List

def make_features_header(
def make_column_headers(
filters: List[str],
features: List[str],
**kwargs
) -> list:
) -> List[str]:
"""
This function returns header list for given filters and features. The default
header names are: ['id', 'redshift', 'type', 'code', 'orig_sample'].
This function returns the full header list for given filters and features.
Parameters
----------
filters : list
Filter values. e.g. ['g', 'r', 'i', 'z']
features : list
Feature values. e.g. ['A', 'B']
with_cost : bool
Flag for adding cost values. Default is False
kwargs
Can include the following flags:
- override_primary_columns: List[str] of primary columns to override the default ones
Expand All @@ -27,33 +24,61 @@ def make_features_header(
Returns
-------
header
header list
List[str]
The complete metadata and feature header list
"""

header = []
header.extend(['id', 'redshift', 'type', 'code', 'orig_sample'])

# Create metadata columns
metadata_columns = make_metadata_column_names(**kwargs)
header += metadata_columns

# Create all pairs of filter + feature strings and append to the header
filter_features = make_filter_feature_names(filters, features)
header += filter_features

return header

def make_metadata_column_names(**kwargs) -> List[str]:
"""The default header names are: ['id', 'redshift', 'type', 'code', 'orig_sample'].
Using the keys in kwargs, we can add additional columns to the header.
Parameters
----------
kwargs
Can include the following flags:
- override_primary_columns: List[str] of primary columns to override the default ones
- with_queryable: flag for adding "queryable" column
- with_last_rmag: flag for adding "last_rmag" column
- with_cost: flag for adding "cost_4m" and "cost_8m" columns
Returns
-------
List[str]
metadata header list
"""

metadata_columns = []
metadata_columns.extend(['id', 'redshift', 'type', 'code', 'orig_sample'])

# There are rare instances where we need to override the primary columns
if kwargs.get('override_primary_columns', False):
header = kwargs.get('override_primary_columns')
metadata_columns = kwargs.get('override_primary_columns')

if kwargs.get('with_queryable', False):
header.append('queryable')
metadata_columns.append('queryable')

if kwargs.get('with_last_rmag', False):
header.append('last_rmag')
metadata_columns.append('last_rmag')

#TODO: find where the 'with_cost' flag is used to make sure we apply there
if kwargs.get('with_cost', False):
header.extend(['cost_4m', 'cost_8m'])
metadata_columns.extend(['cost_4m', 'cost_8m'])

# Create all pairs of filter + feature strings and append to the header
filter_features = create_filter_feature_names(filters, features)
header += filter_features

return header
return metadata_columns

def create_filter_feature_names(filters: List[str], features: List[str]) -> List[str]:
def make_filter_feature_names(filters: List[str], features: List[str]) -> List[str]:
"""This function returns the list of concatenated filters and features. e.g.
filter = ['g', 'r'], features = ['A', 'B'] => ['gA', 'gB', 'rA', 'rB']
Expand Down
145 changes: 82 additions & 63 deletions src/resspect/feature_extractors/light_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
read_plasticc_full_photometry_data,
)
from resspect.feature_extractors.feature_extractor_utils import (
create_filter_feature_names,
make_features_header,
make_filter_feature_names,
make_column_headers,
make_metadata_column_names,
)

warnings.filterwarnings("ignore", category=RuntimeWarning)
Expand All @@ -39,66 +40,64 @@

class LightCurve:
""" Light Curve object, holding meta and photometric data.
Attributes
----------
feature_names: list
Class attribute, a list of names of the feature extraction parameters.
features: list
List with the 5 best-fit feature extraction parameters in all filters.
Concatenated from blue to red.
dataset_name: str
Name of the survey or data set being analyzed.
exp_time: dict
Exposure time required to take a spectra.
Keywords indicate telescope e.g.['4m', '8m'].
filters: list
List of broad band filters.
full_photometry: pd.DataFrame
Photometry for a set of light curves read from file.
id: int
SN identification number.
id_name:
Column name of object identifier.
last_mag: float
r-band magnitude of last observed epoch.
photometry: pd.DataFrame
Photometry information.
Minimum keys --> [mjd, band, flux, fluxerr].
redshift: float
Redshift
sample: str
Original sample to which this light curve is assigned.
sim_peakmag: np.array
Simulated peak magnitude in each filter.
sim_pkmjd: float
Simulated day of maximum, observer frame.
sncode: int
Number identifying the SN model used in the simulation.
sntype: str
General classification, possibilities are: Ia, II or Ibc.
unique_ids: str or array
List of unique ids available in the photometry file.
Only used for PLAsTiCC data.
Methods
-------
calc_exp_time(telescope_diam: float, SNR: float, telescope_name: str)
Calculates time required to take a spectra in the last obs epoch.
check_queryable(mjd: float, r_lim: float)
Check if this light can be queried in a given day.
conv_flux_mag(flux: np.array)
Convert positive flux into magnitude.
load_snpcc_lc(path_to_data: str)
Reads header and photometric information for 1 light curve.
load_plasticc_lc(photo_file: str, snid: int)
Load photometric information for 1 PLAsTiCC light curve.
load_resspect_lc(photo_file: str, snid: int)
Load photometric information for 1 RESSPECT light curve.
plot_bump_fit(save: bool, show: bool, output_file: srt)
Plot photometric points and Bump fitted curve.
"""
# Attributes
# ----------
# feature_names: list
# Class attribute, a list of names of the feature extraction parameters.
# features: list
# List with the 5 best-fit feature extraction parameters in all filters.
# Concatenated from blue to red.
# dataset_name: str
# Name of the survey or data set being analyzed.
# exp_time: dict
# Exposure time required to take a spectra.
# Keywords indicate telescope e.g.['4m', '8m'].
# filters: list
# List of broad band filters.
# full_photometry: pd.DataFrame
# Photometry for a set of light curves read from file.
# id: int
# SN identification number.
# id_name:
# Column name of object identifier.
# last_mag: float
# r-band magnitude of last observed epoch.
# photometry: pd.DataFrame
# Photometry information.
# Minimum keys --> [mjd, band, flux, fluxerr].
# redshift: float
# Redshift
# sample: str
# Original sample to which this light curve is assigned.
# sim_peakmag: np.array
# Simulated peak magnitude in each filter.
# sim_pkmjd: float
# Simulated day of maximum, observer frame.
# sncode: int
# Number identifying the SN model used in the simulation.
# sntype: str
# General classification, possibilities are: Ia, II or Ibc.
# unique_ids: str or array
# List of unique ids available in the photometry file.
# Only used for PLAsTiCC data.

# Methods
# -------
# calc_exp_time(telescope_diam: float, SNR: float, telescope_name: str)
# Calculates time required to take a spectra in the last obs epoch.
# check_queryable(mjd: float, r_lim: float)
# Check if this light can be queried in a given day.
# conv_flux_mag(flux: np.array)
# Convert positive flux into magnitude.
# load_snpcc_lc(path_to_data: str)
# Reads header and photometric information for 1 light curve.
# load_plasticc_lc(photo_file: str, snid: int)
# Load photometric information for 1 PLAsTiCC light curve.
# load_resspect_lc(photo_file: str, snid: int)
# Load photometric information for 1 RESSPECT light curve.
# plot_bump_fit(save: bool, show: bool, output_file: srt)
# Plot photometric points and Bump fitted curve.

feature_names = []

Expand Down Expand Up @@ -165,7 +164,24 @@ def get_features(cls, filters: list) -> list[str]:
-------
list
"""
return create_filter_feature_names(filters, cls.feature_names)

if cls.feature_names is None:
raise ValueError("Feature names not defined for this class.")

return make_filter_feature_names(filters, cls.feature_names)

@classmethod
def get_metadata_header(cls, **kwargs) -> list[str]:
"""
Returns the metadata columns for the feature extractor.
i.e. id, redshift, sntype, sncode, sample
Returns
-------
list
"""

return make_metadata_column_names(**kwargs)

@classmethod
def get_feature_header(cls, filters: list, **kwargs) -> list[str]:
Expand All @@ -183,7 +199,10 @@ def get_feature_header(cls, filters: list, **kwargs) -> list[str]:
list
"""

return make_features_header(filters, cls.feature_names, **kwargs)
if cls.feature_names is None:
raise ValueError("Feature names not defined for this class.")

return make_column_headers(filters, cls.feature_names, **kwargs)

def _get_snpcc_photometry_raw_and_header(
self, lc_data: np.ndarray,
Expand Down
1 change: 1 addition & 0 deletions src/resspect/filter_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
"SNPCC": ["g", "r", "i", "z"],
"DES": ["g", "r", "i", "z"],
"LSST": ["u", "g", "r", "i", "z", "Y"],
"ZTF": ["g", "r"],
}
Loading

0 comments on commit bf849b9

Please sign in to comment.