From b221396aaa6a471f8346cc7d0e11802fe1cb7daa Mon Sep 17 00:00:00 2001 From: Florian Nachtigall Date: Fri, 23 Jun 2023 15:17:08 +0200 Subject: [PATCH 1/2] Suppress warning if base treatment is default scalar If multiple treatments are specified, a warning is given if a scalar is provided as treatment value. Previously, since this check currently included the base treatment T0, for which the default is 0, a warning was emitted even if the base treatment was not specified. Now, suppress misleading warnings for the base treatment being 0. --- econml/_cate_estimator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/econml/_cate_estimator.py b/econml/_cate_estimator.py index 3c453b291..f65f15b81 100644 --- a/econml/_cate_estimator.py +++ b/econml/_cate_estimator.py @@ -591,7 +591,9 @@ def effect(self, X=None, *, T0, T1): Note that when Y is a vector rather than a 2-dimensional array, the corresponding singleton dimension will be collapsed (so this method will return a vector) """ - X, T0, T1 = self._expand_treatments(X, T0, T1) + X, T1 = self._expand_treatments(X, T1) + is_default = ndim(T0) == 0 and T0 == 0 + _, T0 = self._expand_treatments(None, T0, suppress_warn=is_default) # TODO: what if input is sparse? - there's no equivalent to einsum, # but tensordot can't be applied to this problem because we don't sum over m eff = self.const_marginal_effect(X) @@ -847,12 +849,12 @@ def _postfit(self, Y, T, *args, **kwargs): if self.transformer: self._set_transformed_treatment_names() - def _expand_treatments(self, X=None, *Ts, transform=True): + def _expand_treatments(self, X=None, *Ts, transform=True, suppress_warn=False): X, *Ts = check_input_arrays(X, *Ts) n_rows = 1 if X is None else shape(X)[0] outTs = [] for T in Ts: - if (ndim(T) == 0) and self._d_t_in and self._d_t_in[0] > 1: + if (ndim(T) == 0) and self._d_t_in and self._d_t_in[0] > 1 and not suppress_warn: warn("A scalar was specified but there are multiple treatments; " "the same value will be used for each treatment. Consider specifying" "all treatments, or using the const_marginal_effect method.") From 9a5bfbbcae2c031c415a57bb82bfae643420be30 Mon Sep 17 00:00:00 2001 From: Florian Nachtigall Date: Fri, 23 Jun 2023 16:10:00 +0200 Subject: [PATCH 2/2] Remove dead code --- econml/_cate_estimator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/econml/_cate_estimator.py b/econml/_cate_estimator.py index f65f15b81..be559ad8e 100644 --- a/econml/_cate_estimator.py +++ b/econml/_cate_estimator.py @@ -601,7 +601,6 @@ def effect(self, X=None, *, T0, T1): # of rows of T was not taken into account if X is None: eff = np.repeat(eff, shape(T0)[0], axis=0) - m = shape(eff)[0] dT = T1 - T0 einsum_str = 'myt,mt->my' if ndim(dT) == 1: