Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints #593

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
146 changes: 92 additions & 54 deletions astroplan/constraints.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions astroplan/moon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)

# Standard library
from typing import Optional

# Third-party
import numpy as np
from astropy.coordinates import get_sun, get_body
from astropy.time import Time
import astropy.units as u
from astropy.units import Quantity

__all__ = ["moon_phase_angle", "moon_illumination"]


def moon_phase_angle(time, ephemeris=None):
def moon_phase_angle(time: Time, ephemeris: Optional[str] = None) -> Quantity[u.rad]:
"""
Calculate lunar orbital phase in radians.

Expand Down Expand Up @@ -41,7 +47,7 @@ def moon_phase_angle(time, ephemeris=None):
moon.distance - sun.distance*np.cos(elongation))


def moon_illumination(time, ephemeris=None):
def moon_illumination(time: Time, ephemeris: Optional[str] = None) -> Quantity[u.rad]:
"""
Calculate fraction of the moon illuminated.

Expand Down
173 changes: 107 additions & 66 deletions astroplan/observer.py

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions astroplan/periodic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import (absolute_import, division, print_function,
unicode_literals)
# Standard library
from typing import Optional, Union

# Third party
import numpy as np
import astropy.units as u
from astropy.time import Time
from astropy.units import Quantity

__all__ = ['PeriodicEvent', 'EclipsingSystem']

Expand All @@ -13,7 +18,8 @@ class PeriodicEvent(object):
A periodic event defined by an epoch and period.
"""
@u.quantity_input(period=u.day, duration=u.day)
def __init__(self, epoch, period, duration=None, name=None):
def __init__(self, epoch: Time, period: Quantity, duration: Optional[Quantity] = None,
name: Optional[str] = None):
"""

Parameters
Expand All @@ -32,7 +38,7 @@ def __init__(self, epoch, period, duration=None, name=None):
self.name = name
self.duration = duration

def phase(self, time):
def phase(self, time: Time) -> np.ndarray[float]:
"""
Phase of periodic event, on interval [0, 1). For example, the phase
could be an orbital phase for an eclipsing binary system.
Expand Down Expand Up @@ -66,8 +72,10 @@ class EclipsingSystem(PeriodicEvent):
barycentric correction error (<=16 minutes).
"""
@u.quantity_input(period=u.day, duration=u.day)
def __init__(self, primary_eclipse_time, orbital_period, duration=None,
name=None, eccentricity=None, argument_of_periapsis=None):
def __init__(self, primary_eclipse_time: Time, orbital_period: Quantity,
duration: Optional[Quantity] = None, name: Optional[str] = None,
eccentricity: Optional[float] = None,
argument_of_periapsis: Optional[float] = None):
"""
Parameters
----------
Expand Down Expand Up @@ -99,7 +107,7 @@ def __init__(self, primary_eclipse_time, orbital_period, duration=None,
argument_of_periapsis = np.pi/2
self.argument_of_periapsis = argument_of_periapsis

def in_primary_eclipse(self, time):
def in_primary_eclipse(self, time: Time) -> Union[np.ndarray[bool], bool]:
"""
Returns `True` when ``time`` is during a primary eclipse.

Expand All @@ -120,7 +128,7 @@ def in_primary_eclipse(self, time):
return ((phases < float(self.duration/self.period)/2) |
(phases > 1 - float(self.duration/self.period)/2))

def in_secondary_eclipse(self, time):
def in_secondary_eclipse(self, time: Time) -> Union[np.ndarray[bool], bool]:
r"""
Returns `True` when ``time`` is during a secondary eclipse

Expand Down Expand Up @@ -161,7 +169,7 @@ def in_secondary_eclipse(self, time):
return ((phases < secondary_eclipse_phase + float(self.duration/self.period)/2) &
(phases > secondary_eclipse_phase - float(self.duration/self.period)/2))

def out_of_eclipse(self, time):
def out_of_eclipse(self, time: Time) -> Union[np.ndarray[bool], bool]:
"""
Returns `True` when ``time`` is not during primary or secondary eclipse.

Expand All @@ -181,7 +189,7 @@ def out_of_eclipse(self, time):
return np.logical_not(np.logical_or(self.in_primary_eclipse(time),
self.in_secondary_eclipse(time)))

def next_primary_eclipse_time(self, time, n_eclipses=1):
def next_primary_eclipse_time(self, time: Time, n_eclipses: int = 1) -> Time:
"""
Time of the next primary eclipse after ``time``.

Expand All @@ -205,7 +213,7 @@ def next_primary_eclipse_time(self, time, n_eclipses=1):
np.arange(n_eclipses) * self.period)
return eclipse_times

def next_secondary_eclipse_time(self, time, n_eclipses=1):
def next_secondary_eclipse_time(self, time: Time, n_eclipses: int = 1) -> Time:
"""
Time of the next secondary eclipse after ``time``.

Expand Down Expand Up @@ -234,7 +242,7 @@ def next_secondary_eclipse_time(self, time, n_eclipses=1):
np.arange(n_eclipses) * self.period)
return eclipse_times

def next_primary_ingress_egress_time(self, time, n_eclipses=1):
def next_primary_ingress_egress_time(self, time: Time, n_eclipses: int = 1) -> Time:
"""
Calculate the times of ingress and egress for the next ``n_eclipses``
primary eclipses after ``time``
Expand Down Expand Up @@ -264,7 +272,7 @@ def next_primary_ingress_egress_time(self, time, n_eclipses=1):

return Time(ing_egr, format='jd', scale='utc')

def next_secondary_ingress_egress_time(self, time, n_eclipses=1):
def next_secondary_ingress_egress_time(self, time: Time, n_eclipses: int = 1) -> Time:
"""
Calculate the times of ingress and egress for the next ``n_eclipses``
secondary eclipses after ``time``
Expand Down
19 changes: 12 additions & 7 deletions astroplan/plots/finder.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals

import numpy as np
from typing import Optional, Union

import astropy.units as u

from matplotlib.axes import Axes
import numpy as np
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.units import Quantity

from ..target import FixedTarget

__all__ = ['plot_finder_image']


@u.quantity_input(fov_radius=u.deg)
def plot_finder_image(target, survey='DSS', fov_radius=10*u.arcmin,
log=False, ax=None, grid=False, reticle=False,
style_kwargs=None, reticle_style_kwargs=None):
def plot_finder_image(target: Union[FixedTarget, SkyCoord], survey: str = 'DSS',
fov_radius: Quantity = 10*u.arcmin, log: bool = False,
ax: Optional[Axes] = None, grid: bool = False, reticle: bool = False,
style_kwargs: Optional[dict] = None,
reticle_style_kwargs: Optional[dict] = None) -> Axes:
"""
Plot survey image centered on ``target``.

Expand Down
25 changes: 17 additions & 8 deletions astroplan/plots/sky.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import numpy as np
import warnings
from typing import Optional

import astropy.units as u
import numpy as np
from astropy.time import Time
import warnings
from astropy.units import Quantity
from matplotlib.axes import Axes

from ..exceptions import PlotBelowHorizonWarning
from ..observer import Observer
from ..target import FixedTarget
from ..utils import _set_mpl_style_sheet

__all__ = ['plot_sky', 'plot_sky_24hr']


@u.quantity_input(az_label_offset=u.deg)
def plot_sky(target, observer, time, ax=None, style_kwargs=None,
north_to_east_ccw=True, grid=True, az_label_offset=0.0*u.deg,
warn_below_horizon=False, style_sheet=None):
def plot_sky(target: FixedTarget, observer: Observer, time: Time, ax: Optional[Axes] = None,
style_kwargs: Optional[dict] = None, north_to_east_ccw: bool = True, grid: bool = True,
az_label_offset: Quantity = 0.0*u.deg,
warn_below_horizon: bool = False, style_sheet: Optional[dict] = None) -> Axes:
"""
Plots target positions in the sky with respect to the observer's location.

Expand Down Expand Up @@ -229,9 +236,11 @@ def plot_sky(target, observer, time, ax=None, style_kwargs=None,


@u.quantity_input(delta=u.hour)
def plot_sky_24hr(target, observer, time, delta=1*u.hour, ax=None,
style_kwargs=None, north_to_east_ccw=True, grid=True,
az_label_offset=0.0*u.deg, center_time_style_kwargs=None):
def plot_sky_24hr(target: FixedTarget, observer: Observer, time: Time, delta: Quantity = 1*u.hour,
ax: Axes = None, style_kwargs: Optional[dict] = None,
north_to_east_ccw: bool = True, grid: bool = True,
az_label_offset: Quantity = 0.0*u.deg,
center_time_style_kwargs: Optional[dict] = None) -> Axes:
"""
Plots target positions in the sky with respect to the observer's location
over a twenty-four hour period centered on ``time``.
Expand Down
49 changes: 29 additions & 20 deletions astroplan/plots/time_dependent.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import numpy as np
import operator
import astropy.units as u
from astropy.time import Time
from collections.abc import Sequence
import warnings
from collections.abc import Sequence
from typing import Optional

import astropy.units as u
import numpy as np
import pytz
from astropy.time import Time
from matplotlib.axes import Axes

from ..exceptions import PlotWarning
from ..observer import Observer
from ..scheduling import Schedule
from ..target import FixedTarget
from ..utils import _set_mpl_style_sheet

__all__ = ['plot_airmass', 'plot_schedule_airmass', 'plot_parallactic',
'plot_altitude']


def _secz_to_altitude(secant_z):
def _secz_to_altitude(secant_z: float) -> float:
"""
Convert airmass (approximated as the secant of the zenith angle) to
an altitude (aka elevation) in degrees.
Expand All @@ -35,7 +41,7 @@ def _secz_to_altitude(secant_z):
return np.degrees(np.pi/2 - np.arccos(1./secant_z))


def _has_twin(ax):
def _has_twin(ax: Axes) -> bool:
"""
Solution for detecting twin axes built on `ax`. Courtesy of
Jake Vanderplas http://stackoverflow.com/a/36209590/1340208
Expand All @@ -48,10 +54,12 @@ def _has_twin(ax):
return False


def plot_airmass(targets, observer, time, ax=None, style_kwargs=None,
style_sheet=None, brightness_shading=False,
altitude_yaxis=False, min_airmass=1.0, min_region=None,
max_airmass=3.0, max_region=None, use_local_tz=False):
def plot_airmass(targets: FixedTarget, observer: Observer, time: Time, ax: Optional[Axes] = None,
style_kwargs: Optional[dict] = None, style_sheet: Optional[dict] = None,
brightness_shading: bool = False, altitude_yaxis: bool = False,
min_airmass: float = 1.0, min_region: Optional[float] = None,
max_airmass: float = 3.0, max_region: Optional[float] = None,
use_local_tz: bool = False) -> Axes:
r"""
Plots airmass as a function of time for a given target.

Expand Down Expand Up @@ -276,10 +284,11 @@ def plot_airmass(targets, observer, time, ax=None, style_kwargs=None,
return ax


def plot_altitude(targets, observer, time, ax=None, style_kwargs=None,
style_sheet=None, brightness_shading=False,
airmass_yaxis=False, min_altitude=0, min_region=None,
max_altitude=90, max_region=None):
def plot_altitude(targets: FixedTarget, observer: Observer, time: Time, ax: Optional[Axes] = None,
style_kwargs: Optional[dict] = None, style_sheet: Optional[dict] = None,
brightness_shading: bool = False, airmass_yaxis: bool = False,
min_altitude: float = 0, min_region: Optional[float] = None,
max_altitude: float = 90, max_region: Optional[float] = None) -> Axes:
r"""
Plots altitude as a function of time for a given target.

Expand Down Expand Up @@ -467,7 +476,7 @@ def plot_altitude(targets, observer, time, ax=None, style_kwargs=None,
return ax


def plot_schedule_airmass(schedule, show_night=False):
def plot_schedule_airmass(schedule: Schedule, show_night: bool = False) -> Axes:
"""
Plots when observations of targets are scheduled to occur superimposed
upon plots of the airmasses of the targets.
Expand Down Expand Up @@ -526,8 +535,9 @@ def plot_schedule_airmass(schedule, show_night=False):
# TODO: make this output a `axes` object


def plot_parallactic(target, observer, time, ax=None, style_kwargs=None,
style_sheet=None):
def plot_parallactic(target: FixedTarget, observer: Observer, time: Time, ax: Optional[Axes] = None,
style_kwargs: Optional[dict] = None,
style_sheet: Optional[dict] = None) -> Axes:
"""
Plots parallactic angle as a function of time for a given target.

Expand Down Expand Up @@ -586,7 +596,6 @@ def plot_parallactic(target, observer, time, ax=None, style_kwargs=None,
_set_mpl_style_sheet(style_sheet)

import matplotlib.pyplot as plt

from matplotlib import dates

# Set up plot axes and style if needed.
Expand Down
Loading
Loading