Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk committed Aug 15, 2023
2 parents a0fbdc4 + 2805af0 commit 9b97119
Show file tree
Hide file tree
Showing 73 changed files with 1,342 additions and 188 deletions.
5 changes: 5 additions & 0 deletions federatedscope/core/auxiliaries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,8 @@ def get_resource_info(filename):

def get_ds_rank():
return int(os.environ.get("RANK", "0"))


def add_prefix_to_path(prefix, path):
directory, file = os.path.split(path)
return os.path.join(directory, prefix + file)
1 change: 1 addition & 0 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def extend_fl_setting_cfg(cfg):
cfg.federate.restore_from = ''
cfg.federate.save_to = ''
cfg.federate.save_freq = -1
cfg.federate.save_client_model = False
cfg.federate.join_in_info = [
] # The information requirements (from server) for join_in
cfg.federate.sampler = 'uniform' # the strategy for sampling client
Expand Down
44 changes: 43 additions & 1 deletion federatedscope/core/configs/cfg_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ def extend_llm_cfg(cfg):
cfg.llm = CN()
cfg.llm.tok_len = 128

# ---------------------------------------------------------------------- #
# Cache for LLM
# ---------------------------------------------------------------------- #
cfg.llm.cache = CN()
cfg.llm.cache.model = ''

# ---------------------------------------------------------------------- #
# Chat tools for LLM
# ---------------------------------------------------------------------- #
cfg.llm.chat = CN()
cfg.llm.chat.max_history_len = 10
cfg.llm.chat.max_len = 100
Expand All @@ -32,6 +41,9 @@ def extend_llm_cfg(cfg):
cfg.llm.adapter = CN()
cfg.llm.adapter.use = False
cfg.llm.adapter.args = [{}]
# Move adapter to `cpu` after training, which can save memory but cost
# more time.
cfg.llm.adapter.mv_to_cpu = False

# ---------------------------------------------------------------------- #
# Offsite-tuning related options
Expand All @@ -43,9 +55,39 @@ def extend_llm_cfg(cfg):
cfg.llm.offsite_tuning.emu_l = 1 # Index of emulator layer left
cfg.llm.offsite_tuning.emu_r = 10 # Index of emulator layer right

# Used in `eval`
cfg.llm.offsite_tuning.eval_type = 'emu' # Choose one of `[emu, full]`

# Emulator alignment will use dataset in Server
cfg.llm.offsite_tuning.emu_align = CN()
cfg.llm.offsite_tuning.emu_align.use = False
cfg.llm.offsite_tuning.emu_align.restore_from = ''
cfg.llm.offsite_tuning.emu_align.save_to = ''
cfg.llm.offsite_tuning.emu_align.exit_after_align = False

# Server held-out data
cfg.llm.offsite_tuning.emu_align.data = CN()
cfg.llm.offsite_tuning.emu_align.data.root = 'data'
cfg.llm.offsite_tuning.emu_align.data.type = 'alpaca@llm'
cfg.llm.offsite_tuning.emu_align.data.splits = [0.8, 0.1, 0.1]

cfg.llm.offsite_tuning.emu_align.train = CN()
cfg.llm.offsite_tuning.emu_align.train.local_update_steps = 10
cfg.llm.offsite_tuning.emu_align.train.batch_or_epoch = 'batch'
cfg.llm.offsite_tuning.emu_align.train.lm_loss_weight = 0.1
cfg.llm.offsite_tuning.emu_align.train.kd_loss_weight = 0.9

cfg.llm.offsite_tuning.emu_align.train.optimizer = CN(new_allowed=True)
cfg.llm.offsite_tuning.emu_align.train.optimizer.type = 'SGD'
cfg.llm.offsite_tuning.emu_align.train.optimizer.lr = 0.01


def assert_llm_cfg(cfg):
pass
if cfg.llm.offsite_tuning.emu_align.use:
if cfg.llm.offsite_tuning.emu_align.restore_from != '':
logger.warning(
'Enabling `restore_from` in offsite_tuning emulator '
'alignment will skip training the emulator.')


register_config("llm", extend_llm_cfg)
3 changes: 3 additions & 0 deletions federatedscope/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ def print_trainer_meta_info(self):
meta_info = tuple([(val.name, getattr(self, val.name))
for val in sign])
return f'{self.__class__.__name__}{meta_info}'

def save_model(self, path, cur_round):
raise NotImplementedError
55 changes: 33 additions & 22 deletions federatedscope/core/trainers/trainer_pFedMe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
import copy
try:
import torch
except ImportError:
torch = None

from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
from federatedscope.core.optimizer import wrap_regularized_optimizer
from typing import Type


def get_trainable_parameter_list(model):
copied_param = []
for param in model.parameters():
if param.requires_grad:
copied_param.append(copy.deepcopy(param))
else:
copied_param.append(None)
return copied_param


def wrap_pFedMeTrainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
"""
Expand Down Expand Up @@ -81,7 +95,7 @@ def init_pFedMe_ctx(base_trainer):
# the local_model_tmp is used to be the referenced parameter when
# finding the approximate \theta in paper
# will be copied from model every run_routine
ctx.pFedMe_local_model_tmp = None
ctx.pFedMe_local_model_param_tmp = None


def _hook_on_fit_start_set_local_para_tmp(ctx):
Expand All @@ -95,7 +109,7 @@ def _hook_on_fit_start_set_local_para_tmp(ctx):
``wrap_regularized_optimizer`` and set compared parameter group
``ctx.pFedMe_outer_lr`` Initialize to \
``ctx.cfg.train.optimizer.lr``
``ctx.pFedMe_local_model_tmp`` Copy from ``ctx.model``
``ctx.pFedMe_local_model_param_tmp`` Copy from ``ctx.model``
================================== ===========================
"""
# the optimizer used in pFedMe is based on Moreau Envelopes regularization
Expand All @@ -106,13 +120,10 @@ def _hook_on_fit_start_set_local_para_tmp(ctx):
for g in ctx.optimizer.param_groups:
g['lr'] = ctx.cfg.personalization.lr
ctx.pFedMe_outer_lr = ctx.cfg.train.optimizer.lr

ctx.pFedMe_local_model_tmp = copy.deepcopy(ctx.model)
ctx.pFedMe_local_model_param_tmp = get_trainable_parameter_list(ctx.model)
# set the compared model data, then the optimizer will find approximate
# model using trainer.cfg.personalization.lr
compared_global_model_para = [{
"params": list(ctx.pFedMe_local_model_tmp.parameters())
}]
compared_global_model_para = [{"params": ctx.pFedMe_local_model_param_tmp}]
ctx.optimizer.set_compared_para_group(compared_global_model_para)


Expand Down Expand Up @@ -181,23 +192,22 @@ def _hook_on_epoch_end_update_local(ctx):
Attribute Operation
================================== ===========================
``ctx.model`` Update parameters by \
``ctx.pFedMe_local_model_tmp``
``ctx.pFedMe_local_model_param_tmp``
``ctx.optimizer`` Set compared parameter group
================================== ===========================
"""
# update local weight after finding approximate theta
for client_param, local_para_tmp in zip(
ctx.model.parameters(), ctx.pFedMe_local_model_tmp.parameters()):
local_para_tmp.data = local_para_tmp.data - \
ctx.optimizer.regular_weight * \
ctx.pFedMe_outer_lr * (local_para_tmp.data -
client_param.data)
for client_param, local_para_tmp in zip(ctx.model.parameters(),
ctx.pFedMe_local_model_param_tmp):
if client_param.requires_grad:
local_para_tmp.data = local_para_tmp.data - \
ctx.optimizer.regular_weight * \
ctx.pFedMe_outer_lr * (local_para_tmp.data -
client_param.data)

# set the compared model data, then the optimizer will find approximate
# model using trainer.cfg.personalization.lr
compared_global_model_para = [{
"params": list(ctx.pFedMe_local_model_tmp.parameters())
}]
compared_global_model_para = [{"params": ctx.pFedMe_local_model_param_tmp}]
ctx.optimizer.set_compared_para_group(compared_global_model_para)


Expand All @@ -209,12 +219,13 @@ def _hook_on_fit_end_update_local(ctx):
Attribute Operation
================================== ===========================
``ctx.model`` Update parameters by
``ctx.pFedMe_local_model_tmp``
``ctx.pFedMe_local_model_tmp`` Delete
``ctx.pFedMe_local_model_param_tmp``
``ctx.pFedMe_local_model_param_tmp`` Delete
================================== ===========================
"""
for param, local_para_tmp in zip(ctx.model.parameters(),
ctx.pFedMe_local_model_tmp.parameters()):
param.data = local_para_tmp.data
ctx.pFedMe_local_model_param_tmp):
if param.requires_grad:
param.data = local_para_tmp.data

del ctx.pFedMe_local_model_tmp
del ctx.pFedMe_local_model_param_tmp
16 changes: 12 additions & 4 deletions federatedscope/core/workers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
from federatedscope.core.secret_sharing import AdditiveSecretSharing
from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \
calculate_time_cost
calculate_time_cost, add_prefix_to_path
from federatedscope.core.workers.base_client import BaseClient

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -551,9 +551,17 @@ def callback_funcs_for_evaluate(self, message: Message):
forms=['raw'],
return_raw=True)
logger.info(formatted_eval_res)
self._monitor.update_best_result(self.best_results,
formatted_eval_res['Results_raw'],
results_type=f"client #{self.ID}")
update_best_this_round = self._monitor.update_best_result(
self.best_results,
formatted_eval_res['Results_raw'],
results_type=f"client #{self.ID}",
)

if update_best_this_round and self._cfg.federate.save_client_model:
path = add_prefix_to_path(f'client_{self.ID}_',
self._cfg.federate.save_to)
self.trainer.save_model(path, self.state)

self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res['Results_raw'])
self.early_stopper.track_and_check(self.history_results[
Expand Down
74 changes: 52 additions & 22 deletions federatedscope/core/workers/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from federatedscope.core.auxiliaries.aggregator_builder import get_aggregator
from federatedscope.core.auxiliaries.sampler_builder import get_sampler
from federatedscope.core.auxiliaries.utils import merge_dict_of_results, \
Timeout, merge_param_dict, get_ds_rank
Timeout, merge_param_dict, add_prefix_to_path, get_ds_rank
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
from federatedscope.core.secret_sharing import AdditiveSecretSharing
from federatedscope.core.workers.base_server import BaseServer
Expand Down Expand Up @@ -93,6 +93,8 @@ def __init__(self,
if self._cfg.federate.share_local_model \
and not self._cfg.federate.process_num > 1 \
and not self._cfg.llm.deepspeed.use:
if self._cfg.train.is_enable_half:
model = model.half()
# put the model to the specified device
model.to(device)
# Build aggregator
Expand All @@ -107,7 +109,8 @@ def __init__(self,
f' {self._cfg.federate.restore_from}.')
else:
_ = self.aggregator.load_model(self._cfg.federate.restore_from)
logger.info("Restored the model from {}-th round's ckpt")
logger.info(f"Restored the model from "
f"{self._cfg.federate.restore_from}")

if int(config.model.model_num_per_trainer) != \
config.model.model_num_per_trainer or \
Expand Down Expand Up @@ -407,7 +410,8 @@ def check_and_save(self):
if self.state != self.total_round_num and \
self.state % self._cfg.federate.save_freq == 0 and \
self._cfg.federate.save_freq > 0:
path = f'{self.state}_' + self._cfg.federate.save_to
path = add_prefix_to_path(f'{self.state}_',
self._cfg.federate.save_to)
self.aggregator.save_model(path, self.state)

if should_stop or self.state == self.total_round_num:
Expand Down Expand Up @@ -528,10 +532,11 @@ def save_best_results(self):
"""
To Save the best evaluation results.
"""

# Save final round model
if self._cfg.federate.save_to != '':
self.aggregator.save_model(f'final_{self._cfg.federate.save_to}',
self.state)
self.aggregator.save_model(
add_prefix_to_path('final_', self._cfg.federate.save_to),
self.state)
formatted_best_res = self._monitor.format_eval_res(
results=self.best_results,
rnd="Final",
Expand Down Expand Up @@ -609,29 +614,42 @@ def merge_eval_results_from_all_clients(self):
del formatted_logs[key]
logger.info(formatted_logs)
formatted_logs_all_set.update(formatted_logs)
update_best_this_round = self._monitor.update_best_result(
self._monitor.update_best_result(
self.best_results,
metrics_all_clients,
results_type="unseen_client_best_individual"
if merge_type == "unseen" else "client_best_individual")

self._monitor.save_formatted_results(formatted_logs)

update_prior = -1 # Bigger the higher priority
update_prior_list = ['fairness', 'avg', 'weighted_avg']
update_best_this_round = False
for form in self._cfg.eval.report:
if form in update_prior_list:
update_prior_tmp = update_prior_list.index(form)
else:
update_prior_tmp = -1
if form != "raw":
metric_name = form + "_unseen" if merge_type == \
"unseen" else form
update_best_this_round_tmp = \
self._monitor.update_best_result(
self.best_results,
formatted_logs[f"Results_{metric_name}"],
results_type=f"unseen_client_summarized_{form}"
if merge_type == "unseen" else
f"client_summarized_{form}")
if update_prior_tmp >= update_prior:
update_prior = update_prior_tmp
update_best_this_round = update_best_this_round_tmp
if update_best_this_round:
# When the frequency of evaluations is high,
# the frequency of writing to disk in the early stages
# may also be high
if self._cfg.federate.save_to != '':
self.aggregator.save_model(self._cfg.federate.save_to,
self.state)
self._monitor.save_formatted_results(formatted_logs)
for form in self._cfg.eval.report:
if form != "raw":
metric_name = form + "_unseen" if merge_type == \
"unseen" else form
self._monitor.update_best_result(
self.best_results,
formatted_logs[f"Results_{metric_name}"],
results_type=f"unseen_client_summarized_{form}"
if merge_type == "unseen" else
f"client_summarized_{form}")

return formatted_logs_all_set

Expand Down Expand Up @@ -676,11 +694,23 @@ def broadcast_model_para(self,
self.models[model_idx_i])

skip_broadcast = self._cfg.federate.method in ["local", "global"]
if self.model_num > 1:
model_para = [{} if skip_broadcast else model.state_dict()
for model in self.models]
if self._cfg.federate.share_local_model and not \
self._cfg.federate.online_aggr:
if self.model_num > 1:
model_para = [
{} if skip_broadcast else copy.deepcopy(model.state_dict())
for model in self.models
]
else:
model_para = {} if skip_broadcast else copy.deepcopy(
self.models[0].state_dict())
else:
model_para = {} if skip_broadcast else self.models[0].state_dict()
if self.model_num > 1:
model_para = [{} if skip_broadcast else model.state_dict()
for model in self.models]
else:
model_para = {} if skip_broadcast else self.models[
0].state_dict()

# quantization
if msg_type == 'model_para' and not skip_broadcast and \
Expand Down
Loading

0 comments on commit 9b97119

Please sign in to comment.