Skip to content

Commit

Permalink
Make IComparable<T> a compiled template
Browse files Browse the repository at this point in the history
We simplify template authoring/implementation by requiring template types to be called TSelf and optional value type TId. This simplifies implementation and documentation, and should be properly documented and enforced via an analyzer.
  • Loading branch information
kzu committed Dec 7, 2024
1 parent 56185cf commit 39bfb79
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 140 deletions.
9 changes: 8 additions & 1 deletion src/StructId.Analyzer/AnalysisExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,14 @@ public static string GetTypeName(this ITypeSymbol type, string? containingNamesp
return typeName;
}

public static string ToFileName(this ITypeSymbol type) => type.ToDisplayString(FullNameNullable).Replace('+', '.');
public static string ToFileName(this ITypeSymbol type)
{
if (type.ContainingNamespace == null || type.ContainingNamespace.IsGlobalNamespace)
return type.Name;

var name = type.MetadataName.Replace('+', '.');
return $"{type.ContainingNamespace.ToFullName()}.{name}";
}

public static bool IsStructId(this ITypeSymbol type) => type.AllInterfaces.Any(x => x.Name == "IStructId");

Expand Down
10 changes: 0 additions & 10 deletions src/StructId.Analyzer/ComparableGenerator.cs

This file was deleted.

89 changes: 32 additions & 57 deletions src/StructId.Analyzer/TemplatedGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
Expand All @@ -14,8 +12,12 @@
namespace StructId;

[Generator(LanguageNames.CSharp)]
public class TemplatedGenerator : IIncrementalGenerator
public partial class TemplatedGenerator : IIncrementalGenerator
{
static Regex TSelfExpr = new($@"\bTSelf\b", RegexOptions.Compiled | RegexOptions.Multiline);

static Regex TIdExpr = new($@"\bTId\b", RegexOptions.Compiled | RegexOptions.Multiline);

/// <summary>
/// Provides access to some common types and properties used in the compilation.
/// </summary>
Expand Down Expand Up @@ -55,30 +57,30 @@ record IdTemplate(INamedTypeSymbol StructId, INamedTypeSymbol TId, Template Temp

record Template(INamedTypeSymbol TSelf, INamedTypeSymbol TId, AttributeData Attribute, KnownTypes KnownTypes)
{
ConcurrentDictionary<INamedTypeSymbol, string> text = new(SymbolEqualityComparer.Default);
string? code;

public INamedTypeSymbol? CustomTId { get; init; }
public INamedTypeSymbol? OriginalTId { get; init; }

// A custom TId is a file-local type declaration.
public bool IsCustomTId => CustomTId?.DeclaringSyntaxReferences
public bool IsLocalTId => OriginalTId?.DeclaringSyntaxReferences
.All(x => x.GetSyntax() is TypeDeclarationSyntax decl && decl.Modifiers.Any(m => m.IsKind(SyntaxKind.FileKeyword))) == true;

public Regex NameExpr { get; } = new Regex($@"\b{TSelf.Name}\b", RegexOptions.Compiled | RegexOptions.Multiline);

public string GetText(INamedTypeSymbol tid) => text.GetOrAdd(tid, tid
=> GetTemplateCode(TSelf, TId, CustomTId, tid, Attribute, KnownTypes));
public string Text
{
get => code ??= GetTemplateCode(TSelf, TId, OriginalTId, Attribute, KnownTypes);
}

static string GetTemplateCode(INamedTypeSymbol self,
INamedTypeSymbol templateIdType, INamedTypeSymbol? customIdType, INamedTypeSymbol idTypeInstance,
AttributeData attribute, KnownTypes known)
static string GetTemplateCode(INamedTypeSymbol self, INamedTypeSymbol id,
INamedTypeSymbol? originalId, AttributeData attribute, KnownTypes known)
{
if (self.DeclaringSyntaxReferences[0].GetSyntax() is not TypeDeclarationSyntax declaration)
return "";

// Remove the TId/TValue if present in the same syntax tree.
var toremove = templateIdType.DeclaringSyntaxReferences.Select(x => x.GetSyntax()).ToList();
if (customIdType != null)
toremove.AddRange(customIdType.DeclaringSyntaxReferences.Select(x => x.GetSyntax()));
var toremove = id.DeclaringSyntaxReferences.Select(x => x.GetSyntax()).ToList();
// The target id might not be the same as the original id (which can be a local TId)
if (originalId != null)
toremove.AddRange(originalId.DeclaringSyntaxReferences.Select(x => x.GetSyntax()));

// Also the [TStructId<T>] attribute applied to the template itself
if (attribute.ApplicationSyntaxReference?.GetSyntax().FirstAncestorOrSelf<AttributeListSyntax>() is { } attr)
Expand Down Expand Up @@ -131,50 +133,21 @@ static string GetTemplateCode(INamedTypeSymbol self,
return null!;
});

if (idTypeInstance.Equals(templateIdType, SymbolEqualityComparer.Default))
return root.SyntaxTree.GetRoot().ToFullString().Trim();

if (!idTypeInstance.ImplementsExplicitly(templateIdType))
return root.SyntaxTree.GetRoot().ToFullString().Trim();

// rewrite Value references to explicit casts just in case the
// target type is implemented explicitly.

var tid = templateIdType;
if (tid.IsUnboundGenericType && tid.TypeParameters.Length == 1)
{
try
{
// bind to namedTypeSymbol
tid = tid.ConstructedFrom.Construct(idTypeInstance);
}
catch (Exception ex)
{
Debug.WriteLine(ex.ToString());
}
}

root = new ValueRewriter(tid).Visit(root);

var code = root.SyntaxTree.GetRoot().ToFullString().Trim();

return code;
}
}

class ValueRewriter(ITypeSymbol idType) : CSharpSyntaxRewriter
class ValueTypeRewriter(INamedTypeSymbol originalType, INamedTypeSymbol targetType) : CSharpSyntaxRewriter
{
public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node)
// rewrite references to the original type with the target type
public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node)
{
// Cover both the this.Value scenario
if (node.Name.Identifier.Text == "Value")
return ParenthesizedExpression(CastExpression(ParseTypeName(idType.ToFullName()), node));

// As well as the Value.[Member] scenario
if (node.Expression is IdentifierNameSyntax name && name.Identifier.Text == "Value")
return node.WithExpression(ParenthesizedExpression(CastExpression(ParseTypeName(idType.ToFullName()), name)));
if (node.Identifier.Text == originalType.Name)
return IdentifierName(targetType.ToFullName());

return base.VisitMemberAccessExpression(node);
return base.VisitIdentifierName(node);
}
}

Expand Down Expand Up @@ -216,11 +189,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var idType = (INamedTypeSymbol)structId.GetMembers().OfType<IPropertySymbol>().First(p => p.Name == "Value").Type;
attribute = structId.GetAttributes().First(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructId));

// Otherwise, if idType is a local type, this should fail (but an analyzer will take care of that).
// The id type isn't declared in the same file, so we don't do anything fancy with it.
if (idType.DeclaringSyntaxReferences.Length == 0)
return new Template(structId, idType, attribute, known);

// otherwise, the idType is a file-local type with a single interface
// Otherwise, the idType is a file-local type with a single interface
var type = idType.DeclaringSyntaxReferences[0].GetSyntax(cancellation) as TypeDeclarationSyntax;
var iface = type?.BaseList?.Types.FirstOrDefault()?.Type;
if (type == null || iface == null)
Expand All @@ -236,7 +209,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

return new Template(structId, ifaceType, attribute, known)
{
CustomTId = idType
OriginalTId = idType
};
})
.Collect();
Expand Down Expand Up @@ -270,7 +243,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
// the struct id's value type, such as implementing multiple interfaces. In
// this case, the tid would never equal or inherit from the template's TId,
// but we want instead to check for base type compatibility plus all interfaces.
(template.IsCustomTId &&
(template.IsLocalTId &&
// TId is a derived class of the template's TId base type (i.e. object or ValueType)
tid.Is(template.TId.BaseType) &&
// All template provided TId interfaces must be implemented by the struct id's TId
Expand All @@ -284,8 +257,10 @@ public void Initialize(IncrementalGeneratorInitializationContext context)

void GenerateCode(SourceProductionContext context, IdTemplate source)
{
var hintName = $"{source.StructId.ToFileName()}-{source.Template.TSelf.Name}.cs";
var output = source.Template.NameExpr.Replace(source.Template.GetText(source.TId), source.StructId.Name);
var hintName = $"{source.StructId.ToFileName()}/{source.Template.TId.ToFileName()}.cs";
var output = TIdExpr.Replace(
TSelfExpr.Replace(source.Template.Text, source.StructId.Name),
source.TId.ToFullName());

if (source.StructId.ContainingNamespace.Equals(source.StructId.ContainingModule.GlobalNamespace, SymbolEqualityComparer.Default))
{
Expand Down
5 changes: 5 additions & 0 deletions src/StructId.FunctionalTests/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ public void EqualityTest()

Assert.Equal(id1, id2);
Assert.True(id1 == id2);

var user1 = new UserId(1);
var user2 = new UserId(2);

Assert.True(user1 < user2);
}

[Fact]
Expand Down
2 changes: 1 addition & 1 deletion src/StructId.FunctionalTests/IIdTemplate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using StructId.Functional;

[TStructId]
file partial record struct IIdTemplate(Guid Value) : IId
file partial record struct TSelf(Guid Value) : IId
{
public Guid Id => Value;
}
64 changes: 0 additions & 64 deletions src/StructId.Tests/ComparableGeneratorTests.cs

This file was deleted.

Loading

0 comments on commit 39bfb79

Please sign in to comment.