Skip to content

Commit

Permalink
Separate runtime code from test code
Browse files Browse the repository at this point in the history
Expose proximal::runtime::ladmmSolver so we can call L-ADMM algorithm
with early termination. Test the code with test.cpp
  • Loading branch information
antonysigma authored and SteveDiamond committed Oct 26, 2022
1 parent e2e2f8d commit e3d08d7
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 132 deletions.
4 changes: 2 additions & 2 deletions proximal/halide/src/algorithm/problem-interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ using FuncTuple = std::array<Halide::Func, N>;

#if __cplusplus >= 202002L

#warning Halide 10.0 is incompatible to C++20. Upgrade to newer versions. /** A
s compute graph composed of various LinOps written in Halide.
#warning Halide 10.0 is incompatible to C++20. Upgrade to newer versions.
/** A compute graph composed of various LinOps written in Halide.
*
* The (L-)ADMM solvers expects the linear mapping between variable z and u, by
* z_i = K_i * u.
Expand Down
80 changes: 80 additions & 0 deletions proximal/halide/src/user-problem/ladmm-runtime.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "ladmm-runtime.h"

#include <HalideBuffer.h>

#include "ladmm_iter.h"
#include "problem-config.h"

using Halide::Runtime::Buffer;

namespace proximal {
namespace runtime {

constexpr auto W = problem_config::input_width;
constexpr auto H = problem_config::input_height;

signals_t
ladmmSolver(Buffer<const float>& input, const size_t iter_max, const float eps_abs,
const float eps_rel) {
Buffer<float> v(W, H, 1);
Buffer<float> z0(W, H, 1, 2);
Buffer<float> z1(W, H, 1);
Buffer<float> u0(W, H, 1, 2);
Buffer<float> u1(W, H, 1);

// Set zeros
for (auto* buf : {&v, &z0, &z1, &u0, &u1}) {
buf->fill(0.0f);
}

Buffer<float> z0_new(W, H, 1, 2);
Buffer<float> z1_new(W, H, 1);
Buffer<float> u0_new(W, H, 1, 2);
Buffer<float> u1_new(W, H, 1);

Buffer<float> v_new(W, H, 1);

std::vector<float> r(iter_max);
std::vector<float> s(iter_max);
std::vector<float> eps_pri(iter_max);
std::vector<float> eps_dual(iter_max);

for (size_t i = 0; i < iter_max; i++) {
auto _r = Buffer<float>::make_scalar(r.data() + i);
auto _s = Buffer<float>::make_scalar(s.data() + i);
auto _eps_pri = Buffer<float>::make_scalar(eps_pri.data() + i);
auto _eps_dual = Buffer<float>::make_scalar(eps_dual.data() + i);

const auto error = ladmm_iter(input, v, z0, z1, u0, u1, v_new, z0_new, z1_new, u0_new,
u1_new, _r, _s, _eps_pri, _eps_dual);

if (error) {
return {error, {}, {}, {}, {}, {}};
}

// Terminate the algorithm early, if optimal solution is reached.
const bool converged = (r[i] < eps_pri[i]) && (s[i] < eps_dual[i]);
if (converged) {
for (auto* v : {&r, &s, &eps_pri, &eps_dual}) {
v->resize(i + 1);
}
break;
}

if (i != iter_max - 1) {
// This iteration's v_new becomes current v in the next iteration.
std::swap(v, v_new);
std::swap(u0, u0_new);
std::swap(u1, u1_new);
std::swap(z0, z0_new);
std::swap(z1, z1_new);
}
}

constexpr int success = 0;
return {success, v_new, r, s, eps_pri, eps_dual};
}

} // namespace runtime

} // namespace proximal
35 changes: 35 additions & 0 deletions proximal/halide/src/user-problem/ladmm-runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <HalideBuffer.h>

namespace proximal {
namespace runtime {

using Halide::Runtime::Buffer;

struct signals_t {
int error_code;
Buffer<float> v_new;
std::vector<float> r;
std::vector<float> s;
std::vector<float> eps_pri;
std::vector<float> eps_dual;
};

/** Runtime function to call (L-)ADMM, with early termination.
*
* Halide being a non-Turing complete language, is unable to dynamically
* terminate a while-loop. Therefore, we generate the Halide-optimized AOT
* pipeline with a fixed (e.g. 10) iterations, returning the convergence
* metrics at iteration #10.
*
* Then, we check the convergence criteria, and terminate the for-loop when the
* criteria are met. Otherwise, repeat for another (10) iterations.
*
* Reference: https://stackoverflow.com/a/33472074
*/
signals_t ladmmSolver(Buffer<const float>& input, const size_t iter_max = 100,
const float eps_abs = 1e-3, const float eps_rel = 1e-3);
} // namespace runtime

} // namespace proximal
14 changes: 11 additions & 3 deletions proximal/halide/src/user-problem/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,23 @@ solver_bin = custom_target(
build_by_default: true,
)

test_runtime_exe = executable('test-runtime',
ladmm_runtime_lib = library('ladmm-runtime',
sources: [
# TODO(Antony): to separate the ADMM runtime library from the test code.
'test-runtime.cpp',
'ladmm-runtime.cpp',
solver_bin,
],
dependencies: halide_runtime_dep,
)

test_runtime_exe = executable('test-ladmm-runtime',
sources: [
# TODO(Antony): to separate the ADMM runtime library from the test code.
'test.cpp',
],
cpp_args: [
'-DRAW_IMAGE_PATH="@0@"'.format(parrot_img),
],
link_with: ladmm_runtime_lib,
dependencies: [
halide_runtime_dep,
dependency('libpng'),
Expand Down
127 changes: 0 additions & 127 deletions proximal/halide/src/user-problem/test-runtime.cpp

This file was deleted.

54 changes: 54 additions & 0 deletions proximal/halide/src/user-problem/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <HalideBuffer.h>

#include <iostream>

#include "halide_image_io.h"
#include "ladmm-runtime.h"
#include "problem-config.h"

using Halide::Runtime::Buffer;
using Halide::Tools::load_and_convert_image;
using proximal::runtime::ladmmSolver;

namespace {

constexpr auto W = problem_config::input_width;
constexpr auto H = problem_config::input_height;

#ifndef RAW_IMAGE_PATH
#error Path to the raw image must be defined with -DRAW_IMAGE_PATH="..." in the compile command.
#endif

constexpr char raw_image_path[]{RAW_IMAGE_PATH};

constexpr bool verbose = true;

} // namespace

int
main() {
Buffer<float> raw_image = load_and_convert_image(raw_image_path);

raw_image.add_dimension();
Buffer<const float> normalized = std::move(raw_image);

const auto max_n_iter = 50;
const auto [error_code, denoised, r, s, eps_pri, eps_dual] =
ladmmSolver(normalized, max_n_iter);

// TODO(Antony): use std::ranges::zip_view
for (size_t i = 0; i < r.size(); i++) {
const bool converged = (r[i] < eps_pri[i]) && (s[i] < eps_dual[i]);

std::cout << "{r, eps_pri, s, eps_dual}[" << i << "] = " << r[i] << '\t' << eps_pri[i]
<< '\t' << s[i] << '\t' << eps_dual[i] << (converged ? "\tconverged" : "")
<< '\n';
}

std::cout << "Top-left pixel = " << denoised(0, 0, 0) << '\n';

Buffer<float> output = std::move(denoised);
Halide::Tools::convert_and_save_image(output, "denoised.png");

return 0;
}

0 comments on commit e3d08d7

Please sign in to comment.