Skip to content

Commit

Permalink
Adding find_dependent_helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerrit Ansmann committed Mar 5, 2018
1 parent 4f5fd7e commit ca80fe2
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
20 changes: 20 additions & 0 deletions jitcxde_common/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,26 @@ def filter_helpers(helpers,symbols):

return list(reversed(filtered_rev))

def find_dependent_helpers(helpers,dependency):
"""
Returns a list of helpers depending on `dependency` and their respective derivative (applying the chain rule).
"""

dependent_helpers = []

for helper in helpers:
derivative = sum(
(
helper[1].diff(other_helper[0]) * other_helper[1]
for other_helper in dependent_helpers
),
helper[1].diff(dependency)
)
if derivative != 0:
dependent_helpers.append( (helper[0], derivative) )

return dependent_helpers

def copy_helpers(helpers):
return [helper for helper in helpers]

21 changes: 19 additions & 2 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-

from jitcxde_common.helpers import sort_helpers, filter_helpers, copy_helpers
from jitcxde_common.helpers import sort_helpers, filter_helpers, copy_helpers, find_dependent_helpers
import unittest
from symengine import symbols
from symengine import symbols, sin, cos, Integer
from itertools import permutations

p,q,r,s,u,v = symbols("p q r s u v")
Expand Down Expand Up @@ -39,6 +39,23 @@ def test_independence(self):
copy.append([u,v])
assert copy!=chain

class FindDependentTest(unittest.TestCase):
def test_find_dependent_helpers(self):
helpers = [
( q, p ),
( r, sin(q) ),
( s, 3*p+r ),
( u, Integer(42) ),
]
control = [
( q, 1 ),
( r, cos(q) ),
( s, 3+cos(q) ),
]
dependent_helpers = find_dependent_helpers(helpers,p)
# This check is overly restrictive due to depending on the order and exact form of the result:
self.assertListEqual(dependent_helpers,control)

if __name__ == "__main__":
unittest.main(buffer=True)

Expand Down

0 comments on commit ca80fe2

Please sign in to comment.