diff --git a/paddlenlp/data/dist_dataloader.py b/paddlenlp/data/dist_dataloader.py index a6330ce1fe08..01f1828b535a 100644 --- a/paddlenlp/data/dist_dataloader.py +++ b/paddlenlp/data/dist_dataloader.py @@ -66,6 +66,7 @@ def __init__( eval = kwargs.pop("eval", False) is_iterable_dataset = kwargs.pop("is_iterable_dataset", False) + self._pp_data_group = kwargs.pop("pp_data_group", None) if dataset is None: dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset() @@ -78,10 +79,8 @@ def __init__( # Init pp data comm group. if self._hcg.get_pipe_parallel_world_size() > 1: - self._pp_data_group = self._init_dataloader_comm_group() self._pp_group = self._hcg.get_pipe_parallel_group() else: - self._pp_data_group = None self._pp_group = None self.mp_group = self._hcg.get_model_parallel_group() @@ -132,18 +131,6 @@ def __len__(self): else: raise ValueError("raise error for `paddlenlp.trainer.trainer_utils.has_length`") - def _init_dataloader_comm_group(self): - topo = self._hcg._topo - parallel_comm_group = None - parallel_groups = topo.get_comm_list("pipe") - - for group in parallel_groups: - ranks = [group[0], group[-1]] - comm_group = paddle.distributed.new_group(ranks=ranks) - if paddle.distributed.get_rank() in ranks: - parallel_comm_group = comm_group - return parallel_comm_group - def __iter__(self): return self @@ -212,3 +199,16 @@ def __next__(self): logger.debug(e) data = self._broadcast_data(data) return data + + +def init_dataloader_comm_group(): + hcg = fleet.get_hybrid_communicate_group() + topo = hcg._topo + parallel_groups = topo.get_comm_list("pipe") + + for group in parallel_groups: + ranks = [group[0], group[-1]] + comm_group = paddle.distributed.new_group(ranks=ranks) + if paddle.distributed.get_rank() in ranks: + parallel_comm_group = comm_group + return parallel_comm_group diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index a5301e290d08..b45542ccc38f 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -79,6 +79,7 @@ DataCollatorWithPadding, DistDataLoader, default_data_collator, + init_dataloader_comm_group, ) from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel @@ -440,6 +441,10 @@ def fn(layer): model.apply(fn) + self._pp_data_group = None + if self.args.pipeline_parallel_degree > 1 and self.args.distributed_dataloader: + self._pp_data_group = init_dataloader_comm_group() + default_label_names = ( ["start_positions", "end_positions"] if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__ @@ -1537,6 +1542,7 @@ def get_train_dataloader(self): train_dataset = self._remove_unused_columns(train_dataset, description="training") _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader + additional_configs = {} if is_iterable_dataset: # For iterable dataset if self.args.dataset_world_size > 1 and train_dataset is not None: train_dataset = IterableDatasetShard( @@ -1549,9 +1555,7 @@ def get_train_dataloader(self): if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") - additional_configs = {"is_iterable_dataset": True} - else: - additional_configs = {} + additional_configs = {"is_iterable_dataset": True, "pp_data_group": self._pp_data_group} return _DataLoader( train_dataset, batch_size=self.args.per_device_train_batch_size, @@ -1563,11 +1567,13 @@ def get_train_dataloader(self): train_sampler = self._get_train_sampler() if self.args.distributed_dataloader: logger.info("Training using DistDataLoader.") + additional_configs = {"pp_data_group": self._pp_data_group} return _DataLoader( train_dataset, batch_sampler=train_sampler, collate_fn=self.data_collator, num_workers=self.args.dataloader_num_workers, + **additional_configs, ) def _get_eval_sampler(self, eval_dataset: Dataset): @@ -1623,6 +1629,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader + additional_configs = {} if is_iterable_dataset: if self.args.dataset_world_size > 1 and eval_dataset is not None: eval_dataset = IterableDatasetShard( @@ -1632,11 +1639,10 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa num_processes=self.args.dataset_world_size, process_index=self.args.dataset_rank, ) + if self.args.distributed_dataloader: logger.info("Eval using DistDataLoader.") - additional_configs = {"eval": True, "is_iterable_dataset": True} - else: - additional_configs = {} + additional_configs = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group} return _DataLoader( eval_dataset, batch_size=self.args.per_device_eval_batch_size, @@ -1648,9 +1654,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa eval_sampler = self._get_eval_sampler(eval_dataset) if self.args.distributed_dataloader: logger.info("Eval using DistDataLoader.") - additional_configs = {"eval": True} - else: - additional_configs = {} + additional_configs = {"eval": True, "pp_data_group": self._pp_data_group} return _DataLoader( eval_dataset, batch_sampler=eval_sampler, @@ -1683,6 +1687,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: test_dataset = self._remove_unused_columns(test_dataset, description="test") _DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader + additional_config = {} if is_iterable_dataset: if self.args.dataset_world_size > 1 and test_dataset is not None: test_dataset = IterableDatasetShard( @@ -1695,9 +1700,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: if self.args.distributed_dataloader: logger.info("Test using DistDataLoader.") - additional_config = {"eval": True, "is_iterable_dataset": True} - else: - additional_config = {} + additional_config = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group} return _DataLoader( test_dataset, batch_size=self.args.per_device_eval_batch_size * self.world_size, @@ -1709,9 +1712,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: test_sampler = self._get_eval_sampler(test_dataset) if self.args.distributed_dataloader: logger.info("Test using DistDataLoader.") - additional_config = {"eval": True} - else: - additional_config = {} + additional_config = {"eval": True, "pp_data_group": self._pp_data_group} # We use the same batch_size as for eval. return _DataLoader( test_dataset,