Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve message and parameters in MSTEST0025 fixer #4301

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Text;
Expand Down Expand Up @@ -44,22 +45,28 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
context.RegisterCodeFix(
CodeAction.Create(
CodeFixResources.ReplaceWithFailAssertionFix,
ct => SwapArgumentsAsync(context.Document, invocationExpr, ct),
ct => UseAssertFailAsync(context.Document, invocationExpr, diagnostic.AdditionalLocations, ct),
nameof(PreferAssertFailOverAlwaysFalseConditionsFixer)),
context.Diagnostics);
}
}

private static async Task<Document> SwapArgumentsAsync(Document document, InvocationExpressionSyntax invocationExpr, CancellationToken cancellationToken)
private static async Task<Document> UseAssertFailAsync(Document document, InvocationExpressionSyntax invocationExpr, IReadOnlyList<Location> additionalLocations, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();

DocumentEditor editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false);
SyntaxGenerator generator = editor.Generator;

SyntaxNode newInvocationExpr = generator.InvocationExpression(
var newInvocationExpr = (InvocationExpressionSyntax)generator.InvocationExpression(
generator.MemberAccessExpression(generator.IdentifierName("Assert"), "Fail"));

if (additionalLocations.Count >= 1)
{
IEnumerable<ArgumentSyntax> arguments = additionalLocations.Select(location => (ArgumentSyntax)invocationExpr.FindNode(location.SourceSpan));
newInvocationExpr = newInvocationExpr.WithArgumentList(SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments)));
}

editor.ReplaceNode(invocationExpr, newInvocationExpr);

return editor.GetChangedDocument();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ internal static class WellKnownTypeNames
public const string SystemDescriptionAttribute = "System.ComponentModel.DescriptionAttribute";
public const string SystemIAsyncDisposable = "System.IAsyncDisposable";
public const string SystemIDisposable = "System.IDisposable";
public const string SystemNullable = "System.Nullable`1";
public const string SystemReflectionMethodInfo = "System.Reflection.MethodInfo";
public const string SystemRuntimeCompilerServicesITuple = "System.Runtime.CompilerServices.ITuple";
public const string SystemThreadingTasksTask = "System.Threading.Tasks.Task";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ private enum EqualityStatus
private const string ActualParameterName = "actual";
private const string ConditionParameterName = "condition";
private const string ValueParameterName = "value";
private const string MessageParameterName = "message";
private const string ParametersParameterName = "parameters";

private static readonly LocalizableResourceString Title = new(nameof(Resources.PreferAssertFailOverAlwaysFalseConditionsTitle), Resources.ResourceManager, typeof(Resources));
private static readonly LocalizableResourceString MessageFormat = new(nameof(Resources.PreferAssertFailOverAlwaysFalseConditionsMessageFormat), Resources.ResourceManager, typeof(Resources));
Expand All @@ -57,43 +59,74 @@ public override void Initialize(AnalysisContext context)
{
Compilation compilation = context.Compilation;
INamedTypeSymbol? assertSymbol = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.MicrosoftVisualStudioTestToolsUnitTestingAssert);
INamedTypeSymbol? nullableSymbol = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemNullable);
if (assertSymbol is not null)
{
context.RegisterOperationAction(context => AnalyzeOperation(context, assertSymbol, nullableSymbol), OperationKind.Invocation);
context.RegisterOperationAction(context => AnalyzeOperation(context, assertSymbol), OperationKind.Invocation);
}
});
}

private static void AnalyzeOperation(OperationAnalysisContext context, INamedTypeSymbol assertSymbol, INamedTypeSymbol? nullableSymbol)
private static void AnalyzeOperation(OperationAnalysisContext context, INamedTypeSymbol assertSymbol)
{
var operation = (IInvocationOperation)context.Operation;

if (assertSymbol.Equals(operation.TargetMethod.ContainingType, SymbolEqualityComparer.Default) &&
IsAlwaysFalse(operation, nullableSymbol))
IsAlwaysFalse(operation))
{
context.ReportDiagnostic(operation.CreateDiagnostic(Rule, operation.TargetMethod.Name));
context.ReportDiagnostic(operation.CreateDiagnostic(Rule, GetAdditionalLocations(operation), properties: null, operation.TargetMethod.Name));
}
}

private static bool IsAlwaysFalse(IInvocationOperation operation, INamedTypeSymbol? nullableSymbol)
private static ImmutableArray<Location> GetAdditionalLocations(IInvocationOperation operation)
{
IArgumentOperation? messageArg = operation.Arguments.FirstOrDefault(arg => arg.Parameter?.Name == MessageParameterName);
if (messageArg is null)
{
return ImmutableArray<Location>.Empty;
}

IArgumentOperation? parametersArg = operation.Arguments.FirstOrDefault(arg => arg.Parameter?.Name == ParametersParameterName);
if (parametersArg is null)
{
return ImmutableArray.Create(messageArg.Syntax.GetLocation());
}

if (parametersArg.ArgumentKind == ArgumentKind.ParamArray)
{
ImmutableArray<Location>.Builder builder = ImmutableArray.CreateBuilder<Location>();
builder.Add(messageArg.Syntax.GetLocation());
if (parametersArg.Value is IArrayCreationOperation { Initializer.ElementValues: { } elements })
{
foreach (IOperation element in elements)
{
builder.Add(element.Syntax.GetLocation());
}
}

return builder.ToImmutable();
}

return ImmutableArray.Create(messageArg.Syntax.GetLocation(), parametersArg.Syntax.GetLocation());
}

private static bool IsAlwaysFalse(IInvocationOperation operation)
=> operation.TargetMethod.Name switch
{
"IsTrue" => GetConditionArgument(operation) is { Value.ConstantValue: { HasValue: true, Value: false } },
"IsFalse" => GetConditionArgument(operation) is { Value.ConstantValue: { HasValue: true, Value: true } },
"AreEqual" => GetEqualityStatus(operation, ExpectedParameterName) == EqualityStatus.NotEqual,
"AreNotEqual" => GetEqualityStatus(operation, NotExpectedParameterName) == EqualityStatus.Equal,
"IsNotNull" => GetValueArgument(operation) is { Value.ConstantValue: { HasValue: true, Value: null } },
"IsNull" => GetValueArgument(operation) is { } valueArgumentOperation && IsNotNullableType(valueArgumentOperation, nullableSymbol),
"IsNull" => GetValueArgument(operation) is { } valueArgumentOperation && IsNotNullableType(valueArgumentOperation),
_ => false,
};

private static bool IsNotNullableType(IArgumentOperation valueArgumentOperation, INamedTypeSymbol? nullableSymbol)
private static bool IsNotNullableType(IArgumentOperation valueArgumentOperation)
{
ITypeSymbol? valueArgType = valueArgumentOperation.Value.GetReferencedMemberOrLocalOrParameter().GetReferencedMemberOrLocalOrParameter();
return valueArgType is not null
&& valueArgType.NullableAnnotation == NullableAnnotation.NotAnnotated
&& !SymbolEqualityComparer.IncludeNullability.Equals(valueArgType.OriginalDefinition, nullableSymbol);
&& valueArgType.OriginalDefinition.SpecialType != SpecialType.System_Nullable_T;
}

private static IArgumentOperation? GetArgumentWithName(IInvocationOperation operation, string name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,42 +57,41 @@ public override void Initialize(AnalysisContext context)
{
Compilation compilation = context.Compilation;
INamedTypeSymbol? assertSymbol = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.MicrosoftVisualStudioTestToolsUnitTestingAssert);
INamedTypeSymbol? nullableSymbol = compilation.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemNullable);
if (assertSymbol is not null)
{
context.RegisterOperationAction(context => AnalyzeOperation(context, assertSymbol, nullableSymbol), OperationKind.Invocation);
context.RegisterOperationAction(context => AnalyzeOperation(context, assertSymbol), OperationKind.Invocation);
}
});
}

private static void AnalyzeOperation(OperationAnalysisContext context, INamedTypeSymbol assertSymbol, INamedTypeSymbol? nullableSymbol)
private static void AnalyzeOperation(OperationAnalysisContext context, INamedTypeSymbol assertSymbol)
{
var operation = (IInvocationOperation)context.Operation;
if (assertSymbol.Equals(operation.TargetMethod.ContainingType, SymbolEqualityComparer.Default) &&
IsAlwaysTrue(operation, nullableSymbol))
IsAlwaysTrue(operation))
{
context.ReportDiagnostic(operation.CreateDiagnostic(Rule));
}
}

private static bool IsAlwaysTrue(IInvocationOperation operation, INamedTypeSymbol? nullableSymbol)
private static bool IsAlwaysTrue(IInvocationOperation operation)
=> operation.TargetMethod.Name switch
{
"IsTrue" => GetConditionArgument(operation) is { Value.ConstantValue: { HasValue: true, Value: true } },
"IsFalse" => GetConditionArgument(operation) is { Value.ConstantValue: { HasValue: true, Value: false } },
"AreEqual" => GetEqualityStatus(operation, ExpectedParameterName) == EqualityStatus.Equal,
"AreNotEqual" => GetEqualityStatus(operation, NotExpectedParameterName) == EqualityStatus.NotEqual,
"IsNull" => GetValueArgument(operation) is { Value.ConstantValue: { HasValue: true, Value: null } },
"IsNotNull" => GetValueArgument(operation) is { } valueArgumentOperation && IsNotNullableType(valueArgumentOperation, nullableSymbol),
"IsNotNull" => GetValueArgument(operation) is { } valueArgumentOperation && IsNotNullableType(valueArgumentOperation),
_ => false,
};

private static bool IsNotNullableType(IArgumentOperation valueArgumentOperation, INamedTypeSymbol? nullableSymbol)
private static bool IsNotNullableType(IArgumentOperation valueArgumentOperation)
{
ITypeSymbol? valueArgType = valueArgumentOperation.Value.GetReferencedMemberOrLocalOrParameter().GetReferencedMemberOrLocalOrParameter();
return valueArgType is not null
&& valueArgType.NullableAnnotation == NullableAnnotation.NotAnnotated
&& !SymbolEqualityComparer.IncludeNullability.Equals(valueArgType.OriginalDefinition, nullableSymbol);
&& valueArgType.OriginalDefinition.SpecialType != SpecialType.System_Nullable_T;
}

private static IArgumentOperation? GetArgumentWithName(IInvocationOperation operation, string name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,44 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail("message");
}
}
""";
await VerifyCS.VerifyCodeFixAsync(code, fixedCode);
}

public async Task WhenAssertIsTrueIsPassedFalse_WithMessageAndArgsAsParams_Diagnostic()
{
string code = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void TestMethod()
{
[|Assert.IsTrue(false, "message", "1")|];
[|Assert.IsTrue(false, "message", "1", "2")|];
[|Assert.IsTrue(false, "message", new object[] { "1", "2" })|];
[|Assert.IsTrue(message: "message", parameters: new object[] { "1", "2" }, condition: false)|];
}
}
""";
string fixedCode = """
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class MyTestClass
{
[TestMethod]
public void TestMethod()
{
Assert.Fail("message", "1");
Assert.Fail("message", "1", "2");
Assert.Fail("message", new object[] { "1", "2" });
Assert.Fail(message: "message", parameters: new object[] { "1", "2" });
}
}
""";
Expand Down Expand Up @@ -357,7 +394,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail(message: "message");
}
}
""";
Expand Down Expand Up @@ -483,7 +520,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail("message");
}
}
""";
Expand Down Expand Up @@ -515,7 +552,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail(message: "message");
}
}
""";
Expand Down Expand Up @@ -699,7 +736,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail("message");
}
}
""";
Expand Down Expand Up @@ -731,7 +768,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail(message: "message");
}
}
""";
Expand Down Expand Up @@ -934,7 +971,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail("message");
}
}
""";
Expand Down Expand Up @@ -966,7 +1003,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail(message: "message");
}
}
""";
Expand Down Expand Up @@ -998,7 +1035,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail(message: "message");
}
}
""";
Expand Down Expand Up @@ -1146,7 +1183,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail("message");
}
}
""";
Expand Down Expand Up @@ -1178,7 +1215,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail(message: "message");
}
}
""";
Expand Down Expand Up @@ -1210,7 +1247,7 @@ public class MyTestClass
[TestMethod]
public void TestMethod()
{
Assert.Fail();
Assert.Fail(message: "message");
}
}
""";
Expand Down