Skip to content

Commit

Permalink
Infrastructure for dictionary inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerrit Ansmann committed Nov 28, 2018
1 parent 84d6a94 commit 9732cbd
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
42 changes: 37 additions & 5 deletions jitcxde_common/_jitcxde.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,42 @@ def __init__(self,n=None,verbose=True,module_location=None):
# • False if a compile attempt was made but not succesful
# • True if a successful compile attempt was made

def _check_dynvar_dict(self,dictionary,name,length):
if not set(dictionary.keys()) == {self.dynvar(i) for i in range(length)}:
raise ValueError(f"If {name} is a dictionary, its keys must be y(0), y(1), …, y(n) where n is the number of entries.")

def _generator_func_from_dynvar_dict(self,dictionary,name,length):
"""
returns a generator function that yields:
dictionary[dynvar(0)], dictionary[dynvar(1)], …, dictionary[dynvar(length)]
Parameters
----------
name: string
the name of the dictionary for error messages
"""
self._check_dynvar_dict(dictionary,name,length)
def generator_func():
for i in range(length):
yield dictionary[self.dynvar(i)]
return generator_func

def _list_from_dynvar_dict(self,dictionary,name,length):
"""
returns the list
[ dictionary[dynvar(0)], dictionary[dynvar(1)], …, dictionary[dynvar(length)] ]
Parameters
----------
name: string
the name of the dictionary for error messages
"""
self._check_dynvar_dict(dictionary,name,length)
return [
dictionary[self.dynvar(i)]
for i in range(length)
]

def _handle_input(self,f_sym,n_basic=False):
"""
Converts f_sym to a generator function if necessary.
Expand All @@ -96,11 +132,7 @@ def _handle_input(self,f_sym,n_basic=False):
else: self.n = length

if isinstance(f_sym,dict):
if not set(f_sym.keys()) == {self.dynvar(i) for i in range(length)}:
raise ValueError("If f_sym is a dictionary, its keys must be y(0), y(1), …, y(n) where n is the number of entries.")
def new_f_sym():
for i in range(length):
yield f_sym[self.dynvar(i)]
new_f_sym = self._generator_func_from_dynvar_dict(f_sym,"f_sym",length)
else:
def new_f_sym():
gen = f_sym() if isgeneratorfunction(f_sym) else f_sym
Expand Down
15 changes: 15 additions & 0 deletions tests/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,21 @@ def test_dict_spurious_equation(self):
faulty_f = { y(0):1, y(1):1, x:1 }
with self.assertRaises(ValueError):
jitcxde_tester(faulty_f)

def test_dict_tester_missing_equation(self):
tester = jitcxde_tester(f)
faulty_dict = { y(0):1, y(2):1 }
for i in range(10):
with self.assertRaises(ValueError):
tester._check_dynvar_dict(faulty_dict,"",i)

def test_dict_spurious_equation(self):
tester = jitcxde_tester(f)
x = symengine.Symbol("x")
faulty_dict = { y(0):1, y(1):1, x:1 }
for i in range(10):
with self.assertRaises(ValueError):
tester._check_dynvar_dict(faulty_dict,"",i)

if __name__ == "__main__":
unittest.main(buffer=True)
Expand Down

0 comments on commit 9732cbd

Please sign in to comment.