diff --git a/src/nn/activations.rs b/src/nn/activations.rs index 6c74de689..ab428e4a1 100644 --- a/src/nn/activations.rs +++ b/src/nn/activations.rs @@ -26,11 +26,14 @@ macro_rules! activation_impls { activation_impls!(ReLU, try_relu, #[doc="Calls [relu()]."]); activation_impls!(GeLU, try_gelu, #[doc="Calls [gelu()]."]); +activation_impls!(ReLU6, try_relu6, #[doc="Calls [relu6()]."]); +activation_impls!(HardSwish, try_hard_swish, #[doc="Calls [hard_swish()]."]); activation_impls!(Sin, try_sin, #[doc="Calls [sin()]."]); activation_impls!(Cos, try_cos, #[doc="Calls [cos()]."]); activation_impls!(Ln, try_ln, #[doc="Calls [ln()]."]); activation_impls!(Exp, try_exp, #[doc="Calls [exp()]."]); activation_impls!(Sigmoid, try_sigmoid, #[doc="Calls [sigmoid()]."]); +activation_impls!(HardSigmoid, try_hard_sigmoid, #[doc="Calls [hard_sigmoid()]."]); activation_impls!(Tanh, try_tanh, #[doc="Calls [tanh()]."]); activation_impls!(Square, try_square, #[doc="Calls [square()]."]); activation_impls!(Sqrt, try_sqrt, #[doc="Calls [sqrt()]."]); diff --git a/src/tensor_ops/hard_sigmoid/cpu_kernel.rs b/src/tensor_ops/hard_sigmoid/cpu_kernel.rs new file mode 100644 index 000000000..95e61c1ff --- /dev/null +++ b/src/tensor_ops/hard_sigmoid/cpu_kernel.rs @@ -0,0 +1,21 @@ +use crate::tensor_ops::cpu_kernels::UnaryDerivative; + +impl UnaryDerivative for super::HardSigmoidKernelOp { + const DF_USES_FX: bool = true; + const HAS_CONST_DF: bool = false; + #[inline(always)] + fn f(&self, &x: &F) -> F { + (x + F::from(3.0).unwrap()) + .max(F::zero()) + .min(F::from(6.0).unwrap()) + / F::from(6.0).unwrap() + } + #[inline(always)] + fn df(&self, &fx: &F) -> F { + if fx > F::zero() && fx < F::one() { + F::one() / F::from(6.0).unwrap() + } else { + F::zero() + } + } +} diff --git a/src/tensor_ops/hard_sigmoid/cuda_kernel.rs b/src/tensor_ops/hard_sigmoid/cuda_kernel.rs new file mode 100644 index 000000000..b2c761c35 --- /dev/null +++ b/src/tensor_ops/hard_sigmoid/cuda_kernel.rs @@ -0,0 +1,11 @@ +use super::HardSigmoidKernelOp as HardSigmoid; +use crate::tensor_ops::cuda_kernels::cuda_unary; + +unsafe impl cudarc::driver::DeviceRepr for HardSigmoid {} + +const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/hard_sigmoid.ptx")); + +#[cfg(feature = "f16")] +cuda_unary!(df(f(x)) HardSigmoid, half::f16, PTX, "hard_sigmoid_fwd_f16", "hard_sigmoid_bwd_f16"); +cuda_unary!(df(f(x)) HardSigmoid, f32, PTX, "hard_sigmoid_fwd_f32", "hard_sigmoid_bwd_f32"); +cuda_unary!(df(f(x)) HardSigmoid, f64, PTX, "hard_sigmoid_fwd_f64", "hard_sigmoid_bwd_f64"); diff --git a/src/tensor_ops/hard_sigmoid/hard_sigmoid.cu b/src/tensor_ops/hard_sigmoid/hard_sigmoid.cu new file mode 100644 index 000000000..15a02631c --- /dev/null +++ b/src/tensor_ops/hard_sigmoid/hard_sigmoid.cu @@ -0,0 +1,31 @@ +#include "unary_op_macros.cuh" + +struct HardSigmoidKernelOp {}; + +template +__device__ __forceinline__ T hard_sigmoid_fwd(T x) { + T zero = 0.0; + T three = 3.0; + T six = 6.0; + return ming(maxg(x + three, zero), six) / six; +} + +template +__device__ __forceinline__ T hard_sigmoid_bwd(T y) { + T one_sixth = 1.0 / 6.0; + T zero = 0.0; + T one = 1.0; + return y > zero ? y < one ? one_sixth : zero : zero; +} + +UNARY_OP(__half, hard_sigmoid_fwd_f16, hard_sigmoid_bwd_f16, HardSigmoidKernelOp, + hard_sigmoid_fwd(x), + hard_sigmoid_bwd(y)) + +UNARY_OP(float, hard_sigmoid_fwd_f32, hard_sigmoid_bwd_f32, HardSigmoidKernelOp, + hard_sigmoid_fwd(x), + hard_sigmoid_bwd(y)) + +UNARY_OP(double, hard_sigmoid_fwd_f64, hard_sigmoid_bwd_f64, HardSigmoidKernelOp, + hard_sigmoid_fwd(x), + hard_sigmoid_bwd(y)) diff --git a/src/tensor_ops/hard_sigmoid/mod.rs b/src/tensor_ops/hard_sigmoid/mod.rs new file mode 100644 index 000000000..058ca9866 --- /dev/null +++ b/src/tensor_ops/hard_sigmoid/mod.rs @@ -0,0 +1,54 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +use super::ops::{try_unary_op, UnaryKernel}; +use crate::{shapes::*, tensor::*}; + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct HardSigmoidKernelOp; + +/// [Hard Sigmoid](https://arxiv.org/abs/1905.02244). `relu6(x + 3) / 6`. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t = dev.tensor([-4.0, -1.0, 0.0, 1.0, 2.0, 4.0]); +/// let r = t.hard_sigmoid(); +/// ``` +pub fn hard_sigmoid, T: Tape>( + t: Tensor, +) -> Tensor { + t.hard_sigmoid() +} + +impl, T: Tape> Tensor { + /// See [hard_sigmoid] + pub fn hard_sigmoid(self) -> Self { + self.try_hard_sigmoid().unwrap() + } + /// See [hard_sigmoid] + pub fn try_hard_sigmoid(self) -> Result { + try_unary_op(HardSigmoidKernelOp, self) + } +} + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_hard_sigmoid() { + let dev: TestDevice = Default::default(); + let x = dev + .tensor([-4.0, -1.0, 0.0, 1.0, 4.0]) + .to_dtype::(); + let r = x.leaky_trace().hard_sigmoid(); + assert_close_to_literal!(r, [0.0, 0.3333333, 0.5, 0.6666666, 1.0]); + let g = r.mean().backward(); + assert_close_to_literal!(g.get(&x), [0.0, 0.033333335, 0.033333335, 0.033333335, 0.0]); + } +} diff --git a/src/tensor_ops/hard_swish/cpu_kernel.rs b/src/tensor_ops/hard_swish/cpu_kernel.rs new file mode 100644 index 000000000..000845249 --- /dev/null +++ b/src/tensor_ops/hard_swish/cpu_kernel.rs @@ -0,0 +1,21 @@ +use crate::tensor_ops::cpu_kernels::UnaryDerivative; + +impl UnaryDerivative for super::HardSwishKernelOp { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + #[inline(always)] + fn f(&self, &x: &F) -> F { + x * (x + F::from(3.0).unwrap()) + .max(F::zero()) + .min(F::from(6.0).unwrap()) + / F::from(6.0).unwrap() + } + #[inline(always)] + fn df(&self, &x: &F) -> F { + if x > F::from(-3.0).unwrap() { + (x + x + F::from(3.0).unwrap()) / F::from(6.0).unwrap() + } else { + F::zero() + } + } +} diff --git a/src/tensor_ops/hard_swish/cuda_kernel.rs b/src/tensor_ops/hard_swish/cuda_kernel.rs new file mode 100644 index 000000000..08dcecdeb --- /dev/null +++ b/src/tensor_ops/hard_swish/cuda_kernel.rs @@ -0,0 +1,29 @@ +use super::HardSwishKernelOp as HardSwish; +use crate::tensor_ops::cuda_kernels::cuda_unary; + +unsafe impl cudarc::driver::DeviceRepr for HardSwish {} + +const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/hard_swish.ptx")); + +#[cfg(feature = "f16")] +cuda_unary!( + HardSwish, + half::f16, + PTX, + "hard_swish_fwd_f16", + "hard_swish_bwd_f16" +); +cuda_unary!( + HardSwish, + f32, + PTX, + "hard_swish_fwd_f32", + "hard_swish_bwd_f32" +); +cuda_unary!( + HardSwish, + f64, + PTX, + "hard_swish_fwd_f64", + "hard_swish_bwd_f64" +); diff --git a/src/tensor_ops/hard_swish/hard_swish.cu b/src/tensor_ops/hard_swish/hard_swish.cu new file mode 100644 index 000000000..c4b0418fd --- /dev/null +++ b/src/tensor_ops/hard_swish/hard_swish.cu @@ -0,0 +1,32 @@ +#include "unary_op_macros.cuh" + +struct HardSwishKernelOp {}; + +template +__device__ __forceinline__ T hard_swish_fwd(T x) { + T zero = 0.0; + T three = 3.0; + T six = 6.0; + return x * ming(maxg(x + three, zero), six) / six; +} + +template +__device__ __forceinline__ T hard_swish_bwd(T x) { + T minus_three = -3.0; + T zero = 0.0; + T three = 3.0; + T six = 6.0; + return x > minus_three ? ((x + x + three) / six) : zero; +} + +UNARY_OP(__half, hard_swish_fwd_f16, hard_swish_bwd_f16, HardSwishKernelOp, + hard_swish_fwd(x), + hard_swish_bwd(x)) + +UNARY_OP(float, hard_swish_fwd_f32, hard_swish_bwd_f32, HardSwishKernelOp, + hard_swish_fwd(x), + hard_swish_bwd(x)) + +UNARY_OP(double, hard_swish_fwd_f64, hard_swish_bwd_f64, HardSwishKernelOp, + hard_swish_fwd(x), + hard_swish_bwd(x)) diff --git a/src/tensor_ops/hard_swish/mod.rs b/src/tensor_ops/hard_swish/mod.rs new file mode 100644 index 000000000..db78fa12e --- /dev/null +++ b/src/tensor_ops/hard_swish/mod.rs @@ -0,0 +1,54 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +use super::ops::{try_unary_op, UnaryKernel}; +use crate::{shapes::*, tensor::*}; + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct HardSwishKernelOp; + +/// [Hard Swish](https://paperswithcode.com/method/hard-swish). `x * (relu6(x + 3) / 6)`. +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t = dev.tensor([-4.0, -1.0, 0.0, 1.0, 4.0]); +/// let r = t.hard_swish(); +/// ``` +pub fn hard_swish, T: Tape>( + t: Tensor, +) -> Tensor { + t.hard_swish() +} + +impl, T: Tape> Tensor { + /// See [hard_swish] + pub fn hard_swish(self) -> Self { + self.try_hard_swish().unwrap() + } + /// See [hard_swish] + pub fn try_hard_swish(self) -> Result { + try_unary_op(HardSwishKernelOp, self) + } +} + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_hard_swish() { + let dev: TestDevice = Default::default(); + let x = dev + .tensor([-4.0, -1.0, 0.0, 1.0, 4.0]) + .to_dtype::(); + let r = x.leaky_trace().hard_swish(); + assert_close_to_literal!(r, [0.0, -0.33333334, 0.0, 0.6666667, 4.0]); + let g = r.mean().backward(); + assert_close_to_literal!(g.get(&x), [0.0, 0.033333335, 0.1, 0.16666667, 0.36666667]); + } +} diff --git a/src/tensor_ops/mod.rs b/src/tensor_ops/mod.rs index 4f35eb737..9244c126a 100644 --- a/src/tensor_ops/mod.rs +++ b/src/tensor_ops/mod.rs @@ -166,6 +166,8 @@ mod div; mod dropout; mod exp; mod gelu; +mod hard_sigmoid; +mod hard_swish; mod huber_error; mod ln; mod log_softmax; @@ -186,6 +188,7 @@ mod prelu; mod realize_to; mod recip; mod relu; +mod relu6; mod reshape_to; mod roll; mod select_and_gather; @@ -223,6 +226,8 @@ pub use div::{div, TryDiv}; pub use dropout::dropout; pub use exp::exp; pub use gelu::gelu; +pub use hard_sigmoid::hard_sigmoid; +pub use hard_swish::hard_swish; pub use huber_error::huber_error; pub use ln::ln; pub use log_softmax::log_softmax; @@ -243,6 +248,7 @@ pub use prelu::{leakyrelu, prelu, TryPReLU}; pub use realize_to::RealizeTo; pub use recip::recip; pub use relu::relu; +pub use relu6::relu6; pub use reshape_to::ReshapeTo; pub use roll::Roll; pub use select_and_gather::{GatherTo, SelectTo}; diff --git a/src/tensor_ops/relu6/cpu_kernel.rs b/src/tensor_ops/relu6/cpu_kernel.rs new file mode 100644 index 000000000..c9b040491 --- /dev/null +++ b/src/tensor_ops/relu6/cpu_kernel.rs @@ -0,0 +1,18 @@ +use crate::tensor_ops::cpu_kernels::UnaryDerivative; + +impl UnaryDerivative for super::ReLU6KernelOp { + const DF_USES_FX: bool = false; + const HAS_CONST_DF: bool = false; + #[inline(always)] + fn f(&self, x: &F) -> F { + x.max(F::zero()).min(F::from(6.0).unwrap()) + } + #[inline(always)] + fn df(&self, x: &F) -> F { + if x > &F::zero() && x < &F::from(6.0).unwrap() { + F::one() + } else { + F::zero() + } + } +} diff --git a/src/tensor_ops/relu6/cuda_kernel.rs b/src/tensor_ops/relu6/cuda_kernel.rs new file mode 100644 index 000000000..d7e01c25c --- /dev/null +++ b/src/tensor_ops/relu6/cuda_kernel.rs @@ -0,0 +1,17 @@ +use super::ReLU6KernelOp; +use crate::tensor_ops::cuda_kernels::cuda_unary; + +unsafe impl cudarc::driver::DeviceRepr for ReLU6KernelOp {} + +const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/relu6.ptx")); + +#[cfg(feature = "f16")] +cuda_unary!( + ReLU6KernelOp, + half::f16, + PTX, + "relu6_fwd_f16", + "relu6_bwd_f16" +); +cuda_unary!(ReLU6KernelOp, f32, PTX, "relu6_fwd_f32", "relu6_bwd_f32"); +cuda_unary!(ReLU6KernelOp, f64, PTX, "relu6_fwd_f64", "relu6_bwd_f64"); diff --git a/src/tensor_ops/relu6/mod.rs b/src/tensor_ops/relu6/mod.rs new file mode 100644 index 000000000..66f766cf4 --- /dev/null +++ b/src/tensor_ops/relu6/mod.rs @@ -0,0 +1,56 @@ +mod cpu_kernel; + +#[cfg(feature = "cuda")] +mod cuda_kernel; + +use super::ops::{try_unary_op, UnaryKernel}; +use crate::{shapes::*, tensor::*}; + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct ReLU6KernelOp; + +/// Modification of [Rectified Linear Unit (ReLU)](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)). `min(max(0, t), 6)` +/// +/// Examples: +/// ```rust +/// # use dfdx::prelude::*; +/// # let dev: Cpu = Default::default(); +/// let t = dev.tensor([-1.0, 0.0, 1.0, 2.0, 7.0]); +/// let r = t.relu6(); +/// assert_eq!(r.array(), [0.0, 0.0, 1.0, 2.0, 6.0]); +/// ``` +pub fn relu6, T: Tape>( + t: Tensor, +) -> Tensor { + t.relu6() +} + +impl, T: Tape> Tensor { + /// See [relu6] + pub fn relu6(self) -> Self { + self.try_relu6().unwrap() + } + /// See [relu6] + pub fn try_relu6(self) -> Result { + try_unary_op(ReLU6KernelOp, self) + } +} + +#[cfg(test)] +mod tests { + use crate::{tensor::*, tensor_ops::*, tests::*}; + + #[test] + fn test_relu6() { + let dev: TestDevice = Default::default(); + let x = dev + .tensor([-2.0, -1.0, 0.0, 1.0, 2.0, 7.0]) + .to_dtype::(); + let r = x.leaky_trace().relu6(); + assert_close_to_literal!(r, [0.0, 0.0, 0.0, 1.0, 2.0, 6.0]); + // NOTE: call .exp() to make sure we cover cases where .relu6() uses the result's gradient + let g = r.exp().mean().backward(); + assert_close_to_literal!(g.get(&x), [0.0, 0.0, 0.0, 0.45304698, 1.2315094, 0.0]); + } +} diff --git a/src/tensor_ops/relu6/relu6.cu b/src/tensor_ops/relu6/relu6.cu new file mode 100644 index 000000000..9411cbf8d --- /dev/null +++ b/src/tensor_ops/relu6/relu6.cu @@ -0,0 +1,30 @@ +#include "unary_op_macros.cuh" + +struct ReLU6KernelOp {}; + +template +__device__ __forceinline__ T relu6_fwd(T x) { + T zero = 0.0; + T six = 6.0; + return ming(maxg(x, zero), six); +} + +template +__device__ __forceinline__ T relu6_bwd(T x) { + T zero = 0.0; + T one = 1.0; + T six = 6.0; + return x > zero ? x < six ? one : zero : zero; +} + +UNARY_OP(__half, relu6_fwd_f16, relu6_bwd_f16, ReLU6KernelOp, + relu6_fwd(x), + relu6_bwd(x)) + +UNARY_OP(float, relu6_fwd_f32, relu6_bwd_f32, ReLU6KernelOp, + relu6_fwd(x), + relu6_bwd(x)) + +UNARY_OP(double, relu6_fwd_f64, relu6_bwd_f64, ReLU6KernelOp, + relu6_fwd(x), + relu6_bwd(x)) diff --git a/src/tensor_ops/utilities/device.rs b/src/tensor_ops/utilities/device.rs index f150e71e9..ebd0ad1e1 100644 --- a/src/tensor_ops/utilities/device.rs +++ b/src/tensor_ops/utilities/device.rs @@ -87,7 +87,10 @@ pub trait Device: + UnaryKernel + UnaryKernel + UnaryKernel + + UnaryKernel + + UnaryKernel + UnaryKernel + + UnaryKernel + UnaryKernel + UnaryKernel + UnaryKernel