-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/handle solve args #748
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
This reverts commit 0b06cc4.
@@ -435,7 +435,7 @@ def _prepare( | |||
cost_matrix_rank: Optional[int] = None, | |||
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None, | |||
# problem | |||
alpha: float = 0.5, | |||
alpha: Optional[float] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just make a comment behind this that default is 0.5
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm actually wondering whether we should have a default alpha
whenever we don't want it to be 1.0
I.e. always set it explicitly in the classes which use GW, wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set it like this so that it doesn't change the current behaviour.
I.e. always set it explicitly in the classes which use GW, wdyt?
You mean to make non-optional? I personally prefer non-optional parameters, especially if the class is an internal solver. I also think we should rename GWSolver
to FGWSolver
(because it technically can solve fgw and gw) and just set alpha=1
when in a GWProblem
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, overall really nice. Let's see what Mike and Marco say on the ott-jax side!
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
hi @MUCDK ,
So good news is we currently do a good job on partitioning the
kwargs
for solve. In solve we give anykwarg
we don't know to eitherSinkhornSolver
orGWSolver
constructors.SinkhornSolver
usesSinkhorn
orLRSinkhorn
fromottjax
, these classes don't havekwargs
in their constructors so when usingSinkhornSolver
as a backend we are good.GWSolver
usesGromovWasserstein
orLRGromovWasserstein
fromottjax
. The parent class of these classWassersteinSolver
don't throw an error on unrecognized args. The tests will pass after the ottjax PR merges.Here is the PR in
ott-jax
: ott-jax/ott#579Other things done:
CompoundProblem
or any other more abstract class. It's handled inGWSolver
as it should.Additionally closes:
solve
methods #720