Skip to content

Commit

Permalink
feat(trainer): data_path_visitor add date filter (#1086)
Browse files Browse the repository at this point in the history
Co-authored-by: gezhengqiang <[email protected]>
  • Loading branch information
gejielun and Gezq authored May 23, 2023
1 parent 1101610 commit 436e495
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
16 changes: 15 additions & 1 deletion fedlearner/trainer/data_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import tensorflow.compat.v1 as tf
from fedlearner.common import fl_logging
from fedlearner.common import trainer_master_service_pb2 as tm_pb
from fedlearner.common.common import convert_time_string_to_datetime
from fedlearner.data_join.data_block_visitor import DataBlockVisitor
from fedlearner.trainer.utils import match_date


kvstore_type = os.environ.get('KVSTORE_TYPE', 'etcd')
Expand Down Expand Up @@ -337,17 +339,29 @@ def __init__(self,
local_data_path: str,
wildcard: str,
epoch_num: int = 1,
shuffle_type=None):
shuffle_type=None,
start_date=None,
end_date=None):
fl_logging.info("create DataVisitor by data_path: %s", data_path)
if not tf.io.gfile.exists(data_path):
raise ValueError("data_path not found: %s"%data_path)

if start_date:
start_date = convert_time_string_to_datetime(str(start_date))
if end_date:
end_date = convert_time_string_to_datetime(str(end_date))
datablocks = []
for dirname, _, filenames in tf.io.gfile.walk(data_path):
for filename in filenames:
if not fnmatch(os.path.join(dirname, filename), wildcard):
continue
subdirname = os.path.relpath(dirname, data_path)
try:
cur_date = datetime.strptime(subdirname, '%Y%m%d')
if not match_date(cur_date, start_date, end_date):
continue
except Exception:
fl_logging.info('subdirname is not the format of time')
block_id = os.path.join(subdirname, filename)
datablock = _RawDataBlock(
id=block_id, data_path=os.path.join(dirname, filename),
Expand Down
4 changes: 3 additions & 1 deletion fedlearner/trainer/trainer_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ def _create_data_visitor(args):
args.local_data_path,
wildcard=args.data_path_wildcard,
epoch_num=args.epoch_num,
shuffle_type=shuffle_type)
shuffle_type=shuffle_type,
start_date=start_date,
end_date=end_date)
if not visitor:
raise ValueError("cannot found any data to train, "
"please specify [--data-source] or "
Expand Down
12 changes: 12 additions & 0 deletions fedlearner/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from __future__ import division
from __future__ import print_function

from datetime import datetime
from typing import Optional

import sys
import numpy as np

Expand Down Expand Up @@ -122,3 +125,12 @@ def _compute_slot_config(unsorted_slot_config, groups=None, use_fid_v2=False):
'slot_weight_offset': slot_weight_offset,
'output_size': offset
}


def match_date(cur_date: datetime, start_date: Optional[datetime],
end_date: Optional[datetime]) -> bool:
if start_date and cur_date < start_date:
return False
if end_date and cur_date > end_date:
return False
return True

0 comments on commit 436e495

Please sign in to comment.