Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MMoE #1811

Open
wants to merge 9 commits into
base: dev_train
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,15 @@ opencl_program.cc
opencl_program.cc
platforms/mac/tnn.xcodeproj/project.xcworkspace/xcuserdata/darrenyao.xcuserdatad/UserInterfaceState.xcuserstate
platforms/mac/tnn.xcodeproj/xcuserdata/darrenyao.xcuserdatad/xcschemes/xcschememanagement.plist

# build output
platforms/ios/tnn.bundle/
platforms/ios/tnn.framework/
scripts/build_aarch64_macos/
scripts/build_macos_native/

# tmp dir
tmp/

# finetune_demo
scripts/finetune_demo**/
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ set(TNN_VERSION "${TNN_MAJOR_VERSION}.${TNN_MINOR_VERSION}.${TNN_PATCH_VERSION}.

option(TNN_CPU_ENABLE "Enable Cpu" ON)
option(TNN_X86_ENABLE "Enable X86" OFF)
option(TNN_ARM_ENABLE "Enable Arm" OFF)
option(TNN_ARM_ENABLE "Enable Arm" ON)
option(TNN_ARM82_ENABLE "Enable Arm82" OFF)
option(TNN_METAL_ENABLE "Enable Metal" OFF)
option(TNN_OPENCL_ENABLE "Enable OpenCL" OFF)
Expand All @@ -38,7 +38,7 @@ option(TNN_SYMBOL_HIDE "Enable Hide Symbol Visibility" ON)
option(TNN_OPENMP_ENABLE "Enable OpenMP" OFF)
option(TNN_BUILD_SHARED "Build Shared Library" ON)
option(TNN_OPENVINO_BUILD_SHARED "Build Shared Openvino Library" OFF)
option(TNN_TEST_ENABLE "Enable Test" OFF)
option(TNN_TEST_ENABLE "Enable Test" ON)
option(TNN_UNIT_TEST_ENABLE "Enable Test" OFF)
option(TNN_PROFILER_ENABLE "Enable Profiler" OFF)
option(TNN_QUANTIZATION_ENABLE "Enable Quantization" OFF)
Expand All @@ -51,7 +51,7 @@ option(TNN_ONNX2TNN_ENABLE "Enable ONNX2TNN Converter" OFF)
option(TNN_TNN2MEM_ENABLE "Enable tnn2mem" OFF)
option(TNN_BUILD_BENCHMARK_TEST_LIB_ENABLE "Enable Build Benchmark Test Lib" OFF)
option(TNN_GLIBCXX_USE_CXX11_ABI_ENABLE "Enable Use CXX11 ABI" ON)
option(TNN_TRAIN_ENABLE "Enable train module" OFF)
option(TNN_TRAIN_ENABLE "Enable train module" ON)
option(TNN_METAL_FLOAT32 "Enable Metal Float32" OFF)
option(TNN_COREML_FLOAT32 "Enable Float32 CoreML Model" ON)
option(TNN_DYNAMIC_RANGE_QUANTIZATION_ENABLE "Enable Dynamic Range Quantization" OFF)
Expand Down
10 changes: 5 additions & 5 deletions include/tnn/core/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ struct PUBLIC TrainConfig {
// loss
LossFunc loss_func = LOSS_FUNC_DEFAULT;
// if loss_func is not default, the following informations are used to create loss layer
std::string target_layer = ""; // the layer whose output is used to calculate loss, default is the last layer
std::vector<std::string> target_layers; // the layers whose outputs are used to calculate losses, default is the last layer
bool auto_add_prob_layer = true; // add softmax or sigmoid layer before loss layer
// target used to calculate loss
std::string ground_truth_name = ""; // the ground truth, provide by model inputs
DimsVector ground_truth_shape = {}; // the shape of the ground truth
std::vector<std::string> ground_truth_names; // the ground truths, provide by model inputs
std::vector<DimsVector> ground_truth_shapes; // the shapes of the ground truths

// solver
SolverType solver_type = SOLVER_TYPE_SGD;
Expand All @@ -189,8 +189,8 @@ struct PUBLIC TrainConfig {
};

struct PUBLIC TrainingFeedback {
std::string loss_name = "";
float loss_value = 0.0;
std::vector<std::string> loss_names;
std::vector<float> loss_values;
std::string global_step_name = "";
int global_step_value = 0;
};
Expand Down
3 changes: 1 addition & 2 deletions source/tnn/core/default_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,8 @@ Status DefaultNetwork::Forward() {
auto layer = layers_[cnt];
std::vector<Blob *> inputs = layer->GetInputBlobs();
std::vector<Blob *> outputs = layer->GetOutputBlobs();

{

#if DUMP_INPUT_BLOB
if (runtime_model_ == RUNTIME_MODE_NORMAL || runtime_model_ == RUNTIME_MODE_BACKWARD) {
// InputBlob data in dumped into files in NCHW_FLOAT format as default
Expand Down
6 changes: 4 additions & 2 deletions source/tnn/core/instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ Status Instance::SaveTrainedModel(const std::string& model_path) {
Status Instance::GetTrainingFeedback(TrainingFeedback& feed_back) {
RETURN_ON_NEQ(network_->GetTrainingFeedback(feed_back), TNN_OK);
std::shared_ptr<Mat> mat;
GetOutputMat(mat, MatConvertParam(), feed_back.loss_name);
feed_back.loss_value = *(reinterpret_cast<float*>(mat->GetData()));
for (int i = 0; i < feed_back.loss_names.size(); ++i) {
GetOutputMat(mat, MatConvertParam(), feed_back.loss_names[i]);
feed_back.loss_values.push_back(*(reinterpret_cast<float*>(mat->GetData())));
}
GetOutputMat(mat, MatConvertParam(), feed_back.global_step_name);
feed_back.global_step_value = *(reinterpret_cast<float*>(mat->GetData()));
return TNN_OK;
Expand Down
5 changes: 5 additions & 0 deletions source/tnn/interpreter/net_structure.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ struct NetStructure {
std::set<std::string> blobs;
ModelType source_model_type = MODEL_TYPE_TNN;

#ifdef TNN_TRAIN
std::vector<std::string> loss_names;
std::vector<std::string> loss_grad_names;
#endif

public:
std::shared_ptr<NetStructure> Copy() {
std::shared_ptr<NetStructure> net_structure(new NetStructure());
Expand Down
134 changes: 78 additions & 56 deletions source/tnn/train/default_train_network.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/train/default_train_network.h"

#include "tnn/train/gradient/gradient_layer.h"
Expand All @@ -37,17 +36,45 @@ Status DefaultTrainNetwork::Init(NetworkConfig &net_config, ModelConfig &model_c
enable_const_folder);
RETURN_ON_NEQ(ret, TNN_OK);

RETURN_ON_NEQ(CopyLossAndLossGradNames(interpreter), TNN_OK);

RETURN_ON_NEQ(InitTrainingStatus(), TNN_OK);

RETURN_ON_NEQ(InitRuntimeInfo(), TNN_OK);

return TNN_OK;
}

Status DefaultTrainNetwork::CopyLossAndLossGradNames(AbstractModelInterpreter *interpreter) {
auto default_interpreter = dynamic_cast<DefaultModelInterpreter *>(interpreter);
CHECK_PARAM_NULL(default_interpreter);

const NetStructure *net_structure = default_interpreter->GetNetStructure();
if (net_structure == NULL) {
LOGE("ERROR: network_ is nil, network_type may not support\n");
return Status(TNNERR_NULL_PARAM, "network_ is nil, network_type may not support");
}

loss_names_ = net_structure->loss_names;
loss_grad_names_ = net_structure->loss_grad_names;

if (loss_names_.empty()) {
LOGE("DefaultTrainNetwork::CopyLossAndLossGradNames ERROR, cannot get loss names\n");
return Status(TNNERR_TRAIN_ERROR, "cannot get loss names");
}
if (loss_grad_names_.empty()) {
LOGE("DefaultTrainNetwork::CopyLossAndLossGradNames ERROR, cannot get loss grad names\n");
return Status(TNNERR_TRAIN_ERROR, "cannot get loss grad names");
}
return TNN_OK;
}

Status DefaultTrainNetwork::GetAllInputBlobs(BlobMap &blobs) {
blob_manager_->GetAllInputBlobs(blobs);
// loss grad is assumed to be one
blobs.erase(loss_grad_name_);
for (auto loss_grad_name : loss_grad_names_) {
blobs.erase(loss_grad_name);
}
// global step init value is assumed to be zero
blobs.erase(global_step_init_name_);
return TNN_OK;
Expand Down Expand Up @@ -81,34 +108,24 @@ Status DefaultTrainNetwork::TrainStep() {
}

Status DefaultTrainNetwork::GetTrainingFeedback(TrainingFeedback &feed_back) {
feed_back.loss_name = loss_name_;
for (const auto & loss_name : loss_names_) {
feed_back.loss_names.push_back(loss_name);
}
feed_back.global_step_name = global_step_name_;
return TNN_OK;
}

Status DefaultTrainNetwork::InitTrainingStatus() {
LayerInfo *loss_layer = nullptr;
LayerInfo *loss_grad_layer = nullptr;
int cnt = 0;
std::vector<LayerInfo *> loss_layers;
std::vector<LayerInfo *> loss_grad_layers;
int cnt = 0;
for (auto layer : net_structure_->layers) {
if (layer->type == LAYER_GRADIENT) {
loss_grad_layer = layer.get();
break;
}
loss_layer = layer.get();
cnt++;
}
forward_layer_count_ = cnt;
if (!loss_layer) {
LOGE("DefaultTrainNetwork::InitTrainingStatus ERROR, cannot get loss layer\n");
return Status(TNNERR_TRAIN_ERROR, "cannot get loss layer");
}
if (!loss_grad_layer) {
LOGE("DefaultTrainNetwork::InitTrainingStatus ERROR, cannot get loss grad layer\n");
return Status(TNNERR_TRAIN_ERROR, "cannot get loss grad layer");
}
loss_name_ = loss_layer->outputs[0];
loss_grad_name_ = loss_grad_layer->inputs.back();

LayerInfo *solver_layer_info = net_structure_->layers.back().get();
if (!solver_layer_info) {
Expand All @@ -131,50 +148,55 @@ Status DefaultTrainNetwork::InitRuntimeInfo() {
}

Status DefaultTrainNetwork::SetLossGrad() {
Blob *loss_blob = blob_manager_->GetBlob(loss_name_);
if (!loss_blob) {
LOGE("DefaultTrainNetwork::SetLossGrad get loss_blob failed\n");
return Status(TNNERR_TRAIN_ERROR, "get loss_blob failed!");
}
auto loss_data_count = DimsVectorUtils::Count(loss_blob->GetBlobDesc().dims);
if (loss_data_count != 1) {
LOGE(
"DefaultTrainNetwork::SetLossGrad only support loss data count = 1 now, got %d. Try to change loss "
"function type or loss target layer!\n",
loss_data_count);
return Status(TNNERR_TRAIN_ERROR,
"loss data count not supported, try to change loss function type or loss target layer!");
}
for (int loss_idx = 0; loss_idx < loss_names_.size(); ++loss_idx) {
const auto loss_name = loss_names_[loss_idx];
const auto loss_grad_name = loss_grad_names_[loss_idx];

Blob *loss_blob = blob_manager_->GetBlob(loss_name);
if (!loss_blob) {
LOGE("DefaultTrainNetwork::SetLossGrad get loss_blob failed\n");
return Status(TNNERR_TRAIN_ERROR, "get loss_blob failed!");
}
auto loss_data_count = DimsVectorUtils::Count(loss_blob->GetBlobDesc().dims);
if (loss_data_count != 1) {
LOGE(
"DefaultTrainNetwork::SetLossGrad only support loss data count = 1 now, got %d. Try to change loss "
"function type or loss target layer!\n",
loss_data_count);
return Status(TNNERR_TRAIN_ERROR,
"loss data count not supported, try to change loss function type or loss target layer!");
}

std::shared_ptr<Mat> mat(new Mat(DEVICE_ARM, NCHW_FLOAT, {loss_data_count}));
if (!mat || !mat->GetData()) {
LOGE("DefaultTrainNetwork::SetLossGrad create mat failed\n");
return Status(TNNERR_TRAIN_ERROR, "create mat failed");
}
std::shared_ptr<Mat> mat(new Mat(DEVICE_ARM, NCHW_FLOAT, {loss_data_count}));
if (!mat || !mat->GetData()) {
LOGE("DefaultTrainNetwork::SetLossGrad create mat failed\n");
return Status(TNNERR_TRAIN_ERROR, "create mat failed");
}

// init loss grad as one
auto ptr = reinterpret_cast<float *>(mat->GetData());
for (int i = 0; i < loss_data_count; ++i) {
ptr[i] = 1.0;
}
// init loss grad as one
auto ptr = reinterpret_cast<float *>(mat->GetData());
for (int i = 0; i < loss_data_count; ++i) {
ptr[i] = 1.0;
}

Blob *loss_grad = blob_manager_->GetBlob(loss_grad_name_);
if (!loss_grad) {
LOGE("DefaultTrainNetwork::SetLossGrad get loss_grad failed\n");
return Status(TNNERR_TRAIN_ERROR, "get loss_grad failed!");
}
Blob *loss_grad = blob_manager_->GetBlob(loss_grad_name);
if (!loss_grad) {
LOGE("DefaultTrainNetwork::SetLossGrad get loss_grad failed\n");
return Status(TNNERR_TRAIN_ERROR, "get loss_grad failed!");
}

// create blob convert
std::shared_ptr<BlobConverter> blob_converter = std::make_shared<BlobConverter>(loss_grad);
// create blob convert
std::shared_ptr<BlobConverter> blob_converter = std::make_shared<BlobConverter>(loss_grad);

// get command queue
void *command_queue = nullptr;
RETURN_ON_NEQ(GetCommandQueue(&command_queue), TNN_OK);
// get command queue
void *command_queue = nullptr;
RETURN_ON_NEQ(GetCommandQueue(&command_queue), TNN_OK);

Status status = blob_converter->ConvertFromMatAsync(*(mat.get()), MatConvertParam(), command_queue);
if (status != TNN_OK) {
LOGE("DefaultTrainNetwork::SetLossGrad, ConvertFromMatAsync Error: %s\n", status.description().c_str());
return status;
Status status = blob_converter->ConvertFromMatAsync(*(mat.get()), MatConvertParam(), command_queue);
if (status != TNN_OK) {
LOGE("DefaultTrainNetwork::SetLossGrad, ConvertFromMatAsync Error: %s\n", status.description().c_str());
return status;
}
}

return TNN_OK;
Expand Down
6 changes: 4 additions & 2 deletions source/tnn/train/default_train_network.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@ class DefaultTrainNetwork : public DefaultNetwork {
virtual Status SetGradientLayerRuntimeInfo();
virtual Status SetSolverLayerRuntimeInfo();

virtual Status CopyLossAndLossGradNames(AbstractModelInterpreter *interpreter);

std::map<Blob *, Blob *> input_to_grad_map_;
std::map<Blob *, RawBuffer *> grad_to_resource_map_;

std::vector<BaseLayer *> need_refresh_layers_;

std::string loss_name_;
std::string loss_grad_name_;
std::vector<std::string> loss_names_;
std::vector<std::string> loss_grad_names_;
std::string global_step_name_;
std::string global_step_init_name_;

Expand Down
Loading