diff --git a/src/problem.cpp b/src/problem.cpp index ba84856850..fed48dfe88 100644 --- a/src/problem.cpp +++ b/src/problem.cpp @@ -179,11 +179,7 @@ Problem::FindSolutions(Handle& handle, const FindOptions& options, std::size_t m auto ret = std::visit( boost::hof::match( [&](const ConvolutionDescriptor& op_desc) { - if(op_desc.mode == miopenTranspose) - return MakeTransposed().FindSolutionsImpl( - handle, options, max_solutions, buffers, op_desc); - else - return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc); + return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc); }, [&](const SoftmaxDescriptor& op_desc) { return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc); @@ -481,17 +477,21 @@ std::vector Problem::FindSolutionsImpl(Handle& handle, const auto& w = buffers.at(miopenTensorConvolutionW); auto y = buffers.at(miopenTensorConvolutionY); - if(conv_desc.mode == miopenTranspose) - std::swap(x, y); - - const auto conv_problem = AsConvolution(); - - ValidateGroupCount(x_desc, w_desc, conv_desc); + const auto conv_problem = + conv_desc.mode == miopenTranspose ? MakeTransposed().AsConvolution() : AsConvolution(); std::size_t workspace_size; Allocator::ManageDataPtr owned_workspace; Data_t workspace; + if(conv_desc.mode == miopenTranspose) + { + std::swap(x, y); + std::swap(x_desc, y_desc); + } + + ValidateGroupCount(x_desc, w_desc, conv_desc); + if(options.preallocated_workspace) { workspace = options.preallocated_workspace->buffer;