Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions wire-runtime-swift/src/main/swift/AnyMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,27 @@ extension AnyMessage: Codable {
case typeURL = "@type"
case value = "value"
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.typeURL = try container.decode(String.self, forKey: .typeURL)
if typeURL == FieldMask.protoMessageTypeURL() {
let fieldMask = try container.decode(FieldMask.self, forKey: .value)
self.value = try ProtoEncoder().encode(fieldMask)
} else {
self.value = try container.decode(Data.self, forKey: .value)
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(typeURL, forKey: .typeURL)
if typeURL == FieldMask.protoMessageTypeURL() {
let fieldMask = try ProtoDecoder().decode(FieldMask.self, from: value)
try container.encode(fieldMask, forKey: .value)
} else {
try container.encode(value, forKey: .value)
}
}
}
#endif
14 changes: 11 additions & 3 deletions wire-runtime-swift/src/main/swift/FieldMask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ extension FieldMask: Codable {
var result = ""
for scalar in value.unicodeScalars {
if scalar.value >= 65 && scalar.value <= 90 {
if !result.isEmpty {
result.append("_")
}
result.append("_")
result.append(String(scalar).lowercased())
} else {
result.append(String(scalar).lowercased())
Expand All @@ -138,3 +136,13 @@ extension FieldMask: Codable {
}
}
#endif

extension ProtoReader {
public func decode(_ type: FieldMask.Type, mergingInto existing: FieldMask?) throws -> FieldMask {
let decoded = try decode(type)
guard let existing = existing else {
return decoded
}
return FieldMask(paths: existing.paths + decoded.paths)
}
}
17 changes: 17 additions & 0 deletions wire-runtime-swift/src/test/swift/AnyMessageTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ final class AnyMessageTests: XCTestCase {
XCTAssertEqual(fieldMask, try any.unpack(FieldMask.self))
}

func testFieldMaskJSONUsesSpecialValueEncoding() throws {
let fieldMask = FieldMask(paths: ["masked_name"])
let any = try AnyMessage.pack(fieldMask)

let encoded = try JSONEncoder().encode(any)
let decoded = try JSONDecoder().decode(
AnyMessage.self,
from: Foundation.Data(#"{"@type":"type.googleapis.com/google.protobuf.FieldMask","value":"maskedName"}"#.utf8)
)
let encodedObject = try XCTUnwrap(JSONSerialization.jsonObject(with: encoded) as? [String: String])

XCTAssertEqual(encodedObject["@type"], FieldMask.protoMessageTypeURL())
XCTAssertEqual(encodedObject["value"], "maskedName")
XCTAssertEqual(try JSONDecoder().decode(AnyMessage.self, from: encoded), any)
XCTAssertEqual(decoded, any)
}

func testUnpackingAnyToCorrectType() throws {
let data = Data(json_data: "")
let person = Person(name: "foo bar", id: 12345, data: data)
Expand Down
18 changes: 18 additions & 0 deletions wire-runtime-swift/src/test/swift/FieldMaskTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ final class FieldMaskTests: XCTestCase {
XCTAssertEqual(decoded, FieldMask(paths: ["photo"]))
}

func testJSONPreservesLeadingUppercaseAsUnderscore() throws {
let jsonData = Foundation.Data(#""foo.Bar""#.utf8)

let fieldMask = try JSONDecoder().decode(FieldMask.self, from: jsonData)
let encoded = try JSONEncoder().encode(fieldMask)

XCTAssertEqual(fieldMask, FieldMask(paths: ["foo._bar"]))
XCTAssertEqual(String(data: encoded, encoding: .utf8), #""foo.Bar""#)
}

func testGeneratedMessageWithFieldMask() throws {
let fieldMask = FieldMask(paths: ["user.display_name", "photo"])
let otherFieldMask = FieldMask(paths: ["updated_at.seconds"])
Expand All @@ -83,4 +93,12 @@ final class FieldMaskTests: XCTestCase {
XCTAssertEqual(decoded.masks, [fieldMask, otherFieldMask])
XCTAssertEqual(decoded.masks_by_id, [1: fieldMask, 2: otherFieldMask])
}

func testGeneratedMessageMergesDuplicateSingularFieldMask() throws {
let data = Foundation.Data(hexEncoded: "0a030a01610a030a0162")!

let decoded = try ProtoDecoder().decode(MessageContainingFieldMask.self, from: data)

XCTAssertEqual(decoded.mask, FieldMask(paths: ["a", "b"]))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,24 @@ class RuntimeMessageAdapter<M : Any, B : Any>(
val field = fields[tag]
try {
if (field != null) {
val adapter = if (field.isMap) {
field.adapter
if (field.isMap) {
val value = field.adapter.decode(reader)
field.value(builder, value)
} else {
field.singleAdapter
val singleAdapter = field.singleAdapter
if ((field.isMessage || singleAdapter == ProtoAdapter.FIELD_MASK) &&
!field.label.isRepeated &&
!field.label.isOneOf
) {
@Suppress("UNCHECKED_CAST")
val adapter = singleAdapter as ProtoAdapter<Any>
val value = decodeMessageOrMerge(adapter, reader, field.getFromBuilder(builder))
field.set(builder, value)
} else {
val value = singleAdapter.decode(reader)
field.value(builder, value!!)
}
}
val value = adapter.decode(reader)
field.value(builder, value!!)
} else {
val fieldEncoding = reader.peekFieldEncoding()!!
val value = fieldEncoding.rawProtoAdapter().decode(reader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.squareup.wire.durationOfSeconds
import com.squareup.wire.ofEpochSecond
import kotlin.test.Ignore
import kotlin.test.Test
import okio.ByteString.Companion.decodeHex
import okio.ByteString.Companion.encodeUtf8
import okio.Path.Companion.toPath

Expand Down Expand Up @@ -87,6 +88,30 @@ class DynamicSerializationTest {
assertThat(adapter.decode(adapter.encode(expected))).isEqualTo(expected)
}

@Test
fun singularFieldMaskOccurrencesAreMerged() {
val schema = buildSchema {
add(
"message.proto".toPath(),
"""
|syntax = "proto3";
|import "google/protobuf/field_mask.proto";
|
|message Message {
| google.protobuf.FieldMask field_mask_field = 1;
|}
|
""".trimMargin(),
)
}

val adapter = schema.protoAdapter(typeName = "Message", includeUnknown = true)
val encoded = "0a030a01610a030a0162".decodeHex()

assertThat(adapter.decode(encoded))
.isEqualTo(mapOf("field_mask_field" to FieldMask(listOf("a", "b"))))
}

@Ignore // Not supported.
@Test
fun mapTest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ class SwiftGenerator private constructor(
} else {
if (field.isRepeated) {
decoder.add("try $reader.decode(into: &%N", field.safeName)
} else if (field.type == ProtoType.FIELD_MASK) {
val typeName = field.typeName.makeNonOptional()

decoder.add("%N = try $reader.decode(%T.self, mergingInto: %N", field.safeName, typeName, field.safeName)
} else {
val typeName = field.typeName.makeNonOptional()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class SwiftGeneratorTest {
assertThat(code).contains("public var mask: FieldMask?")
assertThat(code).contains("public var masks: [FieldMask]")
assertThat(code).contains("public var masks_by_id: [Int32 : FieldMask]")
assertThat(code).contains("mask = try protoReader.decode(FieldMask.self)")
assertThat(code).contains("mask = try protoReader.decode(FieldMask.self, mergingInto: mask)")
assertThat(code).contains("try protoReader.decode(into: &masks)")
assertThat(code).contains("try protoReader.decode(into: &masks_by_id, keyEncoding: .variable)")
assertThat(code).doesNotContain("@ProtoDefaulted")
Expand Down
2 changes: 1 addition & 1 deletion wire-tests-swift/no-manifest/src/main/swift/AllTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,7 @@ extension AllTypes.Storage : Proto2Codable {
case 15: opt_bytes = try protoReader.decode(Foundation.Data.self)
case 16: opt_nested_enum = try protoReader.decode(AllTypes.NestedEnum.self)
case 17: opt_nested_message = try protoReader.decode(AllTypes.NestedMessage.self)
case 18: opt_field_mask = try protoReader.decode(FieldMask.self)
case 18: opt_field_mask = try protoReader.decode(FieldMask.self, mergingInto: opt_field_mask)
case 101: req_int32 = try protoReader.decode(Int32.self, encoding: .variable)
case 102: req_uint32 = try protoReader.decode(UInt32.self, encoding: .variable)
case 103: req_sint32 = try protoReader.decode(Int32.self, encoding: .signed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import com.squareup.wire.proto3.alltypes.AllTypes as AllTypesProto3
import java.io.File
import java.util.Collections
import okio.ByteString
import okio.ByteString.Companion.decodeHex
import okio.buffer
import okio.source
import org.junit.Assert.fail
Expand Down Expand Up @@ -315,6 +316,12 @@ class WireJsonTest {
assertThat(jsonLibrary.fromJson(json, FieldMask::class.java)).isEqualTo(value)
}

@Test fun singularFieldMaskOccurrencesAreMerged() {
val value = ContainsFieldMask.ADAPTER.decode("0a030a01610a030a0162".decodeHex())

assertThat(value.mask).isEqualTo(FieldMask(listOf("a", "b")))
}

@Test fun anyMessageWithUnregisteredTypeOnReading() {
try {
jsonLibrary.fromJson(PIZZA_DELIVERY_UNKNOWN_TYPE_JSON, PizzaDelivery::class.java)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class FieldMaskJsonFormatterTest {
.isEqualTo(FieldMask(listOf("user.display_name", "photo", "foo_bar.baz_qux")))
}

@Test fun `leading uppercase path segment decodes with leading underscore`() {
assertThat(fromString("foo.Bar"))
.isEqualTo(FieldMask(listOf("foo._bar")))
}

@Test fun `empty paths are skipped`() {
assertThat(toStringOrNumber(FieldMask(listOf("", "photo", ""))))
.isEqualTo("photo")
Expand Down
Loading