Skip to content

Commit

Permalink
Move hipblaslt_op from compile_hipblaslt.cpp to compile_hipblaslt.hpp
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca committed Nov 13, 2024
1 parent 3e4c192 commit 8cd6b7e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
27 changes: 0 additions & 27 deletions src/targets/gpu/compile_hipblaslt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,11 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

struct hipblaslt_op
{
operation op = op::identity{};

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}

std::string name() const { return "gpu::hipblaslt_op"; }

shape compute_shape(std::vector<shape> inputs) const
{
inputs.push_back(inputs.back());
return op.compute_shape(inputs);
}

std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
MIGRAPHX_REGISTER_OP(hipblaslt_op);

static size_t compile(migraphx::context& ctx, operation& op, instruction_ref ins)
{
auto v = op.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
Expand Down
27 changes: 27 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/compile_hipblaslt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

#include <migraphx/config.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/register_op.hpp>
#include <string>

namespace migraphx {
Expand All @@ -37,6 +39,31 @@ struct operation;

namespace gpu {

struct hipblaslt_op
{
operation op = op::identity{};

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}

std::string name() const { return "gpu::hipblaslt_op"; }

shape compute_shape(std::vector<shape> inputs) const
{
inputs.push_back(inputs.back());
return op.compute_shape(inputs);
}

std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
MIGRAPHX_REGISTER_OP(hipblaslt_op);

struct compile_hipblaslt
{
context* ctx = nullptr;
Expand Down

0 comments on commit 8cd6b7e

Please sign in to comment.