Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keras implementation of softdtw #511

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Ivorforce
Copy link

@Ivorforce Ivorforce commented Mar 7, 2024

Let me know if this looks correct.
I tested it and for equal arrays it returns a negative value. I think it's because softmin can return values less than the input values, but I guess that makes sense if it has to be differentiable.

Anyway, here's a minimal use example:

from keras import layers

model = keras.Sequential([
    layers.InputLayer(input_shape=(15, 1)),
])
model.compile(
    optimizer=keras.optimizers.Adam(0.001),
    loss=SoftDTWLoss()
)
history = model.fit(
    np.arange(15, dtype=float)[None, :, None],
    np.arange(15, dtype=float)[None, :, None],
    epochs=6,
)
print(model.predict(np.arange(15, dtype=float)[None, :, None]))

TODO

  • Docstrings updated
  • Code cleaned up

@Ivorforce
Copy link
Author

Ivorforce commented Mar 7, 2024

I added an even more shoddy version for the gradient. There are definitely some cleanup things to do, but at least it runs with the custom gradient function now.

I am currently getting some graph size issues when trying to use the loss with anything that's not trivial, which is I think caused by the dtw being so recursion intensive. I may have to give up on using dtw for my own project unless I can somehow work around that, but I'll leave the PR up either way for future reference, or other people to fix and use for their projects.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant