Skip to content

Commit

Permalink
feat: add scanner tracker
Browse files Browse the repository at this point in the history
Signed-off-by: Angel Luu <[email protected]>
  • Loading branch information
aluu317 committed Dec 17, 2024
1 parent 4441948 commit 37f61bb
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 7 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ flash-attn = ["flash-attn>=2.5.3,<3.0"]
aim = ["aim>=3.19.0,<4.0"]
fms-accel = ["fms-acceleration>=0.1"]
gptq-dev = ["auto_gptq>0.4.2", "optimum>=1.15.0"]
scanner-dev = ["HFResourceScanner>=0.1.0"]


[tool.setuptools.packages.find]
Expand Down
34 changes: 33 additions & 1 deletion tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# Third Party
import pytest
from transformers.utils.import_utils import _is_package_available

# First Party
from build.accelerate_launch import main
Expand All @@ -31,7 +32,10 @@
USER_ERROR_EXIT_CODE,
INTERNAL_ERROR_EXIT_CODE,
)
from tuning.config.tracker_configs import FileLoggingTrackerConfig
from tuning.config.tracker_configs import (
FileLoggingTrackerConfig,
HFResourceScannerConfig,
)

SCRIPT = "tuning/sft_trainer.py"
MODEL_NAME = "Maykeye/TinyLLama-v0"
Expand Down Expand Up @@ -246,6 +250,34 @@ def test_lora_with_lora_post_process_for_vllm_set_to_true():
assert os.path.exists(new_embeddings_file_path)


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_launch_with_HFResourceScanner_enabled():
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"output_dir": tempdir,
"save_model_dir": tempdir,
"lora_post_process_for_vllm": True,
"gradient_accumulation_steps": 1,
"trackers": ["hf_resource_scanner"],
},
}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0

scanner_outfile = os.path.join(
tempdir, HFResourceScannerConfig.scanner_output_filename
)
assert os.path.exists(scanner_outfile)


def test_bad_script_path():
"""Check for appropriate error for an invalid training script location"""
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
26 changes: 23 additions & 3 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ def test_parse_arguments(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_copy)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert data_args.dataset_text_field == "output"
Expand All @@ -381,6 +382,7 @@ def test_parse_arguments_defaults(job_config):
_,
_,
_,
_,
) = sft_trainer.parse_arguments(parser, job_config_defaults)
assert str(model_args.torch_dtype) == "torch.bfloat16"
assert model_args.use_flash_attn is False
Expand All @@ -391,14 +393,14 @@ def test_parse_arguments_peft_method(job_config):
parser = sft_trainer.get_parser()
job_config_pt = copy.deepcopy(job_config)
job_config_pt["peft_method"] = "pt"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_pt
)
assert isinstance(tune_config, peft_config.PromptTuningConfig)

job_config_lora = copy.deepcopy(job_config)
job_config_lora["peft_method"] = "lora"
_, _, _, _, tune_config, _, _, _, _, _, _ = sft_trainer.parse_arguments(
_, _, _, _, tune_config, _, _, _, _, _, _, _ = sft_trainer.parse_arguments(
parser, job_config_lora
)
assert isinstance(tune_config, peft_config.LoraConfig)
Expand Down Expand Up @@ -881,12 +883,18 @@ def _test_run_inference(checkpoint_path):


def _validate_training(
tempdir, check_eval=False, train_logs_file="training_logs.jsonl"
tempdir,
check_eval=False,
train_logs_file="training_logs.jsonl",
check_scanner_file=False,
):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_logs_file_path = "{}/{}".format(tempdir, train_logs_file)
_validate_logfile(train_logs_file_path, check_eval)

if check_scanner_file:
_validate_hf_resource_scanner_file(tempdir)


def _validate_logfile(log_file_path, check_eval=False):
train_log_contents = ""
Expand All @@ -901,6 +909,18 @@ def _validate_logfile(log_file_path, check_eval=False):
assert "validation_loss" in train_log_contents


def _validate_hf_resource_scanner_file(tempdir):
scanner_file_path = os.path.join(tempdir, "scanner_output.json")
assert os.path.exists(scanner_file_path)
assert os.path.getsize(scanner_file_path) > 0

scanner_contents = ""
with open(scanner_file_path, encoding="utf-8") as f:
scanner_contents = f.read()

assert "ResourceScanner Memory Data:" in scanner_contents


def _get_checkpoint_path(dir_path):
return os.path.join(dir_path, "checkpoint-5")

Expand Down
85 changes: 85 additions & 0 deletions tests/trackers/test_hf_resource_scanner_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/


# Standard
import copy
import tempfile

# Third Party
from transformers.utils.import_utils import _is_package_available
import pytest

# First Party
from tests.test_sft_trainer import (
DATA_ARGS,
MODEL_ARGS,
TRAIN_ARGS,
_get_checkpoint_path,
_test_run_causallm_ft,
_test_run_inference,
_validate_training,
)

# Local
from tuning import sft_trainer
from tuning.config.tracker_configs import HFResourceScannerConfig, TrackerConfigFactory

## HF Resource Scanner Tracker Tests


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_run_with_hf_resource_scanner_tracker():
"""Ensure that training succeeds with a good tracker name"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.trackers = ["hf_resource_scanner"]

_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir)
_test_run_inference(_get_checkpoint_path(tempdir))


@pytest.mark.skipif(
not _is_package_available("HFResourceScanner"),
reason="Only runs if HFResourceScanner is installed",
)
def test_sample_run_with_hf_resource_scanner_updated_filename():
"""Ensure that hf_resource_scanner output filename can be updated"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

train_args.trackers = ["hf_resource_scanner"]

scanner_output_file = "scanner_output.json"

tracker_configs = TrackerConfigFactory(
hf_resource_scanner_config=HFResourceScannerConfig(
scanner_output_filename=scanner_output_file
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir, check_scanner_file=True)
6 changes: 6 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from dataclasses import dataclass


@dataclass
class HFResourceScannerConfig:
scanner_output_filename: str = "scanner_output.json"


@dataclass
class FileLoggingTrackerConfig:
training_logs_filename: str = "training_logs.jsonl"
Expand Down Expand Up @@ -67,3 +72,4 @@ def __post_init__(self):
class TrackerConfigFactory:
file_logger_config: FileLoggingTrackerConfig = None
aim_config: AimConfig = None
hf_resource_scanner_config: HFResourceScannerConfig = None
12 changes: 11 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from tuning.config.tracker_configs import (
AimConfig,
FileLoggingTrackerConfig,
HFResourceScannerConfig,
TrackerConfigFactory,
)
from tuning.data.setup_dataprocessor import process_dataargs
Expand Down Expand Up @@ -431,6 +432,7 @@ def get_parser():
peft_config.PromptTuningConfig,
FileLoggingTrackerConfig,
AimConfig,
HFResourceScannerConfig,
QuantizedLoraConfig,
FusedOpsAndKernelsConfig,
AttentionAndDistributedPackingConfig,
Expand Down Expand Up @@ -477,6 +479,8 @@ def parse_arguments(parser, json_config=None):
Configuration for training log file.
AimConfig
Configuration for AIM stack.
HFResourceScannerConfig
Configuration for HFResourceScanner.
QuantizedLoraConfig
Configuration for quantized LoRA (a form of PEFT).
FusedOpsAndKernelsConfig
Expand All @@ -496,6 +500,7 @@ def parse_arguments(parser, json_config=None):
prompt_tuning_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -512,6 +517,7 @@ def parse_arguments(parser, json_config=None):
prompt_tuning_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -537,6 +543,7 @@ def parse_arguments(parser, json_config=None):
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -558,6 +565,7 @@ def main():
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand All @@ -570,7 +578,7 @@ def main():
logger.debug(
"Input args parsed: \
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
tune_config %s, file_logger_config, %s aim_config %s, \
tune_config %s, file_logger_config %s, aim_config %s, hf_resource_scanner_config %s, \
quantized_lora_config %s, fusedops_kernels_config %s, \
attention_and_distributed_packing_config %s exp_metadata %s",
model_args,
Expand All @@ -580,6 +588,7 @@ def main():
tune_config,
file_logger_config,
aim_config,
hf_resource_scanner_config,
quantized_lora_config,
fusedops_kernels_config,
attention_and_distributed_packing_config,
Expand Down Expand Up @@ -611,6 +620,7 @@ def main():

combined_tracker_configs.file_logger_config = file_logger_config
combined_tracker_configs.aim_config = aim_config
combined_tracker_configs.hf_resource_scanner_config = hf_resource_scanner_config

if training_args.output_dir:
os.makedirs(training_args.output_dir, exist_ok=True)
Expand Down
47 changes: 47 additions & 0 deletions tuning/trackers/hf_resource_scanner_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import logging

# Third Party
from HFResourceScanner import Scanner # pylint: disable=import-error

# Local
from .tracker import Tracker
from tuning.config.tracker_configs import HFResourceScannerConfig


class HFResourceScannerTracker(Tracker):
def __init__(self, tracker_config: HFResourceScannerConfig):
"""Tracker which encodes callback to scan for resources using HFResourceScanner
Args:
tracker_config (HFResourceScannerConfig): An instance of HFResourceScanner
tracker config which contains the location of output file.
"""
super().__init__(name="hf_resource_scanner", tracker_config=tracker_config)
# Get logger with root log level
self.logger = logging.getLogger()

def get_hf_callback(self):
"""Returns the HFResourceScanner object associated with this tracker.
Returns:
HFResourceScanner: The file logging callback which inherits
transformers.TrainerCallback and records the metrics to a file.
"""
output_filename = self.config.scanner_output_filename
self.hf_callback = Scanner(output_fmt=output_filename)
return self.hf_callback
Loading

0 comments on commit 37f61bb

Please sign in to comment.