diff --git a/dfdx-core/src/shapes/shape.rs b/dfdx-core/src/shapes/shape.rs index 184337cd..c3e27121 100644 --- a/dfdx-core/src/shapes/shape.rs +++ b/dfdx-core/src/shapes/shape.rs @@ -69,6 +69,30 @@ where } } +impl core::ops::Sub> for usize { + type Output = usize; + fn sub(self, _: Const) -> Self::Output { + self.size() - N + } +} +impl core::ops::Sub for Const { + type Output = usize; + fn sub(self, rhs: usize) -> Self::Output { + N - rhs.size() + } +} + +#[cfg(feature = "nightly")] +impl core::ops::Sub> for Const +where + Const<{ M - N }>: Sized, +{ + type Output = Const<{ M - N }>; + fn sub(self, _: Const) -> Self::Output { + Const + } +} + impl core::ops::Mul> for usize { type Output = usize; fn mul(self, _: Const) -> Self::Output {