Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Fix Aggregate types in UDF #2775

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.CodeAnalysis;
using Microsoft.PowerFx.Core.App;
using Microsoft.PowerFx.Core.App.Controls;
using Microsoft.PowerFx.Core.App.ErrorContainers;
using Microsoft.PowerFx.Core.Binding;
using Microsoft.PowerFx.Core.Binding.BindInfo;
using Microsoft.PowerFx.Core.Entities;
Expand Down Expand Up @@ -57,6 +58,8 @@ public override bool IsServerDelegatable(CallNode callNode, TexlBinding binding)

public override bool SupportsParamCoercion => true;

public override bool HasPreciseErrors => true;

private const int MaxParameterCount = 30;

public TexlNode UdfBody { get; }
Expand All @@ -77,6 +80,27 @@ public override bool TryGetDataSource(CallNode callNode, TexlBinding binding, ou

public bool HasDelegationWarning => _binding?.ErrorContainer.GetErrors().Any(error => error.MessageKey.Contains("SuggestRemoteExecutionHint")) ?? false;

public override bool CheckTypes(CheckTypesContext context, TexlNode[] args, DType[] argTypes, IErrorContainer errors, out DType returnType, out Dictionary<TexlNode, DType> nodeToCoercedTypeMap)
{
if (!base.CheckTypes(context, args, argTypes, errors, out returnType, out nodeToCoercedTypeMap))
{
return false;
}

for (int i = 0; i < argTypes.Length; i++)
{
if ((argTypes[i].IsTableNonObjNull || argTypes[i].IsRecordNonObjNull) &&
!ParamTypes[i].Accepts(argTypes[i], exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: context.Features.PowerFxV1CompatibilityRules, true) &&
!argTypes[i].CoercesTo(ParamTypes[i], true, false, context.Features, true))
{
errors.EnsureError(DocumentErrorSeverity.Severe, args[i], TexlStrings.ErrBadSchema_ExpectedType, ParamTypes[i].GetKindString());
return false;
}
}

return true;
}

/// <summary>
/// Initializes a new instance of the <see cref="UserDefinedFunction"/> class.
/// </summary>
Expand Down Expand Up @@ -167,15 +191,22 @@ public void CheckTypesOnDeclaration(CheckTypesContext context, DType actualBodyR
Contracts.AssertValue(actualBodyReturnType);
Contracts.AssertValue(binding);

if (!ReturnType.Accepts(actualBodyReturnType, exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: context.Features.PowerFxV1CompatibilityRules))
if (!ReturnType.Accepts(actualBodyReturnType, exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: context.Features.PowerFxV1CompatibilityRules, true))
{
if (actualBodyReturnType.CoercesTo(ReturnType, true, false, context.Features))
if (actualBodyReturnType.CoercesTo(ReturnType, true, false, context.Features, true))
{
_binding.SetCoercedType(binding.Top, ReturnType);
}
else
{
var node = UdfBody is VariadicOpNode variadicOpNode ? variadicOpNode.Children.Last() : UdfBody;

if ((ReturnType.IsTable && actualBodyReturnType.IsTable) || (ReturnType.IsRecord && actualBodyReturnType.IsRecord))
{
binding.ErrorContainer.EnsureError(DocumentErrorSeverity.Severe, node, TexlStrings.ErrUDF_ReturnTypeSchemaDoesNotMatch, ReturnType.GetKindString());
return;
}

binding.ErrorContainer.EnsureError(DocumentErrorSeverity.Severe, node, TexlStrings.ErrUDF_ReturnTypeDoesNotMatch, ReturnType.GetKindString(), actualBodyReturnType.GetKindString());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ internal static class TexlStrings
public static ErrorResourceKey ErrUDF_DuplicateParameter = new ErrorResourceKey("ErrUDF_DuplicateParameter");
public static ErrorResourceKey ErrUDF_UnknownType = new ErrorResourceKey("ErrUDF_UnknownType");
public static ErrorResourceKey ErrUDF_ReturnTypeDoesNotMatch = new ErrorResourceKey("ErrUDF_ReturnTypeDoesNotMatch");
public static ErrorResourceKey ErrUDF_ReturnTypeSchemaDoesNotMatch = new ErrorResourceKey("ErrUDF_ReturnTypeSchemaDoesNotMatch");
public static ErrorResourceKey ErrUDF_TooManyParameters = new ErrorResourceKey("ErrUDF_TooManyParameters");
public static ErrorResourceKey ErrUDF_MissingReturnType = new ErrorResourceKey("ErrUDF_MissingReturnType");
public static ErrorResourceKey ErrUDF_MissingParamType = new ErrorResourceKey("ErrUDF_MissingParamType");
Expand Down
55 changes: 40 additions & 15 deletions src/libraries/Microsoft.PowerFx.Core/Types/DType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1851,12 +1851,13 @@ private bool AcceptsEntityType(DType type, bool usePowerFxV1CompatibilityRules)
/// <param name="useLegacyDateTimeAccepts">Legacy rules for accepting date/time types.</param>
/// <param name="usePowerFxV1CompatibilityRules">Use PFx v1 compatibility rules if enabled (less
/// permissive Accepts relationships).</param>
/// <param name="restrictiveAggregateTypes">restrictiveAggregateTypes.</param>
/// <returns>
/// True if <see cref="DType"/> accepts <paramref name="type"/>, false otherwise.
/// </returns>
public bool Accepts(DType type, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules)
public bool Accepts(DType type, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules, bool restrictiveAggregateTypes = false)
{
return Accepts(type, out _, out _, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules);
return Accepts(type, out _, out _, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules, restrictiveAggregateTypes);
}

/// <summary>
Expand Down Expand Up @@ -1888,7 +1889,7 @@ public bool Accepts(DType type, bool exact, bool useLegacyDateTimeAccepts, bool
/// <returns>
/// True if <see cref="DType"/> accepts <paramref name="type"/>, false otherwise.
/// </returns>
public virtual bool Accepts(DType type, out KeyValuePair<string, DType> schemaDifference, out DType schemaDifferenceType, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules)
public virtual bool Accepts(DType type, out KeyValuePair<string, DType> schemaDifference, out DType schemaDifferenceType, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules, bool restrictiveAggregateTypes = false)
{
AssertValid();
type.AssertValid();
Expand Down Expand Up @@ -1941,7 +1942,7 @@ bool DefaultReturnValue(DType targetType) =>

if (Kind == type.Kind)
{
return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules);
return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes);
}

accepts = type.Kind == DKind.Unknown || type.Kind == DKind.Deferred;
Expand All @@ -1955,7 +1956,7 @@ bool DefaultReturnValue(DType targetType) =>

if (Kind == type.Kind || type.IsExpandEntity)
{
return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules);
return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes);
}

accepts = (IsMultiSelectOptionSet() && TypeTree.GetPairs().First().Value.OptionSetInfo == type.OptionSetInfo) || type.Kind == DKind.Unknown || type.Kind == DKind.Deferred;
Expand Down Expand Up @@ -2175,7 +2176,7 @@ bool DefaultReturnValue(DType targetType) =>
}

// Implements Accepts for Record and Table types.
private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree treeSrc, out KeyValuePair<string, DType> schemaDifference, out DType treeSrcSchemaDifferenceType, bool exact = true, bool useLegacyDateTimeAccepts = false, bool usePowerFxV1CompatibilityRules = false)
private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree treeSrc, out KeyValuePair<string, DType> schemaDifference, out DType treeSrcSchemaDifferenceType, bool exact = true, bool useLegacyDateTimeAccepts = false, bool usePowerFxV1CompatibilityRules = false, bool restrictiveAggregateTypes = false)
{
treeDst.AssertValid();
treeSrc.AssertValid();
Expand Down Expand Up @@ -2215,7 +2216,7 @@ private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree tre
return false;
}

if (!pairDst.Value.Accepts(type, out var recursiveSchemaDifference, out var recursiveSchemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules))
if (!pairDst.Value.Accepts(type, out var recursiveSchemaDifference, out var recursiveSchemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes))
{
if (!TryGetDisplayNameForColumn(parentType, pairDst.Key, out var colName))
{
Expand All @@ -2237,6 +2238,17 @@ private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree tre
}
}

if (restrictiveAggregateTypes)
{
foreach (var pairSrc in treeSrc)
{
if (!treeDst.Contains(pairSrc.Key))
{
return false;
}
}
}

return true;
}

Expand Down Expand Up @@ -3141,17 +3153,17 @@ public bool ContainsControlType(DPath path)
(n.Type.IsAggregate && n.Type.ContainsControlType(DPath.Root)));
}

public bool CoercesTo(DType typeDest, bool aggregateCoercion, bool isTopLevelCoercion, Features features)
public bool CoercesTo(DType typeDest, bool aggregateCoercion, bool isTopLevelCoercion, Features features, bool restrictiveAggregateTypes = false)
{
return CoercesTo(typeDest, out _, aggregateCoercion, isTopLevelCoercion, features);
return CoercesTo(typeDest, out _, aggregateCoercion, isTopLevelCoercion, features, restrictiveAggregateTypes);
}

public bool CoercesTo(DType typeDest, out bool isSafe, bool aggregateCoercion, bool isTopLevelCoercion, Features features)
public bool CoercesTo(DType typeDest, out bool isSafe, bool aggregateCoercion, bool isTopLevelCoercion, Features features, bool restrictiveAggregateTypes = false)
{
return CoercesTo(typeDest, out isSafe, out _, out _, out _, aggregateCoercion, isTopLevelCoercion, features);
return CoercesTo(typeDest, out isSafe, out _, out _, out _, aggregateCoercion, isTopLevelCoercion, features, restrictiveAggregateTypes);
}

public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coercionType, out KeyValuePair<string, DType> schemaDifference, out DType schemaDifferenceType, Features features, bool aggregateCoercion = true)
public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coercionType, out KeyValuePair<string, DType> schemaDifference, out DType schemaDifferenceType, Features features, bool aggregateCoercion = true, bool restrictiveAggregateTypes = false)
{
Contracts.Assert(IsAggregate);

Expand Down Expand Up @@ -3259,6 +3271,17 @@ public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coerci
isSafe &= fieldIsSafe;
}

if (restrictiveAggregateTypes)
{
foreach (var typedName in GetNames(DPath.Root))
{
if (!typeDest.TryGetType(typedName.Name, out _))
{
return false;
}
}
}

return isValid;
}

Expand All @@ -3273,7 +3296,8 @@ public virtual bool CoercesTo(
out DType schemaDifferenceType,
bool aggregateCoercion,
bool isTopLevelCoercion,
Features features)
Features features,
bool restrictiveAggregateTypes = false)
{
AssertValid();
Contracts.Assert(typeDest.IsValid);
Expand All @@ -3290,7 +3314,7 @@ public virtual bool CoercesTo(
return false;
}

if (typeDest.Accepts(this, exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules))
if (typeDest.Accepts(this, exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes))
{
coercionType = typeDest;
return true;
Expand Down Expand Up @@ -3335,7 +3359,8 @@ public virtual bool CoercesTo(
out schemaDifference,
out schemaDifferenceType,
features,
aggregateCoercion);
aggregateCoercion,
restrictiveAggregateTypes);
}

var subtypeCoerces = SubtypeCoercesTo(
Expand Down
4 changes: 4 additions & 0 deletions src/strings/PowerFxResources.en-US.resx
Original file line number Diff line number Diff line change
Expand Up @@ -4221,6 +4221,10 @@
<value>The stated function return type '{0}' does not match the return type of the function body '{1}'.</value>
<comment>This error message shows up when expected return type does not match with actual return type. The arguments '{0}' and '{1}' will be replaced with data types. For example, "The stated function return type 'Number' does not match the return type of the function body 'Table'"</comment>
</data>
<data name="ErrUDF_ReturnTypeSchemaDoesNotMatch" xml:space="preserve">
<value>The schema of stated function return type '{0}' does not match the schema of return type of the function body.</value>
<comment>This error message shows up when expected return type schema does not match with schema of actual return type. The arguments '{0}' will be replaced with aggregate data types. For example, "The schema of stated function return type 'Table' does not match the schema of return type of the function body."</comment>
</data>
<data name="ErrUDF_TooManyParameters" xml:space="preserve">
<value>Function {0} has too many parameters. User-defined functions support up to {1} parameters.</value>
<comment>This error message shows up when a user tries to define a function with too many parameters. {0} - the name of the user-defined function, {1} - the max number of parameters allowed.</comment>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1865,12 +1865,7 @@ protected override bool TryGetField(FormulaType fieldType, string fieldName, out
true,
42.0)]

// Functions accept record with more/less fields
[InlineData(
"People := Type([{Name: Text, Age: Number}]); countMinors(p: People): Number = CountRows(Filter(p, Age < 18));",
"countMinors([{Name: \"Bob\", Age: 21, Title: \"Engineer\"}, {Name: \"Alice\", Age: 25, Title: \"Manager\"}])",
true,
0.0)]
// Functions accept record with less fields
[InlineData(
"Employee := Type({Name: Text, Age: Number, Title: Text}); getAge(e: Employee): Number = e.Age;",
"getAge({Name: \"Bob\", Age: 21})",
Expand Down Expand Up @@ -1949,7 +1944,17 @@ protected override bool TryGetField(FormulaType fieldType, string fieldName, out
"f():TestEntity = Entity; g(e: TestEntity):Number = 1;",
"g(f())",
true,
1.0)]
1.0)]

// Aggregate types with more than expected fields are not allowed in UDF
[InlineData(
"f():T = {x: 5, y: 5}; T := Type({x: Number});",
"f().x",
false)]
[InlineData(
"People := Type([{Name: Text, Age: Number}]); countMinors(p: People): Number = CountRows(Filter(p, Age < 18));",
"countMinors([{Name: \"Bob\", Age: 21, Title: \"Engineer\"}, {Name: \"Alice\", Age: 25, Title: \"Manager\"}])",
false)]
public void UserDefinedTypeTest(string userDefinitions, string evalExpression, bool isValid, double expectedResult = 0)
{
var config = new PowerFxConfig();
Expand All @@ -1970,7 +1975,11 @@ public void UserDefinedTypeTest(string userDefinitions, string evalExpression, b
}
else
{
Assert.Throws<InvalidOperationException>(() => recalcEngine.AddUserDefinitions(userDefinitions, CultureInfo.InvariantCulture));
Assert.ThrowsAny<Exception>(() =>
{
recalcEngine.AddUserDefinitions(userDefinitions, CultureInfo.InvariantCulture);
recalcEngine.Eval(evalExpression, options: parserOptions);
});
}
}

Expand Down
Loading