Skip to content

Commit

Permalink
adding copula param checks
Browse files Browse the repository at this point in the history
Took 1 hour 3 minutes
  • Loading branch information
tfm000 committed Oct 31, 2023
1 parent cd97810 commit f2115c9
Showing 1 changed file with 51 additions and 7 deletions.
58 changes: 51 additions & 7 deletions sklarpy/copulas/_prefit_dists.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,33 @@ def _get_mdists(self, mdists: Union[MarginalFitter, dict], d: int,
"MarginalFitter object")
return mdists

def _get_copula_params(self, copula_params: Union[Params, tuple],
data_array: np.ndarray) -> tuple:
"""Converts the user's copula-params input into tuple form and then
checks the parameters of the model and raises an error if one or more
is invalid.
Parameters
----------
copula_params: Union[Params, tuple]
The parameters of the multivariate distribution used to specify
your copula distribution. Can be a Params object of the specific
multivariate distribution or a tuple containing these parameters
in the correct order.
data_array: np.ndarray
numpy array of the multivariate data.
Returns
-------
copula_params_tuple: tuple
The parameters which define the multivariate distribution in the
copula model.
"""
params_tuple: tuple = self._mv_object._get_params(params=copula_params)
self._mv_object._check_dim(data=data_array, params=params_tuple)
return copula_params if isinstance(copula_params, tuple) \
else copula_params.to_tuple

def __mdist_calcs(self, func_strs: List[str], data: np.ndarray,
mdists: Union[MarginalFitter, dict], check: bool,
funcs_kwargs: dict = None) -> Dict[str, np.ndarray]:
Expand Down Expand Up @@ -213,6 +240,10 @@ def logpdf(self, x: Union[pd.DataFrame, np.ndarray],
# checking data
x_array: np.ndarray = self._get_data_array(data=x, is_u=False)

# checking copula params
copula_params_tuple: tuple = self._get_copula_params(
copula_params=copula_params, data_array=x_array)

# getting non-nan data
mask, masked_data, output = get_mask(data=x_array)

Expand All @@ -225,7 +256,7 @@ def logpdf(self, x: Union[pd.DataFrame, np.ndarray],

# calculating logpdf values
logpdf_values: np.ndarray = self.copula_logpdf(
u=res['cdf'], copula_params=copula_params,
u=res['cdf'], copula_params=copula_params_tuple,
match_datatype=False, **kwargs) + res['logpdf'].sum(axis=1)

# converting to correct output datatype
Expand Down Expand Up @@ -319,6 +350,10 @@ def __cdf_mccdf(self, mc_cdf: bool, x: Union[pd.DataFrame, np.ndarray],
# checking data
x_array: np.ndarray = self._get_data_array(data=x, is_u=False)

# checking copula params
copula_params_tuple: tuple = self._get_copula_params(
copula_params=copula_params, data_array=x_array)

# getting non-nan data
mask, masked_data, output = get_mask(data=x_array)

Expand All @@ -330,7 +365,7 @@ def __cdf_mccdf(self, mc_cdf: bool, x: Union[pd.DataFrame, np.ndarray],
mc_str: str = "mc_" if mc_cdf else ""
func: Callable = eval(f"self.copula_{mc_str}cdf")
copula_cdf_values: np.ndarray = func(
u=res['cdf'], copula_params=copula_params,
u=res['cdf'], copula_params=copula_params_tuple,
match_datatype=False, **kwargs)

# converting to correct output datatype
Expand Down Expand Up @@ -514,12 +549,17 @@ def copula_logpdf(self, u: Union[pd.DataFrame, np.ndarray],
# checking data
u_array: np.ndarray = self._get_data_array(data=u, is_u=True)

# checking copula params
copula_params_tuple: tuple = self._get_copula_params(
copula_params=copula_params, data_array=u_array)

# calculating copula logpdf
g: np.ndarray = self._u_to_g(u=u_array, copula_params=copula_params)
g: np.ndarray = self._u_to_g(u=u_array,
copula_params=copula_params_tuple)
g_logpdf: np.ndarray = self._mv_object.logpdf(
x=g, params=copula_params, match_datatype=False, **kwargs)
x=g, params=copula_params_tuple, match_datatype=False, **kwargs)
copula_logpdf_values: np.ndarray = g_logpdf - self._h_logpdf_sum(
g=g, copula_params=copula_params)
g=g, copula_params=copula_params_tuple)
return TypeKeeper(u).type_keep_from_1d_array(
array=copula_logpdf_values, match_datatype=match_datatype,
col_name=['copula_logpdf'])
Expand Down Expand Up @@ -601,11 +641,15 @@ def __copula_cdf_mccdf(self, mc_cdf: bool,
# checking data
u_array: np.ndarray = self._get_data_array(data=u, is_u=True)

# checking copula params
copula_params_tuple: tuple = self._get_copula_params(
copula_params=copula_params, data_array=u_array)

# calculating cdf values
g: np.ndarray = self._u_to_g(u_array, copula_params)
g: np.ndarray = self._u_to_g(u_array, copula_params_tuple)
mc_str: str = "mc_" if mc_cdf else ""
func: Callable = eval(f"self._mv_object.{mc_str}cdf")
copula_cdf_values: np.ndarray = func(x=g, params=copula_params,
copula_cdf_values: np.ndarray = func(x=g, params=copula_params_tuple,
match_datatype=False, **kwargs)
return TypeKeeper(u).type_keep_from_1d_array(
array=copula_cdf_values, match_datatype=match_datatype,
Expand Down

0 comments on commit f2115c9

Please sign in to comment.