Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facilitate perfect forwarding during dispatch #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 30 additions & 94 deletions src/tensor/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@

#include "Tensor.h"
#include <array>
#include <utility>

namespace xt {

// Tensor version
template<class F, class ... T>
auto dispatch(Tensor& t, T&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Tensor&, T&...)>::type
template<typename F, typename TensorT, typename... Args>
auto dispatch(TensorT&& t, Args&&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, TensorT&&, Args&&...)>::type
{
using ReturnType = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Tensor&, T&...)>::type;
using ReturnT = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, TensorT&&, Args&&...)>::type;
using FunctionT = std::function<ReturnT (F&, TensorT&&, Args&&...)>;
F functor;

if(t.device() == kCPU) {
static std::array<std::function<ReturnType (F&, Tensor&, T&...)>, 7> dyn = {{
static std::array<FunctionT, 7> dyn = {{
&F::template cpu<uint8_t>,
&F::template cpu<int8_t>,
&F::template cpu<int16_t>,
Expand All @@ -21,10 +25,9 @@ auto dispatch(Tensor& t, T&... args) -> typename std::result_of<decltype(&F::tem
&F::template cpu<float>,
&F::template cpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, t, args...);
return dyn.at(t.type())(functor, std::forward<TensorT>(t), std::forward<Args>(args)...);
} else if(t.device() == kGPU) {
static std::array<std::function<ReturnType (F&, Tensor&, T&...)>, 7> dyn = {{
static std::array<FunctionT, 7> dyn = {{
&F::template gpu<uint8_t>,
&F::template gpu<int8_t>,
&F::template gpu<int16_t>,
Expand All @@ -33,54 +36,22 @@ auto dispatch(Tensor& t, T&... args) -> typename std::result_of<decltype(&F::tem
&F::template gpu<float>,
&F::template gpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, t, args...);
return dyn.at(t.type())(functor, std::forward<TensorT>(t), std::forward<Args>(args)...);
} else {
throw std::invalid_argument("unsupported device");
}
}

// Context, Tensor version
template<class F, class ... T>
auto dispatch(Context& ctx, Tensor& t, T&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Context&, Tensor&, T&...)>::type
template<typename F, typename TensorT, typename... Args>
auto dispatch(Context& ctx, TensorT&& t, Args&&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Context&, TensorT&&, Args&&...)>::type
{
using ReturnType = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Context&, Tensor&, T&...)>::type;
if(t.device() == kCPU) {
static std::array<std::function<ReturnType (F&, Context&, Tensor&, T&...)>, 7> dyn = {{
&F::template cpu<uint8_t>,
&F::template cpu<int8_t>,
&F::template cpu<int16_t>,
&F::template cpu<int32_t>,
&F::template cpu<int64_t>,
&F::template cpu<float>,
&F::template cpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, t, args...);
} else if(t.device() == kGPU) {
static std::array<std::function<ReturnType (F&, Context&, Tensor&, T&...)>, 7> dyn = {{
&F::template gpu<uint8_t>,
&F::template gpu<int8_t>,
&F::template gpu<int16_t>,
&F::template gpu<int32_t>,
&F::template gpu<int64_t>,
&F::template gpu<float>,
&F::template gpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, ctx, t, args...);
} else {
throw std::invalid_argument("unsupported device");
}
}
using ReturnT = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Context&, TensorT&&, Args&&...)>::type;
using FunctionT = std::function<ReturnT (F&, Context&, TensorT&&, Args&&...)>;
F functor;

// const Tensor version
template<class F, class ... T>
auto dispatch(const Tensor& t, T&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, const Tensor&, T&...)>::type
{
using ReturnType = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, const Tensor&, T&...)>::type;
if(t.device() == kCPU) {
static std::array<std::function<ReturnType (F&, const Tensor&, T&...)>, 7> dyn = {{
static std::array<FunctionT, 7> dyn = {{
&F::template cpu<uint8_t>,
&F::template cpu<int8_t>,
&F::template cpu<int16_t>,
Expand All @@ -89,10 +60,9 @@ auto dispatch(const Tensor& t, T&... args) -> typename std::result_of<decltype(&
&F::template cpu<float>,
&F::template cpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, t, args...);
return dyn.at(t.type())(functor, ctx, std::forward<TensorT>(t), std::forward<Args>(args)...);
} else if(t.device() == kGPU) {
static std::array<std::function<ReturnType (F&, const Tensor&, T&...)>, 7> dyn = {{
static std::array<FunctionT, 7> dyn = {{
&F::template gpu<uint8_t>,
&F::template gpu<int8_t>,
&F::template gpu<int16_t>,
Expand All @@ -101,54 +71,22 @@ auto dispatch(const Tensor& t, T&... args) -> typename std::result_of<decltype(&
&F::template gpu<float>,
&F::template gpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, t, args...);
} else {
throw std::invalid_argument("unsupported device");
}
}

// Context, const Tensor version
template<class F, class ... T>
auto dispatch(Context& ctx, const Tensor& t, T&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Context&, const Tensor&, T&...)>::type
{
using ReturnType = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Context&, const Tensor&, T&...)>::type;
if(t.device() == kCPU) {
static std::array<std::function<ReturnType (F&, Context&, const Tensor&, T&...)>, 7> dyn = {{
&F::template cpu<uint8_t>,
&F::template cpu<int8_t>,
&F::template cpu<int16_t>,
&F::template cpu<int32_t>,
&F::template cpu<int64_t>,
&F::template cpu<float>,
&F::template cpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, ctx, t, args...);
} else if(t.device() == kGPU) {
static std::array<std::function<ReturnType (F&, Context&, const Tensor&, T&...)>, 7> dyn = {{
&F::template gpu<uint8_t>,
&F::template gpu<int8_t>,
&F::template gpu<int16_t>,
&F::template gpu<int32_t>,
&F::template gpu<int64_t>,
&F::template gpu<float>,
&F::template gpu<double>,
}};
F functor;
return dyn.at(t.type())(functor, ctx, t, args...);
return dyn.at(t.type())(functor, ctx, std::forward<TensorT>(t), std::forward<Args>(args)...);
} else {
throw std::invalid_argument("unsupported device");
}
}

// type/device version
template<class F, class ... T>
auto dispatch(TensorType ttype, TensorDevice tdev, T&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, T&...)>::type
template<typename F, typename... Args>
auto dispatch(TensorType ttype, TensorDevice tdev, Args&&... args) -> typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Args&&...)>::type
{
using ReturnType = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, T&...)>::type;
using ReturnT = typename std::result_of<decltype(&F::template cpu<int64_t>)(F&, Args&&...)>::type;
using FunctionT = std::function<ReturnT (F&, Args&&...)>;
F functor;

if(tdev == kCPU) {
static std::array<std::function<ReturnType (F&, T&...)>, 7> dyn = {{
static std::array<FunctionT, 7> dyn = {{
&F::template cpu<uint8_t>,
&F::template cpu<int8_t>,
&F::template cpu<int16_t>,
Expand All @@ -157,10 +95,9 @@ auto dispatch(TensorType ttype, TensorDevice tdev, T&... args) -> typename std::
&F::template cpu<float>,
&F::template cpu<double>,
}};
F functor;
return dyn.at(ttype)(functor, args...);
return dyn.at(ttype)(functor, std::forward<Args>(args)...);
} else if(tdev == kGPU) {
static std::array<std::function<ReturnType (F&, T&...)>, 7> dyn = {{
static std::array<FunctionT, 7> dyn = {{
&F::template gpu<uint8_t>,
&F::template gpu<int8_t>,
&F::template gpu<int16_t>,
Expand All @@ -169,8 +106,7 @@ auto dispatch(TensorType ttype, TensorDevice tdev, T&... args) -> typename std::
&F::template gpu<float>,
&F::template gpu<double>,
}};
F functor;
return dyn.at(ttype)(functor, args...);
return dyn.at(ttype)(functor, std::forward<Args>(args)...);
} else {
throw std::invalid_argument("unsupported device");
}
Expand Down
70 changes: 67 additions & 3 deletions src/tensor/test/basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

using namespace xt;

struct sum_op
struct sum_op_ref
{
template<typename T> Tensor cpu(Tensor& x)
{
Expand All @@ -19,12 +19,60 @@ struct sum_op
}
return sum;
};


template<typename T> Tensor gpu(Tensor& x)
{
throw std::invalid_argument("device not supported");
};
};

struct sum_op_const_ref
{
template<typename T> Tensor cpu(const Tensor& x)
{
if(!isContiguous(x)) {
throw std::invalid_argument("contiguous tensor expected");
}
T* x_p = x.data<T>();
int64_t size = numel(x);
T sum = 0;
for(int64_t i = 0; i < size; i++) {
sum += x_p[i];
}
return sum;
};


template<typename T> Tensor gpu(const Tensor& x)
{
throw std::invalid_argument("device not supported");
};
};

struct sum_op_rvalue_ref
{
template<typename T> Tensor cpu(Tensor&& x)
{
if(!isContiguous(x)) {
throw std::invalid_argument("contiguous tensor expected");
}
T* x_p = x.data<T>();
int64_t size = numel(x);
T sum = 0;
for(int64_t i = 0; i < size; i++) {
sum += x_p[i];
}
return sum;
};


template<typename T> Tensor gpu(const Tensor&& x)
{
throw std::invalid_argument("device not supported");
};
};

static void test(TensorDevice device)
{
{
Expand Down Expand Up @@ -184,10 +232,26 @@ static void test(TensorDevice device)

if(device == kCPU)
{
std::cout << "manual sum:" << std::endl;
std::cout << "manual sum (ref dispatch):" << std::endl;
Tensor a = rand({3, 7}, kFloat, device);
std::cout << a << std::endl;
std::cout << dispatch<sum_op_ref>(a) << " == " << sum(a) << std::endl;
}

if(device == kCPU)
{
std::cout << "manual sum (const ref dispatch):" << std::endl;
const Tensor a = rand({3, 7}, kFloat, device);
std::cout << a << std::endl;
std::cout << dispatch<sum_op_const_ref>(a) << " == " << sum(a) << std::endl;
}

if(device == kCPU)
{
std::cout << "manual sum (rvalue ref dispatch):" << std::endl;
Tensor a = rand({3, 7}, kFloat, device);
std::cout << a << std::endl;
std::cout << dispatch<sum_op>(a) << " == " << sum(a) << std::endl;
std::cout << dispatch<sum_op_rvalue_ref>(std::move(a)) << " == " << sum(a) << std::endl;
}

{
Expand Down