Skip to content

Commit

Permalink
added tests to issue arrow-py#1145 and Format code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
psyuktha committed Oct 17, 2024
1 parent f0e6212 commit 69c3aa1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
12 changes: 7 additions & 5 deletions arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,8 @@ def range(
yield current

values = [getattr(current, f) for f in cls._ATTRS]
current = cls(*values, tzinfo=tzinfo).shift( # type: ignore[misc]
**{frame_relative: relative_steps}
current = cls(*values).shift(
check_imaginary=True, **{frame_relative: relative_steps}
)

if frame in ["month", "quarter", "year"] and current.day < original_day:
Expand Down Expand Up @@ -583,7 +583,9 @@ def span(
elif frame_absolute == "quarter":
floor = floor.shift(months=-((self.month - 1) % 3))

ceil = floor.shift(**{frame_relative: count * relative_steps})
ceil = floor.shift(
check_imaginary=True, **{frame_relative: count * relative_steps}
)

if bounds[0] == "(":
floor = floor.shift(microseconds=+1)
Expand Down Expand Up @@ -981,7 +983,7 @@ def replace(self, **kwargs: Any) -> "Arrow":

return self.fromdatetime(current)

def shift(self, check_imaginary=True, **kwargs: Any) -> "Arrow":
def shift(self, check_imaginary: bool = True, **kwargs: Any) -> "Arrow":
"""Returns a new :class:`Arrow <arrow.arrow.Arrow>` object with attributes updated
according to inputs.
Expand Down Expand Up @@ -1447,7 +1449,7 @@ def dehumanize(self, input_string: str, locale: str = "en_us") -> "Arrow":

time_changes = {k: sign_val * v for k, v in time_object_info.items()}

return current_time.shift(**time_changes)
return current_time.shift(check_imaginary=True, **time_changes)

# query functions

Expand Down
10 changes: 10 additions & 0 deletions tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,16 @@ def test_shift_negative_imaginary(self):
2011, 12, 31, 23, tzinfo="Pacific/Apia"
)

def test_shift_with_imaginary_check(self):
dt = arrow.Arrow(2024, 3, 10, 2, 30, tzinfo=tz.gettz("US/Eastern"))
shifted = dt.shift(hours=1)
assert shifted.datetime.hour == 3

def test_shift_without_imaginary_check(self):
dt = arrow.Arrow(2024, 3, 10, 2, 30, tzinfo=tz.gettz("US/Eastern"))
shifted = dt.shift(hours=1, check_imaginary=False)
assert shifted.datetime.hour == 3

@pytest.mark.skipif(
dateutil.__version__ < "2.7.1", reason="old tz database (2018d needed)"
)
Expand Down

0 comments on commit 69c3aa1

Please sign in to comment.