From 734a4e7ee4fd3200c24473a12292fea3ec4f27bb Mon Sep 17 00:00:00 2001 From: "#W[_t" Date: Fri, 8 Apr 2022 18:28:56 +0800 Subject: [PATCH] Add batch splitting ratio check Check whether the splitting ratio adds up to one, e.g. `opt.batch_ratio="0.1-0.2"` would raise a error while "1"or "0.2-0.8" won't. --- dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dataset.py b/dataset.py index d98224ad38..c25eb4f04d 100644 --- a/dataset.py +++ b/dataset.py @@ -29,6 +29,8 @@ def __init__(self, opt): print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') assert len(opt.select_data) == len(opt.batch_ratio) + opt.batch_ratio = [float(x) for x in opt.batch_ratio] + assert np.sum(opt.batch_ratio) == 1.0 _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) self.data_loader_list = [] @@ -36,7 +38,7 @@ def __init__(self, opt): batch_size_list = [] Total_batch_size = 0 for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): - _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) + _batch_size = max(round(opt.batch_size * batch_ratio_d), 1) print(dashed_line) log.write(dashed_line + '\n') _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) @@ -54,7 +56,7 @@ def __init__(self, opt): _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(dataset_split), dataset_split)] selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' - selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' + selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {batch_ratio_d} (batch_ratio) = {_batch_size}' print(selected_d_log) log.write(selected_d_log + '\n') batch_size_list.append(str(_batch_size))