Skip to content

Commit

Permalink
Merge pull request #49 from HamletTanyavong/dev
Browse files Browse the repository at this point in the history
Update source generators
  • Loading branch information
HamletTanyavong authored Jan 31, 2024
2 parents 4483f90 + 2c7861e commit f1e1bf3
Show file tree
Hide file tree
Showing 16 changed files with 137 additions and 158 deletions.
2 changes: 1 addition & 1 deletion src/Mathematics.NET.SourceGenerators.Public/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public static NameSyntax CreateNameSyntaxFromNamespace(this string namespaceName

if (structDeclaration.Parent is NamespaceDeclarationSyntax namespaceDeclaration)
{
return namespaceDeclaration.Name;
return namespaceDeclaration.Name.WithoutTrailingTrivia();
}

return default;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,7 @@ public bool Equals(NameSyntax? x, NameSyntax? y)
{
return x.GetNameValueOrDefault() == y.GetNameValueOrDefault();
}

if (x is null && y is null)
{
return true;
}

return false;
return x == y;
}

public int GetHashCode(NameSyntax? obj) => obj?.GetNameValueOrDefault()?.GetHashCode() ?? 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ private void GenerateCode(SourceProductionContext context, ImmutableArray<Struct
DiagnosticMessage.CreateInvalidSymbolDeclarationDiagnosticDescriptor(),
info.StructDeclarationSyntax.Identifier.GetLocation()));
}
break;
continue;
}
var symbols = new SymbolBuilder(nameSyntax!, context, selectedInformation);
var symbols = new SymbolBuilder(nameSyntax, context, selectedInformation);
context.AddSource($"Symbols.{nameSyntax.GetNameValueOrDefault()}.g.cs", symbols.GenerateSource().GetText(Encoding.UTF8).ToString());
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// <copyright file="IndexSwapRewriter.cs" company="Mathematics.NET">
// <copyright file="BracketedArgumentIndexSwapRewriter.cs" company="Mathematics.NET">
// Mathematics.NET
// https://github.com/HamletTanyavong/Mathematics.NET
//
Expand Down Expand Up @@ -29,13 +29,13 @@

namespace Mathematics.NET.SourceGenerators.DifferentialGeometry;

/// <summary>A C# syntax rewriter that swaps tensor index orders</summary>
internal sealed class IndexSwapRewriter : CSharpSyntaxRewriter
/// <summary>A C# syntax rewriter that swaps indices in a bracketed argument list</summary>
internal sealed class BracketedArgumentIndexSwapRewriter : CSharpSyntaxRewriter
{
private readonly ArgumentSyntax _indexToContract;
private readonly ArgumentSyntax _indexToSwap;

public IndexSwapRewriter(BracketedArgumentListSyntax bracketedArgumentList, string indexName)
public BracketedArgumentIndexSwapRewriter(BracketedArgumentListSyntax bracketedArgumentList, string indexName)
{
_indexToContract = bracketedArgumentList.Arguments.Last(x => x.Expression is IdentifierNameSyntax name && name.Identifier.Text == indexName);
_indexToSwap = GetIndexToSwap(bracketedArgumentList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
// SOFTWARE.
// </copyright>

using System.Runtime.CompilerServices;
using Mathematics.NET.SourceGenerators.DifferentialGeometry.Models;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Mathematics.NET.SourceGenerators.DifferentialGeometry;

/// <summary>Syntax helper for differential geometry generators</summary>
Expand All @@ -39,40 +35,26 @@ internal static class DifGeoGeneratorExtensions
/// <returns>A member declaration syntax</returns>
internal static MemberDeclarationSyntax GenerateTwinContraction(this MemberDeclarationSyntax memberDeclaration)
{
FlipIndexRewriter walker = new();
FlipIndexPositionRewriter walker = new();
return (MemberDeclarationSyntax)walker.Visit(memberDeclaration);
}

/// <summary>Get the index structure of a tensor.</summary>
/// <param name="memberDeclaration">A member declaration syntax</param>
/// <param name="position">An integer representing the current parameter position—the position of the tensor in question in the parameter list</param>
/// <returns>An index structure</returns>
internal static IndexStructure GetIndexStructure(this MemberDeclarationSyntax memberDeclaration, int position)
{
if (memberDeclaration.ParameterList() is ParameterListSyntax paramList)
{
var args = paramList.Parameters[position].TypeArgumentList()!.Arguments;
var index = args.IndexOf(args.Last(x => x is GenericNameSyntax name && name.Identifier.Text == "Index"));
return new(index, args.Count);
}
return default;
}

/// <summary>Swap the current index with the index immediately to its right.</summary>
/// <summary>Swap the index to contract with the index immediately to its right.</summary>
/// <param name="typeArgumentListSyntax">A type argument list syntax</param>
/// <param name="index">An integer representing the current index position</param>
/// <returns>A type argument list syntax with the specified indices swapped</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static TypeArgumentListSyntax SwapCurrentIndexWithNextIndex(this TypeArgumentListSyntax typeArgumentListSyntax, int index)
internal static TypeArgumentListSyntax SwapContractIndexWithNextIndex(this TypeArgumentListSyntax typeArgumentListSyntax)
{
var args = typeArgumentListSyntax.Arguments;
var currentIndex = args[index];
var nextIndex = args[index + 1];

var newArgs = args.Replace(currentIndex, nextIndex);
nextIndex = newArgs[index + 1];
newArgs = newArgs.Replace(nextIndex, currentIndex);
TypeArgumentIndexSwapRewriter rewriter = new(typeArgumentListSyntax);
return (TypeArgumentListSyntax)rewriter.Visit(typeArgumentListSyntax);
}

return TypeArgumentList(newArgs);
/// <summary>Swap the index to contract with the index immediately to its right.</summary>
/// <param name="bracketedArgumentListSyntax">A bracketed argument list syntax</param>
/// <param name="iterationIndexName">The name of the iteration index</param>
/// <returns>A bracketed argument list syntax with the specified indices swapped</returns>
internal static BracketedArgumentListSyntax SwapIterationIndexWithNextIndex(this BracketedArgumentListSyntax bracketedArgumentListSyntax, string iterationIndexName)
{
BracketedArgumentIndexSwapRewriter rewriter = new(bracketedArgumentListSyntax, iterationIndexName);
return (BracketedArgumentListSyntax)rewriter.Visit(bracketedArgumentListSyntax);
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// <copyright file="FlipIndexRewriter.cs" company="Mathematics.NET">
// <copyright file="FlipIndexPositionRewriter.cs" company="Mathematics.NET">
// Mathematics.NET
// https://github.com/HamletTanyavong/Mathematics.NET
//
Expand Down Expand Up @@ -29,8 +29,8 @@

namespace Mathematics.NET.SourceGenerators.DifferentialGeometry;

/// <summary>A syntax walker that flips lower indices to upper indices and vice versa</summary>
internal sealed class FlipIndexRewriter : CSharpSyntaxRewriter
/// <summary>A C# syntax rewriter that flips lower indices to upper indices and vice versa</summary>
internal sealed class FlipIndexPositionRewriter : CSharpSyntaxRewriter
{
private static readonly IdentifierNameSyntax s_lower = SyntaxFactory.IdentifierName("Lower");
private static readonly IdentifierNameSyntax s_upper = SyntaxFactory.IdentifierName("Upper");
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@ namespace Mathematics.NET.SourceGenerators.DifferentialGeometry;
/// <summary>Tensor contractions builder</summary>
internal sealed class TensorContractionBuilder : TensorContractionBuilderBase
{
private static readonly GenericNameSyntax s_indexToContract = GenericName(
Identifier("Index"))
.WithTypeArgumentList(
TypeArgumentList(
SeparatedList<TypeSyntax>(new SyntaxNodeOrToken[] {
IdentifierName("Upper"),
Token(SyntaxKind.CommaToken),
IdentifierName("IC") })));

public TensorContractionBuilder(SourceProductionContext context, ImmutableArray<MethodInformation> methodInformationArray)
: base(context, methodInformationArray) { }

Expand Down Expand Up @@ -257,7 +248,7 @@ private static MemberDeclarationSyntax ResetTypeParameterConstraints(MemberDecla
.First();

var newArgs = args.RemoveNode(args.Arguments.Last(), SyntaxRemoveOptions.KeepNoTrivia);
newArgs = newArgs!.InsertNodesAfter(newArgs!.Arguments[2], [s_indexToContract]);
newArgs = newArgs!.InsertNodesAfter(newArgs!.Arguments[2], [s_rightIndex]);

return memberDeclaration.ReplaceNode(args, newArgs);
}
Expand All @@ -268,16 +259,14 @@ private static MemberDeclarationSyntax ResetTypeParameters(MemberDeclarationSynt
var args = param.TypeArgumentList()!;

var newArgs = args.RemoveNode(args.Arguments.Last(), SyntaxRemoveOptions.KeepNoTrivia);
newArgs = newArgs!.InsertNodesAfter(newArgs!.Arguments[2], [s_indexToContract]);
newArgs = newArgs!.InsertNodesAfter(newArgs!.Arguments[2], [s_rightIndex]);
return memberDeclaration.ReplaceNode(args, newArgs);
}

private static MemberDeclarationSyntax SwapIndices(MemberDeclarationSyntax memberDeclaration, IndexPosition position)
{
var indexStructure = memberDeclaration.GetIndexStructure((int)position);

memberDeclaration = SwapTypeParameters(memberDeclaration, position, indexStructure);
memberDeclaration = SwapTypeParameterConstraints(memberDeclaration, position, indexStructure);
memberDeclaration = SwapTypeParameters(memberDeclaration, position);
memberDeclaration = SwapTypeParameterConstraints(memberDeclaration, position);
memberDeclaration = SwapMultiplyExpressionComponents(memberDeclaration, position);

return memberDeclaration;
Expand All @@ -294,21 +283,17 @@ private static MemberDeclarationSyntax SwapMultiplyExpressionComponents(MemberDe
.First(x => x.IsKind(SyntaxKind.MultiplyExpression));

var forStatement = (ForStatementSyntax)multiplyExpression.Parent!.Parent!.Parent!.Parent!;
var variableName = forStatement.Declaration!.Variables[0].Identifier.Text;

var iterationIndexName = forStatement.Declaration!.Variables[0].Identifier.Text;
var args = position == IndexPosition.Left
? multiplyExpression.Left.DescendantNodes().OfType<BracketedArgumentListSyntax>().First()
: multiplyExpression.Right.DescendantNodes().OfType<BracketedArgumentListSyntax>().First();
var indexSwapper = new IndexSwapRewriter(args, variableName);
var newArgs = (BracketedArgumentListSyntax)indexSwapper.Visit(args);

return memberDeclaration.ReplaceNode(args, newArgs);
return memberDeclaration.ReplaceNode(args, args.SwapIterationIndexWithNextIndex(iterationIndexName));
}

private static MemberDeclarationSyntax SwapRightIndices(MemberDeclarationSyntax memberDeclaration)
=> SwapIndices(memberDeclaration, IndexPosition.Right);

private static MemberDeclarationSyntax SwapTypeParameterConstraints(MemberDeclarationSyntax memberDeclaration, IndexPosition position, IndexStructure indexStructure)
private static MemberDeclarationSyntax SwapTypeParameterConstraints(MemberDeclarationSyntax memberDeclaration, IndexPosition position)
{
var constraints = memberDeclaration
.ChildNodes()
Expand All @@ -320,17 +305,17 @@ private static MemberDeclarationSyntax SwapTypeParameterConstraints(MemberDeclar
.OfType<TypeArgumentListSyntax>()
.First();

var newArgs = args.SwapCurrentIndexWithNextIndex(indexStructure.ContractPosition);
var newArgs = args.SwapContractIndexWithNextIndex();
var newConstraints = constraints.ReplaceNode(args, newArgs);

return memberDeclaration.ReplaceNode(constraints, newConstraints);
}

private static MemberDeclarationSyntax SwapTypeParameters(MemberDeclarationSyntax memberDeclaration, IndexPosition position, IndexStructure indexStructure)
private static MemberDeclarationSyntax SwapTypeParameters(MemberDeclarationSyntax memberDeclaration, IndexPosition position)
{
var param = memberDeclaration.ParameterList()!.Parameters[(int)position];
var args = param.TypeArgumentList()!;
var newParam = param.ReplaceNode(args, args.SwapCurrentIndexWithNextIndex(indexStructure.ContractPosition));
var newParam = param.ReplaceNode(args, args.SwapContractIndexWithNextIndex());
return memberDeclaration.ReplaceNode(param, newParam);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,32 @@
// </copyright>

using System.Collections.Immutable;
using Microsoft.CodeAnalysis.CSharp;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Mathematics.NET.SourceGenerators.DifferentialGeometry;

/// <summary>A base class for tensor contraction builders</summary>
internal abstract class TensorContractionBuilderBase
{
private protected static readonly GenericNameSyntax s_leftIndex = GenericName(
Identifier("Index"))
.WithTypeArgumentList(
TypeArgumentList(
SeparatedList<TypeSyntax>(new SyntaxNodeOrToken[] {
IdentifierName("Lower"),
Token(SyntaxKind.CommaToken),
IdentifierName("IC") })));

private protected static readonly GenericNameSyntax s_rightIndex = GenericName(
Identifier("Index"))
.WithTypeArgumentList(
TypeArgumentList(
SeparatedList<TypeSyntax>(new SyntaxNodeOrToken[] {
IdentifierName("Upper"),
Token(SyntaxKind.CommaToken),
IdentifierName("IC") })));

private protected readonly SourceProductionContext _context;
private protected readonly ImmutableArray<MethodInformation> _methodInformationArray;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.CreateSyntaxProvider(CouldBeGenerateTensorContractionAttribute, GetTensorContractionOrNull)
.Where(x => x is not null);
var compilation = context.CompilationProvider.Combine(provider.Collect());
context.RegisterSourceOutput(compilation, (context, source) => GenerateCode(context, source.Left, source.Right));
context.RegisterSourceOutput(compilation, (context, source) => GenerateCode(context, source.Right));
}

private static bool CouldBeGenerateTensorContractionAttribute(SyntaxNode syntaxNode, CancellationToken token)
Expand All @@ -53,7 +53,7 @@ private static MethodInformation GetTensorContractionOrNull(GeneratorSyntaxConte
return new(attribute, (MethodDeclarationSyntax)attribute.Parent!.Parent!);
}

private void GenerateCode(SourceProductionContext context, Compilation compilation, ImmutableArray<MethodInformation> information)
private void GenerateCode(SourceProductionContext context, ImmutableArray<MethodInformation> information)
{
var tensorContractions = new TensorContractionBuilder(context, information);
context.AddSource("DifGeo.Contractions.g.cs", tensorContractions.GenerateSource().GetText(Encoding.UTF8).ToString());
Expand Down
Loading

0 comments on commit f1e1bf3

Please sign in to comment.