From f8a8ab185ebbcb7744d4b2ba5d9694b3fb5a688e Mon Sep 17 00:00:00 2001 From: Eric Rozell Date: Thu, 19 Nov 2020 12:08:02 -0500 Subject: [PATCH] Adds support for entity hierarchies in compare This change allows users to declaratively specify hierarchical entities in their expected utterance results. For example, a user may declare the following: ```json { "text": "Order a pepperoni pizza", "intent": "OrderFood", "entities": { "entity": "FoodItem", "startPos": 8, "endPos": 22, "children": [ { "entity": "Topping", "startPos": 8, "endPos": 16 }, { "entity": "FoodType", "startPos": 18, "endPos": 22 } ] } } ``` This would result in 3 test cases, one for the parent entity (the "FoodItem" entity), and two additional test cases for each of the two nested entities ("FoodItem::Topping" and "FoodItem::FoodType"). Child entity type names are prefixed by their parent entity type names in the format `parentType::childType`. As such, the recursive entity parsing for the LUIS V3 provider has been updated to use this convention. Fixes #335 --- .../JsonLabeledUtteranceConverterTests.cs | 59 +++++++- .../NLU.DevOps.Core.Tests.csproj | 3 +- src/NLU.DevOps.Core/EntityConverter.cs | 130 ++++++++++++++---- src/NLU.DevOps.Core/HierarchicalEntity.cs | 31 +++++ src/NLU.DevOps.Core/IHierarchicalEntity.cs | 19 +++ src/NLU.DevOps.Core/JsonEntities.cs | 33 ++++- .../JsonLabeledUtteranceConverter.cs | 3 + .../LuisNLUTestClientTests.cs | 8 +- src/NLU.DevOps.LuisV3/LuisNLUTestClient.cs | 12 +- .../TestCaseSourceTests.cs | 4 +- 10 files changed, 255 insertions(+), 47 deletions(-) create mode 100644 src/NLU.DevOps.Core/HierarchicalEntity.cs create mode 100644 src/NLU.DevOps.Core/IHierarchicalEntity.cs diff --git a/src/NLU.DevOps.Core.Tests/JsonLabeledUtteranceConverterTests.cs b/src/NLU.DevOps.Core.Tests/JsonLabeledUtteranceConverterTests.cs index b25ba5a..bcc0471 100644 --- a/src/NLU.DevOps.Core.Tests/JsonLabeledUtteranceConverterTests.cs +++ b/src/NLU.DevOps.Core.Tests/JsonLabeledUtteranceConverterTests.cs @@ -4,9 +4,9 @@ namespace NLU.DevOps.Core.Tests { using System; - using System.Collections.Generic; using System.Linq; using FluentAssertions; + using FluentAssertions.Json; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Newtonsoft.Json.Serialization; @@ -87,6 +87,63 @@ public static void ConvertsUtteranceWithStartPosAndEndPosEntity() actual.Entities[0].MatchIndex.Should().Be(2); } + [Test] + public static void ConvertsUtteranceWithNestedEntities() + { + var text = "foo bar baz"; + + var leafEntity = new JObject + { + { "entity", "baz" }, + { "startPos", 8 }, + { "endPos", 10 }, + { "foo", new JArray(42) }, + { "bar", null }, + { "baz", 42 }, + { "qux", JValue.CreateUndefined() }, + }; + + var midEntity = new JObject + { + { "entityType", "bar" }, + { "matchText", "bar baz" }, + { "children", new JArray { leafEntity } }, + { "entityValue", new JObject { { "bar", "qux" } } }, + }; + + var entity = new JObject + { + { "entity", "foo" }, + { "startPos", 0 }, + { "endPos", 10 }, + { "children", new JArray { midEntity } }, + }; + + var json = new JObject + { + { "text", text }, + { "entities", new JArray { entity } }, + }; + + var serializer = CreateSerializer(); + var actual = json.ToObject(serializer); + actual.Text.Should().Be(text); + actual.Entities.Count.Should().Be(3); + actual.Entities[0].EntityType.Should().Be("foo"); + actual.Entities[0].MatchText.Should().Be(text); + actual.Entities[1].EntityType.Should().Be("foo::bar"); + actual.Entities[1].MatchText.Should().Be("bar baz"); + actual.Entities[1].EntityValue.Should().BeEquivalentTo(new JObject { { "bar", "qux" } }); + actual.Entities[2].EntityType.Should().Be("foo::bar::baz"); + actual.Entities[2].MatchText.Should().Be("baz"); + + var additionalProperties = actual.Entities[2].As().AdditionalProperties; + additionalProperties["foo"].As().Should().BeEquivalentTo(new JArray(42)); + additionalProperties["bar"].Should().BeNull(); + additionalProperties["baz"].Should().Be(42); + additionalProperties["qux"].Should().BeNull(); + } + private static JsonSerializer CreateSerializer() { var serializer = JsonSerializer.CreateDefault(); diff --git a/src/NLU.DevOps.Core.Tests/NLU.DevOps.Core.Tests.csproj b/src/NLU.DevOps.Core.Tests/NLU.DevOps.Core.Tests.csproj index 0d41f59..d20d1a8 100644 --- a/src/NLU.DevOps.Core.Tests/NLU.DevOps.Core.Tests.csproj +++ b/src/NLU.DevOps.Core.Tests/NLU.DevOps.Core.Tests.csproj @@ -15,7 +15,8 @@ - + + diff --git a/src/NLU.DevOps.Core/EntityConverter.cs b/src/NLU.DevOps.Core/EntityConverter.cs index ac0f256..ef7454f 100644 --- a/src/NLU.DevOps.Core/EntityConverter.cs +++ b/src/NLU.DevOps.Core/EntityConverter.cs @@ -4,6 +4,8 @@ namespace NLU.DevOps.Core { using System; + using System.Collections.Generic; + using System.Diagnostics; using Newtonsoft.Json; using Newtonsoft.Json.Linq; @@ -16,42 +18,33 @@ public EntityConverter(string utterance) private string Utterance { get; } + private string Prefix { get; set; } = string.Empty; + public override Entity ReadJson(JsonReader reader, Type objectType, Entity existingValue, bool hasExistingValue, JsonSerializer serializer) { + Debug.Assert(!hasExistingValue, "Entity instance can only be constructor initialized."); + var jsonObject = JObject.Load(reader); + return typeof(HierarchicalEntity).IsAssignableFrom(objectType) + ? this.ReadHierarchicalEntity(jsonObject, serializer) + : this.ReadEntity(jsonObject, objectType, serializer); + } + + public override void WriteJson(JsonWriter writer, Entity value, JsonSerializer serializer) + { + throw new NotImplementedException(); + } + + private Entity ReadEntity(JObject jsonObject, Type objectType, JsonSerializer serializer) + { var matchText = jsonObject.Value("matchText"); + var matchIndex = jsonObject.Value("matchIndex"); var startPosOrNull = jsonObject.Value("startPos"); var endPosOrNull = jsonObject.Value("endPos"); - if (matchText == null && startPosOrNull != null && endPosOrNull != null) + if (matchText == null && startPosOrNull.HasValue && endPosOrNull.HasValue) { - var startPos = startPosOrNull.Value; - var endPos = endPosOrNull.Value; - var length = endPos - startPos + 1; - if (!this.IsValid(startPos, endPos)) - { - throw new InvalidOperationException( - $"Invalid start position '{startPos}' or end position '{endPos}' for utterance '{this.Utterance}'."); - } - - matchText = this.Utterance.Substring(startPos, length); + (matchText, matchIndex) = this.GetMatchInfo(startPosOrNull.Value, endPosOrNull.Value); jsonObject.Add("matchText", matchText); - var matchIndex = 0; - var currentPos = 0; - while (true) - { - currentPos = this.Utterance.IndexOf(matchText, currentPos, StringComparison.InvariantCulture); - - // Because 'matchText' is derived from the utterance from 'startPos' and 'endPos', - // we are guaranteed to find a match at with index 'startPos'. - if (currentPos == startPos) - { - break; - } - - currentPos += length; - matchIndex++; - } - jsonObject.Add("matchIndex", matchIndex); jsonObject.Remove("startPos"); jsonObject.Remove("endPos"); @@ -76,9 +69,86 @@ public override Entity ReadJson(JsonReader reader, Type objectType, Entity exist } } - public override void WriteJson(JsonWriter writer, Entity value, JsonSerializer serializer) + private HierarchicalEntity ReadHierarchicalEntity(JObject jsonObject, JsonSerializer serializer) { - throw new NotImplementedException(); + var matchText = jsonObject.Value("matchText"); + var matchIndex = jsonObject.Value("matchIndex"); + var startPosOrNull = jsonObject.Value("startPos"); + var endPosOrNull = jsonObject.Value("endPos"); + if (matchText == null && startPosOrNull.HasValue && endPosOrNull.HasValue) + { + (matchText, matchIndex) = this.GetMatchInfo(startPosOrNull.Value, endPosOrNull.Value); + } + + var entityType = jsonObject.Value("entityType") ?? jsonObject.Value("entity"); + var childrenJson = jsonObject["children"]; + var children = default(IEnumerable); + if (childrenJson != null) + { + var prefix = $"{entityType}::"; + this.Prefix += prefix; + try + { + children = childrenJson.ToObject>(serializer); + } + finally + { + this.Prefix = this.Prefix.Substring(0, this.Prefix.Length - prefix.Length); + } + } + + var entity = new HierarchicalEntity($"{this.Prefix}{entityType}", jsonObject["entityValue"], matchText, matchIndex, children); + foreach (var property in jsonObject) + { + switch (property.Key) + { + case "children": + case "endPos": + case "entity": + case "entityType": + case "entityValue": + case "matchText": + case "matchIndex": + case "startPos": + break; + default: + var value = property.Value is JValue jsonValue ? jsonValue.Value : property.Value; + entity.AdditionalProperties.Add(property.Key, value); + break; + } + } + + return entity; + } + + private Tuple GetMatchInfo(int startPos, int endPos) + { + if (!this.IsValid(startPos, endPos)) + { + throw new InvalidOperationException( + $"Invalid start position '{startPos}' or end position '{endPos}' for utterance '{this.Utterance}'."); + } + + var length = endPos - startPos + 1; + var matchText = this.Utterance.Substring(startPos, length); + var matchIndex = 0; + var currentPos = 0; + while (true) + { + currentPos = this.Utterance.IndexOf(matchText, currentPos, StringComparison.InvariantCulture); + + // Because 'matchText' is derived from the utterance from 'startPos' and 'endPos', + // we are guaranteed to find a match at with index 'startPos'. + if (currentPos == startPos) + { + break; + } + + currentPos += length; + matchIndex++; + } + + return Tuple.Create(matchText, matchIndex); } private bool IsValid(int startPos, int endPos) diff --git a/src/NLU.DevOps.Core/HierarchicalEntity.cs b/src/NLU.DevOps.Core/HierarchicalEntity.cs new file mode 100644 index 0000000..fe520f4 --- /dev/null +++ b/src/NLU.DevOps.Core/HierarchicalEntity.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace NLU.DevOps.Core +{ + using System.Collections.Generic; + using Newtonsoft.Json.Linq; + + /// + /// Entity appearing in utterance. + /// + public sealed class HierarchicalEntity : Entity, IHierarchicalEntity + { + /// + /// Initializes a new instance of the class. + /// + /// Entity type name. + /// Entity value, generally a canonical form of the entity. + /// Matching text in the utterance. + /// Occurrence index of matching token in the utterance. + /// Children entities. + public HierarchicalEntity(string entityType, JToken entityValue, string matchText, int matchIndex, IEnumerable children) + : base(entityType, entityValue, matchText, matchIndex) + { + this.Children = children; + } + + /// + public IEnumerable Children { get; } + } +} diff --git a/src/NLU.DevOps.Core/IHierarchicalEntity.cs b/src/NLU.DevOps.Core/IHierarchicalEntity.cs new file mode 100644 index 0000000..268c3fb --- /dev/null +++ b/src/NLU.DevOps.Core/IHierarchicalEntity.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace NLU.DevOps.Core +{ + using System.Collections.Generic; + using Models; + + /// + /// Entity with nested children. + /// + public interface IHierarchicalEntity : IEntity + { + /// + /// Gets the child entities. + /// + IEnumerable Children { get; } + } +} diff --git a/src/NLU.DevOps.Core/JsonEntities.cs b/src/NLU.DevOps.Core/JsonEntities.cs index 1e8f333..738868c 100644 --- a/src/NLU.DevOps.Core/JsonEntities.cs +++ b/src/NLU.DevOps.Core/JsonEntities.cs @@ -3,7 +3,10 @@ namespace NLU.DevOps.Core { + using System; using System.Collections.Generic; + using System.Linq; + using Models; using Newtonsoft.Json; /// @@ -15,20 +18,44 @@ public class JsonEntities /// Initializes a new instance of the class. /// /// Entities referenced in the utterance. - public JsonEntities(IReadOnlyList entities) + public JsonEntities(IEnumerable entities) { - this.Entities = entities; + this.Entities = FlattenChildren(entities)?.ToArray(); } /// /// Gets the entities referenced in the utterance. /// - public IReadOnlyList Entities { get; } + public IReadOnlyList Entities { get; } /// /// Gets the additional properties. /// [JsonExtensionData] public IDictionary AdditionalProperties { get; } = new Dictionary(); + + private static IEnumerable FlattenChildren(IEnumerable entities, string prefix = "") + { + if (entities == null) + { + return null; + } + + IEnumerable getChildren(IHierarchicalEntity entity) + { + yield return entity; + + var children = FlattenChildren(entity.Children, $"{prefix}{entity.EntityType}::"); + if (children != null) + { + foreach (var child in children) + { + yield return child; + } + } + } + + return entities.SelectMany(getChildren); + } } } diff --git a/src/NLU.DevOps.Core/JsonLabeledUtteranceConverter.cs b/src/NLU.DevOps.Core/JsonLabeledUtteranceConverter.cs index 5730bb2..c7fa870 100644 --- a/src/NLU.DevOps.Core/JsonLabeledUtteranceConverter.cs +++ b/src/NLU.DevOps.Core/JsonLabeledUtteranceConverter.cs @@ -4,6 +4,7 @@ namespace NLU.DevOps.Core { using System; + using System.Diagnostics; using Newtonsoft.Json; using Newtonsoft.Json.Linq; @@ -18,6 +19,8 @@ public class JsonLabeledUtteranceConverter : JsonConverter /// public override JsonLabeledUtterance ReadJson(JsonReader reader, Type objectType, JsonLabeledUtterance existingValue, bool hasExistingValue, JsonSerializer serializer) { + Debug.Assert(!hasExistingValue, "Utterance instance can only be constructor initialized."); + var jsonObject = JObject.Load(reader); var utterance = jsonObject.Value("text") ?? jsonObject.Value("query"); var entityConverter = new EntityConverter(utterance); diff --git a/src/NLU.DevOps.LuisV3.Tests/LuisNLUTestClientTests.cs b/src/NLU.DevOps.LuisV3.Tests/LuisNLUTestClientTests.cs index 2ebf2e3..ab2fc4d 100644 --- a/src/NLU.DevOps.LuisV3.Tests/LuisNLUTestClientTests.cs +++ b/src/NLU.DevOps.LuisV3.Tests/LuisNLUTestClientTests.cs @@ -402,19 +402,19 @@ public static async Task UtteranceWithNestedMLEntity() result.Text.Should().Be(test); result.Intent.Should().Be("RequestVacation"); result.Entities.Count.Should().Be(7); - result.Entities[0].EntityType.Should().Be("leave-type"); + result.Entities[0].EntityType.Should().Be("vacation-request::leave-type"); result.Entities[0].EntityValue.Should().BeEquivalentTo(@"[ ""sick"" ]"); result.Entities[0].MatchText.Should().Be("sick leave"); result.Entities[0].MatchIndex.Should().Be(0); - result.Entities[1].EntityType.Should().Be("days-number"); + result.Entities[1].EntityType.Should().Be("vacation-request::days-duration::days-number"); result.Entities[1].EntityValue.Should().BeEquivalentTo("6"); result.Entities[1].MatchText.Should().Be("6"); result.Entities[1].MatchIndex.Should().Be(0); - result.Entities[2].EntityType.Should().Be("days-duration"); + result.Entities[2].EntityType.Should().Be("vacation-request::days-duration"); result.Entities[2].EntityValue.Should().BeEquivalentTo(@"{ ""days-number"": [ 6 ] }"); result.Entities[2].MatchText.Should().Be("6 days"); result.Entities[2].MatchIndex.Should().Be(0); - result.Entities[3].EntityType.Should().Be("start-date"); + result.Entities[3].EntityType.Should().Be("vacation-request::start-date"); result.Entities[3].MatchText.Should().Be("starting march 5"); result.Entities[3].MatchIndex.Should().Be(0); result.Entities[4].EntityType.Should().Be("vacation-request"); diff --git a/src/NLU.DevOps.LuisV3/LuisNLUTestClient.cs b/src/NLU.DevOps.LuisV3/LuisNLUTestClient.cs index c0761b1..1afcb0d 100644 --- a/src/NLU.DevOps.LuisV3/LuisNLUTestClient.cs +++ b/src/NLU.DevOps.LuisV3/LuisNLUTestClient.cs @@ -102,7 +102,7 @@ private static IEnumerable GetEntities(string utterance, IDictionary getEntitiesForType(string type, object instances, JToken metadata) + IEnumerable getEntitiesForType(string prefix, string type, object instances, JToken metadata) { if (instances is JArray instancesJson) { @@ -111,14 +111,14 @@ IEnumerable getEntitiesForType(string type, object instances, JToken me .Zip( typeMetadata, (instance, instanceMetadata) => - getEntitiesRecursive(type, instance, instanceMetadata)) + getEntitiesRecursive(prefix, type, instance, instanceMetadata)) .SelectMany(e => e); } return Array.Empty(); } - IEnumerable getEntitiesRecursive(string entityType, JToken entityJson, JToken entityMetadata) + IEnumerable getEntitiesRecursive(string prefix, string entityType, JToken entityJson, JToken entityMetadata) { var startIndex = entityMetadata.Value("startIndex"); var length = entityMetadata.Value("length"); @@ -136,7 +136,7 @@ IEnumerable getEntitiesRecursive(string entityType, JToken entityJson, if (entityJson is JObject entityJsonObject && entityJsonObject.TryGetValue("$instance", out var innerMetadata)) { var children = ((IDictionary)entityJsonObject) - .SelectMany(pair => getEntitiesForType(pair.Key, pair.Value, innerMetadata)); + .SelectMany(pair => getEntitiesForType($"{prefix}{entityType}::", pair.Key, pair.Value, innerMetadata)); foreach (var child in children) { @@ -144,7 +144,7 @@ IEnumerable getEntitiesRecursive(string entityType, JToken entityJson, } } - yield return new Entity(entityType, entityValue, matchText, matchIndex) + yield return new Entity($"{prefix}{entityType}", entityValue, matchText, matchIndex) .WithScore(score); } @@ -159,7 +159,7 @@ IEnumerable getEntitiesRecursive(string entityType, JToken entityJson, } return entities.SelectMany(pair => - getEntitiesForType(pair.Key, pair.Value, globalMetadata)); + getEntitiesForType(string.Empty, pair.Key, pair.Value, globalMetadata)); } private static JToken PruneMetadata(JToken json) diff --git a/src/NLU.DevOps.ModelPerformance.Tests/TestCaseSourceTests.cs b/src/NLU.DevOps.ModelPerformance.Tests/TestCaseSourceTests.cs index a388f36..97a7adc 100644 --- a/src/NLU.DevOps.ModelPerformance.Tests/TestCaseSourceTests.cs +++ b/src/NLU.DevOps.ModelPerformance.Tests/TestCaseSourceTests.cs @@ -711,7 +711,7 @@ public static void GlobalLocalStrictIgnoreEntities( [Test] public static void NoFalsePositiveIntentsUnitTestMode() { - var expectedUtterance = new JsonLabeledUtterance(new JsonEntities(Array.Empty())); + var expectedUtterance = new JsonLabeledUtterance(new JsonEntities(Array.Empty())); var actualUtterance = new LabeledUtterance(null, "foo", null); var testSettings = new TestSettings(default(string), true); @@ -729,7 +729,7 @@ public static void NoFalsePositiveIntentsUnitTestMode() [Test] public static void HasFalsePositiveIntentsUnitTestMode() { - var jsonEntities = new JsonEntities(Array.Empty()) + var jsonEntities = new JsonEntities(Array.Empty()) { AdditionalProperties = {