Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting the scatter & sparse into Metal / MPS #464

Open
thegodone opened this issue Oct 29, 2024 · 3 comments
Open

Porting the scatter & sparse into Metal / MPS #464

thegodone opened this issue Oct 29, 2024 · 3 comments

Comments

@thegodone
Copy link

Is there a ongoing work to port scatter and sparse packages into Metal / MPS ?

@Nobregaigor
Copy link

Hi, I am also encountering issues with torch_scatter when using MPS:

--> 169 min_weights, _ = scatter_min(weights, src, dim=0, dim_size=num_nodes)

File /.../python3.12/site-packages/torch_scatter/scatter.py:75, in scatter_min(src, index, dim, out, dim_size)
     69 def scatter_min(
     70         src: torch.Tensor,
     71         index: torch.Tensor,
     72         dim: int = -1,
     73         out: Optional[torch.Tensor] = None,
     74         dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
---> 75     return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)

File /.../python3.12/site-packages/torch/_ops.py:1061, in OpOverloadPacket.__call__(self_, *args, **kwargs)
   1060     return _call_overload_packet_from_python(self_, args, kwargs)
-> 1061 return self_._op(*args, **(kwargs or {}))

RuntimeError: src.device().is_cpu() INTERNAL ASSERT FAILED at "csrc/cpu/scatter_cpu.cpp":11, please report a bug to PyTorch. src must be CPU tensor

I understand that scatter operations and sparse tensors may not yet have full MPS support. Is there any update on whether torch_scatter or related scatter/sparse packages are being actively ported to work with MPS?

Any insights or workarounds would be greatly appreciated. Thank you!

@thegodone
Copy link
Author

There is this package available : https://github.com/mlx-graphs/mlx-graphs that is a very good start but not as complete as PYG in term of convolution layers.

@sarmientoF
Copy link

sarmientoF commented Nov 27, 2024

@thegodone I've looked into mlx-graphs as an alternative to torch_scatter, but I could not find a way to get the index location when using the basic scatter operations, do you know if it is possible using mlx-graphs/mlx ?

I want to get out, argmax like in the code bellow

from torch_scatter import scatter_max

src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))

out, argmax = scatter_max(src, index, out=out)

Using mlx scatter_max only returns out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants