Skip to content

Commit

Permalink
typing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
kalessin committed Aug 21, 2023
1 parent 1890be9 commit 3d3d20b
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 66 deletions.
2 changes: 1 addition & 1 deletion lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
result=0
flake8 shub_workflow/ tests/ --application-import-names=shub_workflow --import-order-style=pep8
result=$(($result | $?))
mypy --ignore-missing-imports --disable-error-code=method-assign shub_workflow/ tests/
mypy --ignore-missing-imports --disable-error-code=method-assign --check-untyped-defs shub_workflow/ tests/
result=$(($result | $?))
exit $result
2 changes: 1 addition & 1 deletion shub_workflow/clone_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def add_argparser_options(self):

def run(self):
if self.args.key:
keys = filter(lambda x: not self.is_cloned_by_jobkey(x), self.args.key)
keys = list(filter(lambda x: not self.is_cloned_by_jobkey(x), self.args.key))
elif self.args.tag_spider:
keys = []
project_id, tag, spider = self.args.tag_spider.split("/")
Expand Down
3 changes: 3 additions & 0 deletions shub_workflow/crawl.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def __init__(self):
self.__next_job_seq = 1
self._jobuids = self.create_dupe_filter()

def get_delayed_jobs(self) -> List[FullJobParams]:
return deepcopy(self.__delayed_jobs)

@classmethod
def create_dupe_filter(cls) -> DupesFilterProtocol:
return BloomFilter(max_elements=1e6, error_rate=1e-6)
Expand Down
4 changes: 2 additions & 2 deletions shub_workflow/deliver/futils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ def upload_file(path, dest, aws_key=None, aws_secret=None, aws_token=None, **kwa
gcstorage.upload_file(path, dest)


def get_glob(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs):
def get_glob(path, aws_key=None, aws_secret=None, aws_token=None, **kwargs) -> List[str]:
region = kwargs.pop("region", None)
if check_s3_path(path):
fs = S3FileSystem(**s3_credentials(aws_key, aws_secret, aws_token, region), **kwargs)
fp = [_S3_ATTRIBUTE + p for p in fs.glob(s3_path(path))]
else:
fp = iglob(path)
fp = list(iglob(path))

return fp

Expand Down
8 changes: 4 additions & 4 deletions shub_workflow/deliver/gcstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def set_credential_file_environ(module, resource, check_exists=True):
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = credfile


def upload_file(src_path, dest_path):
def upload_file(src_path: str, dest_path: str):
storage_client = storage.Client()
try:
bucket_name, destination_blob_name = _GS_FOLDER_RE.match(dest_path).groups()
except AttributeError:
m = _GS_FOLDER_RE.match(dest_path)
if m is None:
raise ValueError(f"Invalid destination {dest_path}")
bucket_name, destination_blob_name = m.groups()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_filename(src_path, retry=storage.retry.DEFAULT_RETRY)
Expand Down
4 changes: 3 additions & 1 deletion shub_workflow/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def run_job(self, job: TaskId, is_retry=False) -> Optional[JobKey]:
if task is not None:
idx = jobconf["index"]
return task.run(self, is_retry, index=idx)
return None
raise RuntimeError(f"Failed to run task {job}")

def _must_wait_time(self, job: TaskId) -> bool:
status = self.__pending_jobs[job]
Expand Down Expand Up @@ -300,6 +300,7 @@ def run_pending_jobs(self):
if job_can_run:
try:
jobid = self.run_job(task_id, status["is_retry"])
assert jobid is not None, f"Failed to run task {task_id}"
except Exception:
self._release_resources(task_id)
raise
Expand Down Expand Up @@ -330,6 +331,7 @@ def run_pending_jobs(self):
if job_can_run:
try:
jobid = self.run_job(task_id, status["is_retry"])
assert jobid is not None, f"Failed to run task {task_id}"
except Exception:
self._release_resources(task_id)
raise
Expand Down
5 changes: 4 additions & 1 deletion shub_workflow/graph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class JobGraphDict(TypedDict):
origin: NotRequired[TaskId]
index: NotRequired[int]

spider: NotRequired[str]
spider_args: NotRequired[Dict[str, str]]


class BaseTask(abc.ABC):
def __init__(
Expand Down Expand Up @@ -283,7 +286,7 @@ def get_spider_args(self):
spider_args.update({"job_settings": self.__job_settings})
return spider_args

def as_jobgraph_dict(self):
def as_jobgraph_dict(self) -> JobGraphDict:
jdict = super().as_jobgraph_dict()
jdict.update({"spider": self.spider, "spider_args": self.get_spider_args()})
return jdict
Expand Down
8 changes: 8 additions & 0 deletions shub_workflow/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,14 @@ def _run_loops(self) -> Generator[bool, None, None]:
def base_loop_tasks(self):
...

@abc.abstractmethod
def _on_start(self):
...

@abc.abstractmethod
def _close(self):
...


class BaseLoopScript(BaseScript, BaseLoopScriptProtocol):

Expand Down
4 changes: 3 additions & 1 deletion shub_workflow/utils/sesemail.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import boto3
from botocore.client import Config

from shub_workflow.script import BaseScriptProtocol

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -100,7 +102,7 @@ def build_email_message(
return msg


class SESMailSenderMixin:
class SESMailSenderMixin(BaseScriptProtocol):
"""Use this mixin for enabling ses email sending capabilities on your script class"""

def __init__(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def workflow_loop(self):
self.assertEqual(manager.name, "my_fantasy_name")

manager._on_start()
self.assertFalse(manager._check_resume_workflow.called)
self.assertFalse(mocked_check_resume_workflow.called)

@patch("shub_workflow.base.WorkFlowManager._check_resume_workflow")
def test_check_resume_workflow_called(
Expand All @@ -72,7 +72,7 @@ def workflow_loop(self):
self.assertEqual(manager.name, "my_fantasy_name")

manager._on_start()
self.assertTrue(manager._check_resume_workflow.called)
self.assertTrue(mocked_check_resume_workflow.called)

def test_project_id_override(self, mocked_update_metadata, mocked_get_job_tags):
class TestManager(WorkFlowManager):
Expand Down
Loading

0 comments on commit 3d3d20b

Please sign in to comment.