Skip to content

Commit

Permalink
Fixing ProgressBar() (#181)
Browse files Browse the repository at this point in the history
* Initial commit.

* Minor fixes.

* Style changes.

* Minor style changes.

* Minor Changes

* Changes in HISTORY.md and adding a test.

* minor changes.

* Update HISTORY.md

Co-Authored-By: Ryan Curtin <[email protected]>

* Update include/ensmallen_bits/callbacks/progress_bar.hpp

Co-Authored-By: Ryan Curtin <[email protected]>

* Incorporating the changes.

* Solving style issues

* Solving merge conflicts.

* Adding tests and correcting the epoch conditionals.

* Minor changes in tests.

* Minor style changes

* Minor changes.

* Style changes.

* Changes in callback tests.

* Style Changes.

* Update include/ensmallen_bits/callbacks/progress_bar.hpp

Co-Authored-By: Ryan Curtin <[email protected]>

* Update tests/callbacks_test.cpp

Co-Authored-By: Ryan Curtin <[email protected]>

* Update tests/callbacks_test.cpp

Co-Authored-By: Ryan Curtin <[email protected]>

* Changes in tests.

Co-authored-by: Ryan Curtin <[email protected]>
  • Loading branch information
gaurav-singh1998 and rcurtin authored Apr 16, 2020
1 parent a8c2976 commit a005414
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 15 deletions.
3 changes: 3 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
### ensmallen ?.??.?: "???
###### ????-??-??
* Fix total number of epochs and time estimation for ProgressBar callback
([#181](https://github.com/mlpack/ensmallen/pull/181)).

* Handle SpSubview_col and SpSubview_row in Armadillo 9.870
([#194](https://github.com/mlpack/ensmallen/pull/194)).

Expand Down
26 changes: 11 additions & 15 deletions include/ensmallen_bits/callbacks/progress_bar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,9 @@ class ProgressBar
if (function.NumFunctions() % optimizer.BatchSize() > 0)
epochSize++;

if (!optimizer.MaxIterations())
{
Warn << "Maximum number of iterations not defined (no limit),"
<< " no progress bar shown." << std::endl;
}
else
{
epochs = optimizer.MaxIterations() / epochSize;
if (optimizer.MaxIterations() % epochSize > 0)
epochs++;
}
epochs = optimizer.MaxIterations() / function.NumFunctions();
if (optimizer.MaxIterations() % function.NumFunctions() > 0)
epochs++;

stepTimer.tic();
}
Expand Down Expand Up @@ -138,8 +130,12 @@ class ProgressBar
{
if (newEpoch)
{
output << "Epoch " << epoch << "/" << epochs << "\n";
output.flush();
output << "Epoch " << epoch;
if (epochs > 0)
{
output << "/" << epochs;
}
output << '\n';
newEpoch = false;
}

Expand All @@ -161,8 +157,8 @@ class ProgressBar
}
}

output << "] " << progress << "% - ETA: " << (size_t) stepTimer.toc() *
(epochSize - step + 1) % 60 << "s - loss: " <<
output << "] " << progress << "% - ETA: " << (size_t) (stepTimer.toc() *
(epochSize - step + 1)) % 60 << "s - loss: " <<
objective / (double) step << "\r";
output.flush();

Expand Down
52 changes: 52 additions & 0 deletions tests/callbacks_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,3 +562,55 @@ TEST_CASE("TimerStopCallbackTest", "[CallbacksTest]")
// Add some time to account for the function to return.
REQUIRE(timer.toc() < 2);
}

/**
* Make sure the ProgressBar callback will show the progress on the specified
* output stream if the MaxIterations parameter of the optimizer is 0.
*/
TEST_CASE("ProgressBarCallbackNoMaxIterationsTest", "[CallbacksTest]")
{
SGDTestFunction f;
arma::mat coordinates = f.GetInitialPoint();

StandardSGD s(0.0003, 1, 0, DBL_MAX, true);

std::stringstream stream;
s.Optimize(f, coordinates, ProgressBar(10, stream));

REQUIRE(stream.str().length() > 0);
}

/**
* Make sure the ProgressBar callback will show the progress on the specified
* output stream with the correct epoch number if the MaxIterations parameter
* of the optimizer is 0.
*/
TEST_CASE("ProgressBarCallbackNoMaxIterationsEpochTest", "[CallbacksTest]")
{
SGDTestFunction f;
arma::mat coordinates = f.GetInitialPoint();

StandardSGD s(0.0003, 1, 0, DBL_MAX, true);

std::stringstream stream;
s.Optimize(f, coordinates, ProgressBar(10, stream));
REQUIRE(stream.str().find("Epoch 1") != std::string::npos);
REQUIRE(stream.str().find("Epoch 1/") == std::string::npos);
}

/**
* Make sure the ProgressBar callback will show the progress on the specified
* output stream with the correct epoch number if the MaxIterations parameter
* of the optimizer is not equal to 0.
*/
TEST_CASE("ProgressBarCallbackEpochTest", "[CallbacksTest]")
{
SGDTestFunction f;
arma::mat coordinates = f.GetInitialPoint();

StandardSGD s(0.0003, 1, 1, 1e-9, true);

std::stringstream stream;
s.Optimize(f, coordinates, ProgressBar(10, stream));
REQUIRE(stream.str().find("Epoch 1/1") != std::string::npos);
}

0 comments on commit a005414

Please sign in to comment.