Skip to content

Commit

Permalink
Fixes for changes relating to Numba error types. (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSKenyon committed Jan 26, 2024
1 parent 84e317a commit e7e03f1
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 0 deletions.
6 changes: 6 additions & 0 deletions quartical/gains/amplitude/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -449,6 +451,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -505,6 +509,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)
param_to_gain = param_to_gain_factory(corr_mode)

Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/complex/diag_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -428,6 +430,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -484,6 +488,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)

def impl(
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/complex/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -429,6 +431,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -485,6 +489,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)

def impl(
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/crosshand_phase/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -453,6 +455,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -509,6 +513,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)
param_to_gain = param_to_gain_factory(corr_mode)

Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/delay/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -493,6 +495,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -553,6 +557,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)
param_to_gain = param_to_gain_factory(corr_mode)

Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/delay_and_offset/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -487,6 +489,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -547,6 +551,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)
param_to_gain = param_to_gain_factory(corr_mode)

Expand Down
2 changes: 2 additions & 0 deletions quartical/gains/leakage/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)

def impl(
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/phase/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -460,6 +462,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -516,6 +520,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)
param_to_gain = param_to_gain_factory(corr_mode)

Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/rotation/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -463,6 +465,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -521,6 +525,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)

if corr_mode.literal_value == 4:
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/rotation_measure/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -474,6 +476,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -534,6 +538,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)

if corr_mode.literal_value == 4:
Expand Down
6 changes: 6 additions & 0 deletions quartical/gains/tec_and_offset/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def nb_compute_jhj_jhr(
corr_mode
):

coerce_literal(nb_compute_jhj_jhr, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
row_weights_idx = ms_inputs.fields.index('ROW_WEIGHTS')
row_weights_type = ms_inputs[row_weights_idx]
Expand Down Expand Up @@ -490,6 +492,8 @@ def compute_update(native_imdry, corr_mode):
@overload(compute_update, jit_options=PARALLEL_JIT_OPTIONS)
def nb_compute_update(native_imdry, corr_mode):

coerce_literal(nb_compute_update, ["corr_mode"])

# We want to dispatch based on this field so we need its type.
jhj = native_imdry[native_imdry.fields.index('jhj')]

Expand Down Expand Up @@ -550,6 +554,8 @@ def nb_finalize_update(
corr_mode
):

coerce_literal(nb_finalize_update, ["corr_mode"])

set_identity = factories.set_identity_factory(corr_mode)
param_to_gain = param_to_gain_factory(corr_mode)

Expand Down

0 comments on commit e7e03f1

Please sign in to comment.