diff --git a/dataset.py b/dataset.py index d98224ad38..c25eb4f04d 100644 --- a/dataset.py +++ b/dataset.py @@ -29,6 +29,8 @@ def __init__(self, opt): print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') assert len(opt.select_data) == len(opt.batch_ratio) + opt.batch_ratio = [float(x) for x in opt.batch_ratio] + assert np.sum(opt.batch_ratio) == 1.0 _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) self.data_loader_list = [] @@ -36,7 +38,7 @@ def __init__(self, opt): batch_size_list = [] Total_batch_size = 0 for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): - _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) + _batch_size = max(round(opt.batch_size * batch_ratio_d), 1) print(dashed_line) log.write(dashed_line + '\n') _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) @@ -54,7 +56,7 @@ def __init__(self, opt): _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(dataset_split), dataset_split)] selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' - selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' + selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {batch_ratio_d} (batch_ratio) = {_batch_size}' print(selected_d_log) log.write(selected_d_log + '\n') batch_size_list.append(str(_batch_size))