-
Notifications
You must be signed in to change notification settings - Fork 1
/
profile_fnn_regression.cpp
59 lines (52 loc) · 1.87 KB
/
profile_fnn_regression.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
/***************************************************************************
* profile_fnn_regression.cpp
*
* Copyright 2021 Mirco De Marchi
*
****************************************************************************/
/*
* This file is part of EdgeLearning.
*
* EdgeLearning is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* EdgeLearning is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with EdgeLearning. If not, see <https://www.gnu.org/licenses/>.
*/
#include "profile_fnn.hpp"
const NeuralNetworkDescriptor execution_time_hidden_layers_descriptor(
{
Dense{"hidden_layer0", 32, ActivationType::ReLU },
}
);
template <OptimizerType OT>
class ProfileFNNRegression : public ProfileFNN<LossType::MSE, OT>
{
public:
ProfileFNNRegression(
ProfileDataset::Type dataset_type,
std::vector<NeuralNetworkDescriptor> hidden_layers_descriptor_vec,
ProfileNN::TrainingSetting default_setting)
: ProfileFNN<LossType::MSE, OT>(
"regression",
dataset_type,
hidden_layers_descriptor_vec,
default_setting)
{ }
};
int main() {
SizeType EPOCHS = 20;
SizeType BATCH_SIZE = 128;
NumType LEARNING_RATE = 0.01;
ProfileFNNRegression<OptimizerType::GRADIENT_DESCENT>(
ProfileDataset::Type::CSV_EXECUTION_TIME,
{execution_time_hidden_layers_descriptor},
{EPOCHS, BATCH_SIZE, LEARNING_RATE}).run();
}