Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental Feature]DeepSpeed for LLM with standalone and distributed (#653) #684

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions federatedscope/core/auxiliaries/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def filter(self, record):
return True


def update_logger(cfg, clear_before_add=False):
def update_logger(cfg, clear_before_add=False, rank=0):
root_logger = logging.getLogger("federatedscope")

# clear all existing handlers and add the default stream
Expand All @@ -70,11 +70,16 @@ def update_logger(cfg, clear_before_add=False):
root_logger.addHandler(handler)

# update level
if cfg.verbose > 0:
logging_level = logging.INFO
if rank == 0:
if cfg.verbose > 0:
logging_level = logging.INFO
else:
logging_level = logging.WARN
root_logger.warning("Skip DEBUG/INFO messages")
else:
logging_level = logging.WARN
root_logger.warning("Skip DEBUG/INFO messages")
root_logger.warning(f"Using deepspeed, and we will disable "
f"subprocesses {rank} logger.")
logging_level = logging.CRITICAL
root_logger.setLevel(logging_level)

# ================ create outdir to save log, exp_config, models, etc,.
Expand All @@ -88,6 +93,9 @@ def update_logger(cfg, clear_before_add=False):
cfg.expname = f"{cfg.expname}_{cfg.expname_tag}"
cfg.outdir = os.path.join(cfg.outdir, cfg.expname)

if rank != 0:
return

# if exist, make directory with given name and time
if os.path.isdir(cfg.outdir) and os.path.exists(cfg.outdir):
outdir = os.path.join(cfg.outdir, "sub_exp" +
Expand Down
4 changes: 4 additions & 0 deletions federatedscope/core/auxiliaries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def get_resource_info(filename):
return device_info


def get_ds_rank():
return int(os.environ.get("RANK", "0"))


def add_prefix_to_path(prefix, path):
directory, file = os.path.split(path)
return os.path.join(directory, prefix + file)
4 changes: 4 additions & 0 deletions federatedscope/core/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def parse_args(args=None):
required=False,
default=None,
type=str)
parser.add_argument('--local_rank',
type=int,
default=-1,
help='local rank passed from distributed launcher')
parser.add_argument(
'--help',
nargs="?",
Expand Down
8 changes: 8 additions & 0 deletions federatedscope/core/configs/cfg_llm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging

from federatedscope.core.configs.config import CN
Expand Down Expand Up @@ -26,6 +27,13 @@ def extend_llm_cfg(cfg):
cfg.llm.chat.max_history_len = 10
cfg.llm.chat.max_len = 100

# ---------------------------------------------------------------------- #
# Deepspeed related options
# ---------------------------------------------------------------------- #
cfg.llm.deepspeed = CN()
cfg.llm.deepspeed.use = False
cfg.llm.deepspeed.ds_config = ''

# ---------------------------------------------------------------------- #
# Adapters for LLM
# ---------------------------------------------------------------------- #
Expand Down
11 changes: 8 additions & 3 deletions federatedscope/core/fed_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from federatedscope.core.workers import Server, Client
from federatedscope.core.gpu_manager import GPUManager
from federatedscope.core.auxiliaries.model_builder import get_model
from federatedscope.core.auxiliaries.utils import get_resource_info
from federatedscope.core.auxiliaries.utils import get_resource_info, \
get_ds_rank
from federatedscope.core.auxiliaries.feat_engr_builder import \
get_feat_engr_wrapper

Expand Down Expand Up @@ -133,6 +134,10 @@ def run(self):
"""
raise NotImplementedError

@property
def ds_rank(self):
return get_ds_rank()

def _setup_server(self, resource_info=None, client_resource_info=None):
"""
Set up and instantiate the server.
Expand Down Expand Up @@ -518,7 +523,7 @@ def _set_up(self):

self.server_address = {
'host': self.cfg.distribute.server_host,
'port': self.cfg.distribute.server_port
'port': self.cfg.distribute.server_port + self.ds_rank
}
if self.cfg.distribute.role == 'server':
self.server = self._setup_server(resource_info=sampled_resource)
Expand All @@ -527,7 +532,7 @@ def _set_up(self):
# the server has been set up and number with #0
self.client_address = {
'host': self.cfg.distribute.client_host,
'port': self.cfg.distribute.client_port
'port': self.cfg.distribute.client_port + self.ds_rank
}
self.client = self._setup_client(resource_info=sampled_resource)

Expand Down
19 changes: 10 additions & 9 deletions federatedscope/core/trainers/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@

class GeneralTorchTrainer(Trainer):
def get_model_para(self):
if self.cfg.federate.process_num > 1:
if self.cfg.federate.process_num > 1 or \
self.cfg.federate.share_local_model or \
self.cfg.llm.deepspeed.use:
return self._param_filter(self.ctx.model.state_dict())
else:
return self._param_filter(
self.ctx.model.state_dict() if self.cfg.federate.
share_local_model else self.ctx.model.cpu().state_dict())
return self._param_filter(self.ctx.model.cpu().state_dict())

def setup_data(self, ctx):
"""
Expand Down Expand Up @@ -463,8 +463,9 @@ def discharge_model(self):
Discharge the model from GPU device
"""
# Avoid memory leak
if not self.cfg.federate.share_local_model:
if torch is None:
pass
else:
self.ctx.model.to(torch.device("cpu"))
if torch is None:
return

if not self.cfg.federate.share_local_model and \
not self.cfg.llm.deepspeed.use:
self.ctx.model.to(torch.device("cpu"))
5 changes: 5 additions & 0 deletions federatedscope/core/workers/base_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from federatedscope.core.monitors.monitor import Monitor
from federatedscope.core.auxiliaries.utils import get_ds_rank


class Worker(object):
Expand Down Expand Up @@ -68,3 +69,7 @@ def mode(self):
@mode.setter
def mode(self, value):
self._mode = value

@property
def ds_rank(self):
return get_ds_rank()
8 changes: 5 additions & 3 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
from federatedscope.core.secret_sharing import AdditiveSecretSharing
from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \
calculate_time_cost, add_prefix_to_path
calculate_time_cost, add_prefix_to_path, get_ds_rank
from federatedscope.core.workers.base_client import BaseClient

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if get_ds_rank() == 0:
logger.setLevel(logging.INFO)


class Client(BaseClient):
Expand Down Expand Up @@ -560,7 +561,8 @@ def callback_funcs_for_evaluate(self, message: Message):
if update_best_this_round and self._cfg.federate.save_client_model:
path = add_prefix_to_path(f'client_{self.ID}_',
self._cfg.federate.save_to)
self.trainer.save_model(path, self.state)
if self.ds_rank == 0:
self.trainer.save_model(path, self.state)

self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res['Results_raw'])
Expand Down
18 changes: 11 additions & 7 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator
from federatedscope.core.auxiliaries.sampler_builder import get_sampler
from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \
Timeout, merge_param_dict, add_prefix_to_path
Timeout, merge_param_dict, add_prefix_to_path, get_ds_rank
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
from federatedscope.core.secret_sharing import AdditiveSecretSharing
from federatedscope.core.workers.base_server import BaseServer

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if get_ds_rank() == 0:
logger.setLevel(logging.INFO)


class Server(BaseServer):
Expand Down Expand Up @@ -90,7 +91,8 @@ def __init__(self,
self._monitor.the_larger_the_better)

if self._cfg.federate.share_local_model \
and not self._cfg.federate.process_num > 1:
and not self._cfg.federate.process_num > 1 \
and not self._cfg.llm.deepspeed.use:
if self._cfg.train.is_enable_half:
model = model.half()
# put the model to the specified device
Expand Down Expand Up @@ -410,7 +412,8 @@ def check_and_save(self):
self._cfg.federate.save_freq > 0:
path = add_prefix_to_path(f'{self.state}_',
self._cfg.federate.save_to)
self.aggregator.save_model(path, self.state)
if self.ds_rank == 0:
self.aggregator.save_model(path, self.state)

if should_stop or self.state == self.total_round_num:
logger.info('Server: Final evaluation is finished! Starting '
Expand Down Expand Up @@ -531,7 +534,7 @@ def save_best_results(self):
To Save the best evaluation results.
"""
# Save final round model
if self._cfg.federate.save_to != '':
if self._cfg.federate.save_to != '' and self.ds_rank == 0:
self.aggregator.save_model(
add_prefix_to_path('final_', self._cfg.federate.save_to),
self.state)
Expand Down Expand Up @@ -645,7 +648,7 @@ def merge_eval_results_from_all_clients(self):
# When the frequency of evaluations is high,
# the frequency of writing to disk in the early stages
# may also be high
if self._cfg.federate.save_to != '':
if self._cfg.federate.save_to != '' and self.ds_rank == 0:
self.aggregator.save_model(self._cfg.federate.save_to,
self.state)

Expand Down Expand Up @@ -843,7 +846,8 @@ def trigger_for_start(self):
self.models[0])) / 1024.0 * 8.
except Exception as error:
model_size = 1.0
logger.warning(f'{error} in calculate model size.')
logger.warning(f'Error {error} in calculate model '
f'size.')
else:
# TODO: calculate model size for TF Model
model_size = 1.0
Expand Down
46 changes: 46 additions & 0 deletions federatedscope/llm/baseline/deepspeed/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
{
"train_batch_size": 4,
"steps_per_print": 2000,
"fp16": {"enabled": true},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": false
},
"wall_clock_breakdown": false
}
51 changes: 51 additions & 0 deletions federatedscope/llm/baseline/deepspeed/llama_client.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# deepspeed --master_port 29501 federatedscope/main.py --cfg federatedscope/llm/baseline/deepspeed/llama_client.yaml
use_gpu: True
early_stop:
patience: 0
federate:
mode: distributed
client_num: 1
total_round_num: 500
save_to: "llama_ds.ckpt"
data:
root: data/
type: 'alpaca@llm'
splits: [0.98,0.01,0.01]
splitter: 'iid'
distribute:
use: True
server_host: '127.0.0.1'
server_port: 50051 # [50051, 50051 + client_num]
client_host: '127.0.0.1'
client_port: 50061 # [50061, 50061 + client_num]
role: 'client'
data_idx: 1
grpc_max_send_message_length: 1048576000
grpc_max_receive_message_length: 1048576000
llm:
tok_len: 1000
chat:
max_len: 2000
adapter:
use: True
args: [ { 'adapter_package': 'peft', 'adapter_method': 'lora', 'r': 8, 'lora_alpha': 32, 'lora_dropout': 0.1 } ]
deepspeed:
use: True
ds_config: 'federatedscope/llm/baseline/deepspeed/ds_config.json'
dataloader:
batch_size: 1
model:
type: 'decapoda-research/llama-13b-hf@huggingface_llm'
train:
local_update_steps: 30
batch_or_epoch: batch
optimizer:
lr: 0.0003
weight_decay: 0.0
is_enable_half: True
trainer:
type: llmtrainer
eval:
freq: 5
metrics: ['loss']
count_flops: False
42 changes: 42 additions & 0 deletions federatedscope/llm/baseline/deepspeed/llama_ds.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# deepspeed federatedscope/main.py --cfg federatedscope/llm/baseline/deepspeed/llama_ds.yaml
use_gpu: True
device: 0
early_stop:
patience: 0
federate:
mode: standalone
client_num: 1
total_round_num: 500
save_to: "llama_ds.ckpt"
data:
root: data/
type: 'alpaca@llm'
splits: [0.98,0.01,0.01]
splitter: 'iid'
llm:
tok_len: 1000
chat:
max_len: 2000
adapter:
use: True
args: [ { 'adapter_package': 'peft', 'adapter_method': 'lora', 'r': 8, 'lora_alpha': 32, 'lora_dropout': 0.1 } ]
deepspeed:
use: True
ds_config: 'federatedscope/llm/baseline/deepspeed/ds_config.json'
dataloader:
batch_size: 1
model:
type: 'decapoda-research/llama-13b-hf@huggingface_llm'
train:
local_update_steps: 30
batch_or_epoch: batch
optimizer:
lr: 0.0003
weight_decay: 0.0
is_enable_half: True
trainer:
type: llmtrainer
eval:
freq: 5
metrics: ['loss']
count_flops: False
Loading