diff --git a/jitcxde_common/_jitcxde.py b/jitcxde_common/_jitcxde.py index b1b84cd..12081a5 100644 --- a/jitcxde_common/_jitcxde.py +++ b/jitcxde_common/_jitcxde.py @@ -73,6 +73,42 @@ def __init__(self,n=None,verbose=True,module_location=None): # • False if a compile attempt was made but not succesful # • True if a successful compile attempt was made + def _check_dynvar_dict(self,dictionary,name,length): + if not set(dictionary.keys()) == {self.dynvar(i) for i in range(length)}: + raise ValueError(f"If {name} is a dictionary, its keys must be y(0), y(1), …, y(n) where n is the number of entries.") + + def _generator_func_from_dynvar_dict(self,dictionary,name,length): + """ + returns a generator function that yields: + dictionary[dynvar(0)], dictionary[dynvar(1)], …, dictionary[dynvar(length)] + + Parameters + ---------- + name: string + the name of the dictionary for error messages + """ + self._check_dynvar_dict(dictionary,name,length) + def generator_func(): + for i in range(length): + yield dictionary[self.dynvar(i)] + return generator_func + + def _list_from_dynvar_dict(self,dictionary,name,length): + """ + returns the list + [ dictionary[dynvar(0)], dictionary[dynvar(1)], …, dictionary[dynvar(length)] ] + + Parameters + ---------- + name: string + the name of the dictionary for error messages + """ + self._check_dynvar_dict(dictionary,name,length) + return [ + dictionary[self.dynvar(i)] + for i in range(length) + ] + def _handle_input(self,f_sym,n_basic=False): """ Converts f_sym to a generator function if necessary. @@ -96,11 +132,7 @@ def _handle_input(self,f_sym,n_basic=False): else: self.n = length if isinstance(f_sym,dict): - if not set(f_sym.keys()) == {self.dynvar(i) for i in range(length)}: - raise ValueError("If f_sym is a dictionary, its keys must be y(0), y(1), …, y(n) where n is the number of entries.") - def new_f_sym(): - for i in range(length): - yield f_sym[self.dynvar(i)] + new_f_sym = self._generator_func_from_dynvar_dict(f_sym,"f_sym",length) else: def new_f_sym(): gen = f_sym() if isgeneratorfunction(f_sym) else f_sym diff --git a/tests/test_code.py b/tests/test_code.py index 5e194e5..d6f906b 100644 --- a/tests/test_code.py +++ b/tests/test_code.py @@ -165,6 +165,21 @@ def test_dict_spurious_equation(self): faulty_f = { y(0):1, y(1):1, x:1 } with self.assertRaises(ValueError): jitcxde_tester(faulty_f) + + def test_dict_tester_missing_equation(self): + tester = jitcxde_tester(f) + faulty_dict = { y(0):1, y(2):1 } + for i in range(10): + with self.assertRaises(ValueError): + tester._check_dynvar_dict(faulty_dict,"",i) + + def test_dict_spurious_equation(self): + tester = jitcxde_tester(f) + x = symengine.Symbol("x") + faulty_dict = { y(0):1, y(1):1, x:1 } + for i in range(10): + with self.assertRaises(ValueError): + tester._check_dynvar_dict(faulty_dict,"",i) if __name__ == "__main__": unittest.main(buffer=True)