Skip to content

Commit

Permalink
Rename Node to GradientNode
Browse files Browse the repository at this point in the history
  • Loading branch information
HamletTanyavong committed Dec 11, 2023
1 parent 31bfb25 commit cb900ae
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// <copyright file="Node.cs" company="Mathematics.NET">
// <copyright file="GradientNode.cs" company="Mathematics.NET">
// Mathematics.NET
// https://github.com/HamletTanyavong/Mathematics.NET
//
Expand Down Expand Up @@ -32,7 +32,7 @@ namespace Mathematics.NET.AutoDiff;
/// <summary>Represents a node on a gradient tape</summary>
/// <typeparam name="T">A type that implements <see cref="IComplex{T}"/></typeparam>
[StructLayout(LayoutKind.Sequential)]
internal readonly record struct Node<T>
internal readonly record struct GradientNode<T>
where T : IComplex<T>
{
/// <summary>The derivative of the left component of the binary operation</summary>
Expand All @@ -45,7 +45,7 @@ internal readonly record struct Node<T>
/// <summary>The parent index of the right node</summary>
public readonly int PY;

public Node(int index)
public GradientNode(int index)
{
DX = T.Zero;
DY = T.Zero;
Expand All @@ -54,7 +54,7 @@ public Node(int index)
PY = index;
}

public Node(T dx, int px, int py)
public GradientNode(T dx, int px, int py)
{
DX = dx;
DY = T.Zero;
Expand All @@ -63,7 +63,7 @@ public Node(T dx, int px, int py)
PY = py;
}

public Node(T dx, T dy, int px, int py)
public GradientNode(T dx, T dy, int px, int py)
{
DX = dx;
DY = dy;
Expand Down
8 changes: 4 additions & 4 deletions src/Mathematics.NET/AutoDiff/GradientTape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public record class GradientTape<T> : ITape<T>
{
// TODO: Measure performance with Stack<Node<T>> instead of List<Node<T>>
// TODO: Consider using array pools or something similar
private List<Node<T>> _nodes;
private List<GradientNode<T>> _nodes;
private int _variableCount;

public GradientTape()
Expand All @@ -97,8 +97,8 @@ public void PrintNodes(CancellationToken cancellationToken, int limit = 100)
{
const string tab = " ";

ReadOnlySpan<Node<T>> nodeSpan = CollectionsMarshal.AsSpan(_nodes);
Node<T> node;
ReadOnlySpan<GradientNode<T>> nodeSpan = CollectionsMarshal.AsSpan(_nodes);
GradientNode<T> node;

int i = 0;
while (i < Math.Min(_variableCount, limit))
Expand Down Expand Up @@ -146,7 +146,7 @@ public void ReverseAccumulation(out ReadOnlySpan<T> gradient, T seed)
throw new Exception("Gradient tape contains no root nodes");
}

ReadOnlySpan<Node<T>> nodes = CollectionsMarshal.AsSpan(_nodes);
ReadOnlySpan<GradientNode<T>> nodes = CollectionsMarshal.AsSpan(_nodes);
ref var start = ref MemoryMarshal.GetReference(nodes);

var length = nodes.Length;
Expand Down

0 comments on commit cb900ae

Please sign in to comment.