-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add models_utils.py and fix come bugs with tensors and device in gnn_…
…models.py
- Loading branch information
1 parent
c2a7a42
commit 3f066f6
Showing
3 changed files
with
65 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import torch | ||
from torch_geometric.nn import MessagePassing | ||
|
||
|
||
def apply_message_gradient_capture(layer, name): | ||
""" | ||
# Example how get Tensors | ||
# for name, layer in self.gnn.named_children(): | ||
# if isinstance(layer, MessagePassing): | ||
# print(f"{name}: {layer.get_message_gradients()}") | ||
""" | ||
original_message = layer.message | ||
layer.message_gradients = {} | ||
|
||
def capture_message_gradients(x_j, *args, **kwargs): | ||
x_j = x_j.requires_grad_() | ||
if not layer.training: | ||
return original_message(x_j=x_j, *args, **kwargs) | ||
|
||
def save_message_grad(grad): | ||
layer.message_gradients[name] = grad.detach() | ||
x_j.register_hook(save_message_grad) | ||
return original_message(x_j=x_j, *args, **kwargs) | ||
layer.message = capture_message_gradients | ||
|
||
def get_message_gradients(): | ||
return layer.message_gradients | ||
layer.get_message_gradients = get_message_gradients | ||
|
||
|
||
def apply_decorator_to_graph_layers(model): | ||
# TODO Kirill add more options | ||
""" | ||
Example how use this def | ||
apply_decorator_to_graph_layers(gnn) | ||
""" | ||
for name, layer in model.named_children(): | ||
if isinstance(layer, MessagePassing): | ||
apply_message_gradient_capture(layer, name) | ||
elif isinstance(layer, torch.nn.Module): | ||
apply_decorator_to_graph_layers(layer) | ||
|