Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mobilenet ops #793

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/nn/activations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ macro_rules! activation_impls {

activation_impls!(ReLU, try_relu, #[doc="Calls [relu()]."]);
activation_impls!(GeLU, try_gelu, #[doc="Calls [gelu()]."]);
activation_impls!(ReLU6, try_relu6, #[doc="Calls [relu6()]."]);
activation_impls!(HardSwish, try_hard_swish, #[doc="Calls [hard_swish()]."]);
activation_impls!(Sin, try_sin, #[doc="Calls [sin()]."]);
activation_impls!(Cos, try_cos, #[doc="Calls [cos()]."]);
activation_impls!(Ln, try_ln, #[doc="Calls [ln()]."]);
activation_impls!(Exp, try_exp, #[doc="Calls [exp()]."]);
activation_impls!(Sigmoid, try_sigmoid, #[doc="Calls [sigmoid()]."]);
activation_impls!(HardSigmoid, try_hard_sigmoid, #[doc="Calls [hard_sigmoid()]."]);
activation_impls!(Tanh, try_tanh, #[doc="Calls [tanh()]."]);
activation_impls!(Square, try_square, #[doc="Calls [square()]."]);
activation_impls!(Sqrt, try_sqrt, #[doc="Calls [sqrt()]."]);
Expand Down
21 changes: 21 additions & 0 deletions src/tensor_ops/hard_sigmoid/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;

impl<F: num_traits::Float> UnaryDerivative<F> for super::HardSigmoidKernelOp {
const DF_USES_FX: bool = true;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, &x: &F) -> F {
(x + F::from(3.0).unwrap())
.max(F::zero())
.min(F::from(6.0).unwrap())
/ F::from(6.0).unwrap()
}
#[inline(always)]
fn df(&self, &fx: &F) -> F {
if fx > F::zero() && fx < F::one() {
F::one() / F::from(6.0).unwrap()
} else {
F::zero()
}
}
}
11 changes: 11 additions & 0 deletions src/tensor_ops/hard_sigmoid/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use super::HardSigmoidKernelOp as HardSigmoid;
use crate::tensor_ops::cuda_kernels::cuda_unary;

unsafe impl cudarc::driver::DeviceRepr for HardSigmoid {}

const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/hard_sigmoid.ptx"));

#[cfg(feature = "f16")]
cuda_unary!(df(f(x)) HardSigmoid, half::f16, PTX, "hard_sigmoid_fwd_f16", "hard_sigmoid_bwd_f16");
cuda_unary!(df(f(x)) HardSigmoid, f32, PTX, "hard_sigmoid_fwd_f32", "hard_sigmoid_bwd_f32");
cuda_unary!(df(f(x)) HardSigmoid, f64, PTX, "hard_sigmoid_fwd_f64", "hard_sigmoid_bwd_f64");
31 changes: 31 additions & 0 deletions src/tensor_ops/hard_sigmoid/hard_sigmoid.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "unary_op_macros.cuh"

struct HardSigmoidKernelOp {};

template<typename T>
__device__ __forceinline__ T hard_sigmoid_fwd(T x) {
T zero = 0.0;
T three = 3.0;
T six = 6.0;
return ming(maxg(x + three, zero), six) / six;
}

template<typename T>
__device__ __forceinline__ T hard_sigmoid_bwd(T y) {
T one_sixth = 1.0 / 6.0;
T zero = 0.0;
T one = 1.0;
return y > zero ? y < one ? one_sixth : zero : zero;
}

UNARY_OP(__half, hard_sigmoid_fwd_f16, hard_sigmoid_bwd_f16, HardSigmoidKernelOp,
hard_sigmoid_fwd(x),
hard_sigmoid_bwd(y))

UNARY_OP(float, hard_sigmoid_fwd_f32, hard_sigmoid_bwd_f32, HardSigmoidKernelOp,
hard_sigmoid_fwd(x),
hard_sigmoid_bwd(y))

UNARY_OP(double, hard_sigmoid_fwd_f64, hard_sigmoid_bwd_f64, HardSigmoidKernelOp,
hard_sigmoid_fwd(x),
hard_sigmoid_bwd(y))
54 changes: 54 additions & 0 deletions src/tensor_ops/hard_sigmoid/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use super::ops::{try_unary_op, UnaryKernel};
use crate::{shapes::*, tensor::*};

#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct HardSigmoidKernelOp;

/// [Hard Sigmoid](https://arxiv.org/abs/1905.02244). `relu6(x + 3) / 6`.
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-4.0, -1.0, 0.0, 1.0, 2.0, 4.0]);
/// let r = t.hard_sigmoid();
/// ```
pub fn hard_sigmoid<S: Shape, E: Dtype, D: UnaryKernel<HardSigmoidKernelOp, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
) -> Tensor<S, E, D, T> {
t.hard_sigmoid()
}

impl<S: Shape, E: Dtype, D: UnaryKernel<HardSigmoidKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
/// See [hard_sigmoid]
pub fn hard_sigmoid(self) -> Self {
self.try_hard_sigmoid().unwrap()
}
/// See [hard_sigmoid]
pub fn try_hard_sigmoid(self) -> Result<Self, D::Err> {
try_unary_op(HardSigmoidKernelOp, self)
}
}

#[cfg(test)]
mod tests {
use crate::{tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_hard_sigmoid() {
let dev: TestDevice = Default::default();
let x = dev
.tensor([-4.0, -1.0, 0.0, 1.0, 4.0])
.to_dtype::<TestDtype>();
let r = x.leaky_trace().hard_sigmoid();
assert_close_to_literal!(r, [0.0, 0.3333333, 0.5, 0.6666666, 1.0]);
let g = r.mean().backward();
assert_close_to_literal!(g.get(&x), [0.0, 0.033333335, 0.033333335, 0.033333335, 0.0]);
}
}
21 changes: 21 additions & 0 deletions src/tensor_ops/hard_swish/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;

impl<F: num_traits::Float> UnaryDerivative<F> for super::HardSwishKernelOp {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, &x: &F) -> F {
x * (x + F::from(3.0).unwrap())
.max(F::zero())
.min(F::from(6.0).unwrap())
/ F::from(6.0).unwrap()
}
#[inline(always)]
fn df(&self, &x: &F) -> F {
if x > F::from(-3.0).unwrap() {
(x + x + F::from(3.0).unwrap()) / F::from(6.0).unwrap()
} else {
F::zero()
}
}
}
29 changes: 29 additions & 0 deletions src/tensor_ops/hard_swish/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use super::HardSwishKernelOp as HardSwish;
use crate::tensor_ops::cuda_kernels::cuda_unary;

unsafe impl cudarc::driver::DeviceRepr for HardSwish {}

const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/hard_swish.ptx"));

#[cfg(feature = "f16")]
cuda_unary!(
HardSwish,
half::f16,
PTX,
"hard_swish_fwd_f16",
"hard_swish_bwd_f16"
);
cuda_unary!(
HardSwish,
f32,
PTX,
"hard_swish_fwd_f32",
"hard_swish_bwd_f32"
);
cuda_unary!(
HardSwish,
f64,
PTX,
"hard_swish_fwd_f64",
"hard_swish_bwd_f64"
);
32 changes: 32 additions & 0 deletions src/tensor_ops/hard_swish/hard_swish.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include "unary_op_macros.cuh"

struct HardSwishKernelOp {};

template<typename T>
__device__ __forceinline__ T hard_swish_fwd(T x) {
T zero = 0.0;
T three = 3.0;
T six = 6.0;
return x * ming(maxg(x + three, zero), six) / six;
}

template<typename T>
__device__ __forceinline__ T hard_swish_bwd(T x) {
T minus_three = -3.0;
T zero = 0.0;
T three = 3.0;
T six = 6.0;
return x > minus_three ? ((x + x + three) / six) : zero;
}

UNARY_OP(__half, hard_swish_fwd_f16, hard_swish_bwd_f16, HardSwishKernelOp,
hard_swish_fwd(x),
hard_swish_bwd(x))

UNARY_OP(float, hard_swish_fwd_f32, hard_swish_bwd_f32, HardSwishKernelOp,
hard_swish_fwd(x),
hard_swish_bwd(x))

UNARY_OP(double, hard_swish_fwd_f64, hard_swish_bwd_f64, HardSwishKernelOp,
hard_swish_fwd(x),
hard_swish_bwd(x))
54 changes: 54 additions & 0 deletions src/tensor_ops/hard_swish/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
mod cpu_kernel;

#[cfg(feature = "cuda")]
mod cuda_kernel;

use super::ops::{try_unary_op, UnaryKernel};
use crate::{shapes::*, tensor::*};

#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct HardSwishKernelOp;

/// [Hard Swish](https://paperswithcode.com/method/hard-swish). `x * (relu6(x + 3) / 6)`.
///
/// Examples:
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let t = dev.tensor([-4.0, -1.0, 0.0, 1.0, 4.0]);
/// let r = t.hard_swish();
/// ```
pub fn hard_swish<S: Shape, E: Dtype, D: UnaryKernel<HardSwishKernelOp, E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>,
) -> Tensor<S, E, D, T> {
t.hard_swish()
}

impl<S: Shape, E: Dtype, D: UnaryKernel<HardSwishKernelOp, E>, T: Tape<E, D>> Tensor<S, E, D, T> {
/// See [hard_swish]
pub fn hard_swish(self) -> Self {
self.try_hard_swish().unwrap()
}
/// See [hard_swish]
pub fn try_hard_swish(self) -> Result<Self, D::Err> {
try_unary_op(HardSwishKernelOp, self)
}
}

#[cfg(test)]
mod tests {
use crate::{tensor::*, tensor_ops::*, tests::*};

#[test]
fn test_hard_swish() {
let dev: TestDevice = Default::default();
let x = dev
.tensor([-4.0, -1.0, 0.0, 1.0, 4.0])
.to_dtype::<TestDtype>();
let r = x.leaky_trace().hard_swish();
assert_close_to_literal!(r, [0.0, -0.33333334, 0.0, 0.6666667, 4.0]);
let g = r.mean().backward();
assert_close_to_literal!(g.get(&x), [0.0, 0.033333335, 0.1, 0.16666667, 0.36666667]);
}
}
6 changes: 6 additions & 0 deletions src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ mod div;
mod dropout;
mod exp;
mod gelu;
mod hard_sigmoid;
mod hard_swish;
mod huber_error;
mod ln;
mod log_softmax;
Expand All @@ -186,6 +188,7 @@ mod prelu;
mod realize_to;
mod recip;
mod relu;
mod relu6;
mod reshape_to;
mod roll;
mod select_and_gather;
Expand Down Expand Up @@ -223,6 +226,8 @@ pub use div::{div, TryDiv};
pub use dropout::dropout;
pub use exp::exp;
pub use gelu::gelu;
pub use hard_sigmoid::hard_sigmoid;
pub use hard_swish::hard_swish;
pub use huber_error::huber_error;
pub use ln::ln;
pub use log_softmax::log_softmax;
Expand All @@ -243,6 +248,7 @@ pub use prelu::{leakyrelu, prelu, TryPReLU};
pub use realize_to::RealizeTo;
pub use recip::recip;
pub use relu::relu;
pub use relu6::relu6;
pub use reshape_to::ReshapeTo;
pub use roll::Roll;
pub use select_and_gather::{GatherTo, SelectTo};
Expand Down
18 changes: 18 additions & 0 deletions src/tensor_ops/relu6/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use crate::tensor_ops::cpu_kernels::UnaryDerivative;

impl<F: num_traits::Float> UnaryDerivative<F> for super::ReLU6KernelOp {
const DF_USES_FX: bool = false;
const HAS_CONST_DF: bool = false;
#[inline(always)]
fn f(&self, x: &F) -> F {
x.max(F::zero()).min(F::from(6.0).unwrap())
}
#[inline(always)]
fn df(&self, x: &F) -> F {
if x > &F::zero() && x < &F::from(6.0).unwrap() {
F::one()
} else {
F::zero()
}
}
}
17 changes: 17 additions & 0 deletions src/tensor_ops/relu6/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use super::ReLU6KernelOp;
use crate::tensor_ops::cuda_kernels::cuda_unary;

unsafe impl cudarc::driver::DeviceRepr for ReLU6KernelOp {}

const PTX: &str = include_str!(concat!(env!("OUT_DIR"), "/relu6.ptx"));

#[cfg(feature = "f16")]
cuda_unary!(
ReLU6KernelOp,
half::f16,
PTX,
"relu6_fwd_f16",
"relu6_bwd_f16"
);
cuda_unary!(ReLU6KernelOp, f32, PTX, "relu6_fwd_f32", "relu6_bwd_f32");
cuda_unary!(ReLU6KernelOp, f64, PTX, "relu6_fwd_f64", "relu6_bwd_f64");
Loading