Skip to content

Commit

Permalink
Fix offsite tuning eval (#674)
Browse files Browse the repository at this point in the history
  • Loading branch information
rayrayraykk authored Aug 1, 2023
1 parent ec9026d commit c09bfe0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
25 changes: 13 additions & 12 deletions federatedscope/llm/misc/fschat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,23 @@ def __init__(self, config):
self.tokenizer, _ = get_tokenizer(model_name, config.data.root,
config.llm.tok_len)
self.model = get_llm(config)
if config.llm.offsite_tuning.use:
from federatedscope.llm.offsite_tuning.utils import \
wrap_offsite_tuning_for_eval
self.model = wrap_offsite_tuning_for_eval(self.model, config)

self.device = f'cuda:{config.device}'
self.add_special_tokens = True

try:
ckpt = torch.load(config.federate.save_to, map_location='cpu')
if 'model' and 'cur_round' in ckpt:
self.model.load_state_dict(ckpt['model'])
else:
self.model.load_state_dict(ckpt)
except Exception as error:
print(f"{error}, will use raw model.")
if config.llm.offsite_tuning.use:
from federatedscope.llm.offsite_tuning.utils import \
wrap_offsite_tuning_for_eval
self.model = wrap_offsite_tuning_for_eval(self.model, config)
else:
try:
ckpt = torch.load(config.federate.save_to, map_location='cpu')
if 'model' and 'cur_round' in ckpt:
self.model.load_state_dict(ckpt['model'])
else:
self.model.load_state_dict(ckpt)
except Exception as error:
print(f"{error}, will use raw model.")

if config.train.is_enable_half:
self.model.half()
Expand Down
3 changes: 3 additions & 0 deletions federatedscope/llm/offsite_tuning/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def __init__(self,
monitor=Monitor(
config,
monitored_object=self))
# No need for this attr
if hasattr(adap_model, 'teacher'):
del adap_model.teacher

self.raw_model = model
super(OffsiteTuningServer,
Expand Down
16 changes: 15 additions & 1 deletion federatedscope/llm/offsite_tuning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def build_cfg_for_alignment(config):
logger.info('Alignment finished!')

# Save aligned model
del adap_model.teacher
adap_model.save_model(cfg.llm.offsite_tuning.emu_align.save_to)

# Make student un-trainable
Expand All @@ -302,7 +303,8 @@ def wrap_offsite_tuning_for_eval(model, config):
emulator_r=emulator_r,
**offsite_tuning_kwargs)
# Load kd model if ckpt exits
if config.llm.offsite_tuning.emu_align.use:
if config.llm.offsite_tuning.emu_align.use and \
config.llm.offsite_tuning.eval_type == 'emu':
if config.llm.offsite_tuning.emu_align.restore_from != '':
try:
ckpt = torch.load(
Expand All @@ -314,9 +316,21 @@ def wrap_offsite_tuning_for_eval(model, config):
except Exception as error:
logger.warning(error)

# Load ckpt for eval
try:
ckpt = torch.load(config.federate.save_to, map_location='cpu')
if 'model' and 'cur_round' in ckpt:
adap_model.load_state_dict(ckpt['model'])
else:
adap_model.load_state_dict(ckpt)
except Exception as error:
logger.warning(f"{error}, will use raw model.")

if config.llm.offsite_tuning.eval_type == 'emu':
model = adap_model
del model.teacher
elif config.llm.offsite_tuning.eval_type == 'full':
# Raw model load adapter from adapter_and_emulator
new_model_state_dict = model.state_dict()
for key, value in zip(model.state_dict().keys(),
adap_model.state_dict().values()):
Expand Down

0 comments on commit c09bfe0

Please sign in to comment.