diff --git a/wire-runtime-swift/src/main/swift/AnyMessage.swift b/wire-runtime-swift/src/main/swift/AnyMessage.swift index 05bd68faa6..5ef98e0ef0 100644 --- a/wire-runtime-swift/src/main/swift/AnyMessage.swift +++ b/wire-runtime-swift/src/main/swift/AnyMessage.swift @@ -117,5 +117,27 @@ extension AnyMessage: Codable { case typeURL = "@type" case value = "value" } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.typeURL = try container.decode(String.self, forKey: .typeURL) + if typeURL == FieldMask.protoMessageTypeURL() { + let fieldMask = try container.decode(FieldMask.self, forKey: .value) + self.value = try ProtoEncoder().encode(fieldMask) + } else { + self.value = try container.decode(Data.self, forKey: .value) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(typeURL, forKey: .typeURL) + if typeURL == FieldMask.protoMessageTypeURL() { + let fieldMask = try ProtoDecoder().decode(FieldMask.self, from: value) + try container.encode(fieldMask, forKey: .value) + } else { + try container.encode(value, forKey: .value) + } + } } #endif diff --git a/wire-runtime-swift/src/main/swift/FieldMask.swift b/wire-runtime-swift/src/main/swift/FieldMask.swift index d9c803b51f..6b8071dc38 100644 --- a/wire-runtime-swift/src/main/swift/FieldMask.swift +++ b/wire-runtime-swift/src/main/swift/FieldMask.swift @@ -126,9 +126,7 @@ extension FieldMask: Codable { var result = "" for scalar in value.unicodeScalars { if scalar.value >= 65 && scalar.value <= 90 { - if !result.isEmpty { - result.append("_") - } + result.append("_") result.append(String(scalar).lowercased()) } else { result.append(String(scalar).lowercased()) @@ -138,3 +136,13 @@ extension FieldMask: Codable { } } #endif + +extension ProtoReader { + public func decode(_ type: FieldMask.Type, mergingInto existing: FieldMask?) throws -> FieldMask { + let decoded = try decode(type) + guard let existing = existing else { + return decoded + } + return FieldMask(paths: existing.paths + decoded.paths) + } +} diff --git a/wire-runtime-swift/src/test/swift/AnyMessageTests.swift b/wire-runtime-swift/src/test/swift/AnyMessageTests.swift index bd182a61f4..8f69d24f78 100644 --- a/wire-runtime-swift/src/test/swift/AnyMessageTests.swift +++ b/wire-runtime-swift/src/test/swift/AnyMessageTests.swift @@ -34,6 +34,23 @@ final class AnyMessageTests: XCTestCase { XCTAssertEqual(fieldMask, try any.unpack(FieldMask.self)) } + func testFieldMaskJSONUsesSpecialValueEncoding() throws { + let fieldMask = FieldMask(paths: ["masked_name"]) + let any = try AnyMessage.pack(fieldMask) + + let encoded = try JSONEncoder().encode(any) + let decoded = try JSONDecoder().decode( + AnyMessage.self, + from: Foundation.Data(#"{"@type":"type.googleapis.com/google.protobuf.FieldMask","value":"maskedName"}"#.utf8) + ) + let encodedObject = try XCTUnwrap(JSONSerialization.jsonObject(with: encoded) as? [String: String]) + + XCTAssertEqual(encodedObject["@type"], FieldMask.protoMessageTypeURL()) + XCTAssertEqual(encodedObject["value"], "maskedName") + XCTAssertEqual(try JSONDecoder().decode(AnyMessage.self, from: encoded), any) + XCTAssertEqual(decoded, any) + } + func testUnpackingAnyToCorrectType() throws { let data = Data(json_data: "") let person = Person(name: "foo bar", id: 12345, data: data) diff --git a/wire-runtime-swift/src/test/swift/FieldMaskTests.swift b/wire-runtime-swift/src/test/swift/FieldMaskTests.swift index b878069457..f7898c33c1 100644 --- a/wire-runtime-swift/src/test/swift/FieldMaskTests.swift +++ b/wire-runtime-swift/src/test/swift/FieldMaskTests.swift @@ -67,6 +67,16 @@ final class FieldMaskTests: XCTestCase { XCTAssertEqual(decoded, FieldMask(paths: ["photo"])) } + func testJSONPreservesLeadingUppercaseAsUnderscore() throws { + let jsonData = Foundation.Data(#""foo.Bar""#.utf8) + + let fieldMask = try JSONDecoder().decode(FieldMask.self, from: jsonData) + let encoded = try JSONEncoder().encode(fieldMask) + + XCTAssertEqual(fieldMask, FieldMask(paths: ["foo._bar"])) + XCTAssertEqual(String(data: encoded, encoding: .utf8), #""foo.Bar""#) + } + func testGeneratedMessageWithFieldMask() throws { let fieldMask = FieldMask(paths: ["user.display_name", "photo"]) let otherFieldMask = FieldMask(paths: ["updated_at.seconds"]) @@ -83,4 +93,12 @@ final class FieldMaskTests: XCTestCase { XCTAssertEqual(decoded.masks, [fieldMask, otherFieldMask]) XCTAssertEqual(decoded.masks_by_id, [1: fieldMask, 2: otherFieldMask]) } + + func testGeneratedMessageMergesDuplicateSingularFieldMask() throws { + let data = Foundation.Data(hexEncoded: "0a030a01610a030a0162")! + + let decoded = try ProtoDecoder().decode(MessageContainingFieldMask.self, from: data) + + XCTAssertEqual(decoded.mask, FieldMask(paths: ["a", "b"])) + } } diff --git a/wire-runtime/src/commonMain/kotlin/com/squareup/wire/internal/RuntimeMessageAdapter.kt b/wire-runtime/src/commonMain/kotlin/com/squareup/wire/internal/RuntimeMessageAdapter.kt index 4c967d9a2c..8e4946813f 100644 --- a/wire-runtime/src/commonMain/kotlin/com/squareup/wire/internal/RuntimeMessageAdapter.kt +++ b/wire-runtime/src/commonMain/kotlin/com/squareup/wire/internal/RuntimeMessageAdapter.kt @@ -169,13 +169,24 @@ class RuntimeMessageAdapter( val field = fields[tag] try { if (field != null) { - val adapter = if (field.isMap) { - field.adapter + if (field.isMap) { + val value = field.adapter.decode(reader) + field.value(builder, value) } else { - field.singleAdapter + val singleAdapter = field.singleAdapter + if ((field.isMessage || singleAdapter == ProtoAdapter.FIELD_MASK) && + !field.label.isRepeated && + !field.label.isOneOf + ) { + @Suppress("UNCHECKED_CAST") + val adapter = singleAdapter as ProtoAdapter + val value = decodeMessageOrMerge(adapter, reader, field.getFromBuilder(builder)) + field.set(builder, value) + } else { + val value = singleAdapter.decode(reader) + field.value(builder, value!!) + } } - val value = adapter.decode(reader) - field.value(builder, value!!) } else { val fieldEncoding = reader.peekFieldEncoding()!! val value = fieldEncoding.rawProtoAdapter().decode(reader) diff --git a/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/DynamicSerializationTest.kt b/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/DynamicSerializationTest.kt index ad3993baaa..679ac0bb3f 100644 --- a/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/DynamicSerializationTest.kt +++ b/wire-schema/src/jvmTest/kotlin/com/squareup/wire/schema/DynamicSerializationTest.kt @@ -23,6 +23,7 @@ import com.squareup.wire.durationOfSeconds import com.squareup.wire.ofEpochSecond import kotlin.test.Ignore import kotlin.test.Test +import okio.ByteString.Companion.decodeHex import okio.ByteString.Companion.encodeUtf8 import okio.Path.Companion.toPath @@ -87,6 +88,30 @@ class DynamicSerializationTest { assertThat(adapter.decode(adapter.encode(expected))).isEqualTo(expected) } + @Test + fun singularFieldMaskOccurrencesAreMerged() { + val schema = buildSchema { + add( + "message.proto".toPath(), + """ + |syntax = "proto3"; + |import "google/protobuf/field_mask.proto"; + | + |message Message { + | google.protobuf.FieldMask field_mask_field = 1; + |} + | + """.trimMargin(), + ) + } + + val adapter = schema.protoAdapter(typeName = "Message", includeUnknown = true) + val encoded = "0a030a01610a030a0162".decodeHex() + + assertThat(adapter.decode(encoded)) + .isEqualTo(mapOf("field_mask_field" to FieldMask(listOf("a", "b")))) + } + @Ignore // Not supported. @Test fun mapTest() { diff --git a/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt b/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt index 2268b73fd6..fd43446bf7 100644 --- a/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt +++ b/wire-swift-generator/src/main/java/com/squareup/wire/swift/SwiftGenerator.kt @@ -635,6 +635,10 @@ class SwiftGenerator private constructor( } else { if (field.isRepeated) { decoder.add("try $reader.decode(into: &%N", field.safeName) + } else if (field.type == ProtoType.FIELD_MASK) { + val typeName = field.typeName.makeNonOptional() + + decoder.add("%N = try $reader.decode(%T.self, mergingInto: %N", field.safeName, typeName, field.safeName) } else { val typeName = field.typeName.makeNonOptional() diff --git a/wire-swift-generator/src/test/java/com/squareup/wire/swift/SwiftGeneratorTest.kt b/wire-swift-generator/src/test/java/com/squareup/wire/swift/SwiftGeneratorTest.kt index 2142b65400..c8ec5b2ae2 100644 --- a/wire-swift-generator/src/test/java/com/squareup/wire/swift/SwiftGeneratorTest.kt +++ b/wire-swift-generator/src/test/java/com/squareup/wire/swift/SwiftGeneratorTest.kt @@ -90,7 +90,7 @@ class SwiftGeneratorTest { assertThat(code).contains("public var mask: FieldMask?") assertThat(code).contains("public var masks: [FieldMask]") assertThat(code).contains("public var masks_by_id: [Int32 : FieldMask]") - assertThat(code).contains("mask = try protoReader.decode(FieldMask.self)") + assertThat(code).contains("mask = try protoReader.decode(FieldMask.self, mergingInto: mask)") assertThat(code).contains("try protoReader.decode(into: &masks)") assertThat(code).contains("try protoReader.decode(into: &masks_by_id, keyEncoding: .variable)") assertThat(code).doesNotContain("@ProtoDefaulted") diff --git a/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift b/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift index a2dc2173ed..3e45cda50f 100644 --- a/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift +++ b/wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift @@ -2394,7 +2394,7 @@ extension AllTypes.Storage : Proto2Codable { case 15: opt_bytes = try protoReader.decode(Foundation.Data.self) case 16: opt_nested_enum = try protoReader.decode(AllTypes.NestedEnum.self) case 17: opt_nested_message = try protoReader.decode(AllTypes.NestedMessage.self) - case 18: opt_field_mask = try protoReader.decode(FieldMask.self) + case 18: opt_field_mask = try protoReader.decode(FieldMask.self, mergingInto: opt_field_mask) case 101: req_int32 = try protoReader.decode(Int32.self, encoding: .variable) case 102: req_uint32 = try protoReader.decode(UInt32.self, encoding: .variable) case 103: req_sint32 = try protoReader.decode(Int32.self, encoding: .signed) diff --git a/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/WireJsonTest.kt b/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/WireJsonTest.kt index 8b0f78c401..9cfd57e5b3 100644 --- a/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/WireJsonTest.kt +++ b/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/WireJsonTest.kt @@ -31,6 +31,7 @@ import com.squareup.wire.proto3.alltypes.AllTypes as AllTypesProto3 import java.io.File import java.util.Collections import okio.ByteString +import okio.ByteString.Companion.decodeHex import okio.buffer import okio.source import org.junit.Assert.fail @@ -315,6 +316,12 @@ class WireJsonTest { assertThat(jsonLibrary.fromJson(json, FieldMask::class.java)).isEqualTo(value) } + @Test fun singularFieldMaskOccurrencesAreMerged() { + val value = ContainsFieldMask.ADAPTER.decode("0a030a01610a030a0162".decodeHex()) + + assertThat(value.mask).isEqualTo(FieldMask(listOf("a", "b"))) + } + @Test fun anyMessageWithUnregisteredTypeOnReading() { try { jsonLibrary.fromJson(PIZZA_DELIVERY_UNKNOWN_TYPE_JSON, PizzaDelivery::class.java) diff --git a/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/internal/FieldMaskJsonFormatterTest.kt b/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/internal/FieldMaskJsonFormatterTest.kt index cb06797dc5..e1168820b3 100644 --- a/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/internal/FieldMaskJsonFormatterTest.kt +++ b/wire-tests/wire-json-shared-kotlin-tests/com/squareup/wire/internal/FieldMaskJsonFormatterTest.kt @@ -34,6 +34,11 @@ class FieldMaskJsonFormatterTest { .isEqualTo(FieldMask(listOf("user.display_name", "photo", "foo_bar.baz_qux"))) } + @Test fun `leading uppercase path segment decodes with leading underscore`() { + assertThat(fromString("foo.Bar")) + .isEqualTo(FieldMask(listOf("foo._bar"))) + } + @Test fun `empty paths are skipped`() { assertThat(toStringOrNumber(FieldMask(listOf("", "photo", "")))) .isEqualTo("photo")