Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Nx.Random.shuffle repeating a single value in certain cases on GPU #1552

Merged
merged 7 commits into from
Oct 30, 2024

Conversation

polvalente
Copy link
Contributor

fixes #1551

In short, what was happening is that we were not testing a specific edge case that resulted in an invalid-shaped tensor for the sort_keys sub-input in sort_key_val, that in turn yielded an invalid shuffle when the input is 1D.

@polvalente polvalente self-assigned this Oct 30, 2024
nx/lib/nx/random.ex Outdated Show resolved Hide resolved
@jonatanklosko
Copy link
Member

I'm still trying to narrow it down, but the issue must be somewhere in the combination of random_bits and take_along_axis. For example, if we make this change:

-sort_keys = random_bits(keys[1], shape: tensor.shape)
+sort_keys = randint_split(keys[1], 0, uint32max, shape: tensor.shape)

It seems to work just fine.

@jonatanklosko
Copy link
Member

@polvalente actually, this also fixes it:

-sort_keys = random_bits(keys[1], shape: tensor.shape)
+sort_keys = random_bits(keys[1], shape: tensor.shape) |> Nx.as_type(:s32)

@jonatanklosko jonatanklosko changed the title fix: use different gather functions whether Nx.Random.shuffle is independent fix: Nx.Random.shuffle repeating a single value in certain cases on GPU Oct 30, 2024
Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

@polvalente polvalente merged commit a21d30b into main Oct 30, 2024
6 of 8 checks passed
@polvalente polvalente deleted the pv-fix/random-shuffle-independent branch October 30, 2024 05:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Nx.Random.shuffle/3 fails for large tensors on cuda backend
2 participants