forked from foundation-model-stack/fms-hf-tuning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change to custom aim callback to disable multiple instantiation for
FSDP. Add tracker factory. Signed-off-by: Dushyant Behl <[email protected]>
- Loading branch information
1 parent
4f93a7c
commit a28aa7b
Showing
3 changed files
with
102 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,29 @@ | ||
# Generic Tracker API | ||
|
||
from tuning.tracker.aimstack_tracker import AimStackTracker | ||
|
||
class Tracker: | ||
def __init__(self, tracker_config) -> None: | ||
self.config = tracker_config | ||
def __init__(self, name=None, tracker_config=None) -> None: | ||
if tracker_config is not None: | ||
self.config = tracker_config | ||
if name is None: | ||
self._name = "None" | ||
else: | ||
self._name = name | ||
|
||
def get_hf_callback(): | ||
return None | ||
|
||
def track(self, metric, name, stage): | ||
pass | ||
|
||
# Metadata passed here is supposed to be a KV object | ||
# Key being the name and value being the metric to track. | ||
def track_metadata(self, metadata=None): | ||
if metadata is None or not isinstance(metadata, dict): | ||
return | ||
for k, v in metadata.items(): | ||
self.track(name=k, metric=v) | ||
# Object passed here is supposed to be a KV object | ||
# for the parameters to be associated with a run | ||
def set_params(self, params, name): | ||
pass | ||
|
||
def get_tracker(tracker_name, tracker_config): | ||
if tracker_name == 'aim': | ||
if tracker_config is not None: | ||
tracker = AimStackTracker(tracker_config) | ||
else: | ||
tracker = Tracker() | ||
return tracker | ||
class TrackerFactory: | ||
def get_tracker(tracker_name, tracker_config): | ||
for T in Tracker.__subclasses__(): | ||
if T._name == tracker_name: | ||
return T(tracker_config) | ||
else: | ||
return Tracker() |