Skip to content

Commit b7f9071

Browse files
committed
🚸 Add support for predicates on optionals
1 parent 966d8a4 commit b7f9071

10 files changed

Lines changed: 268 additions & 19 deletions

File tree

PredicateKit.podspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
Pod::Spec.new do |spec|
1717
spec.name = "PredicateKit"
18-
spec.version = "1.3.0"
18+
spec.version = "1.4.0"
1919
spec.summary = "Write expressive and type-safe predicates for CoreData using key-paths, comparisons and logical operators, literal values, and functions."
2020
spec.description = <<-DESC
2121
PredicateKit allows Swift developers to write expressive and type-safe predicates for CoreData using key-paths, comparisons and logical operators, literal values, and functions.

PredicateKit/CoreData/NSFetchRequestBuilder.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ struct NSFetchRequestBuilder {
7373
case .direct, .any, .all:
7474
return NSComparisonPredicate(
7575
leftExpression: makeExpression(from: comparison.expression),
76-
rightExpression: NSExpression(forConstantValue: comparison.value),
76+
rightExpression: makeExpression(from: comparison.value),
7777
modifier: makeComparisonModifier(from: comparison.modifier),
7878
type: makeOperator(from: comparison.operator),
7979
options: makeComparisonOptions(from: comparison.options)
@@ -110,6 +110,10 @@ struct NSFetchRequestBuilder {
110110
expression.toNSExpression(conversionOptions)
111111
}
112112

113+
private func makeExpression(from primitive: Primitive) -> NSExpression {
114+
return NSExpression(forConstantValue: primitive.value)
115+
}
116+
113117
private func makeOperator(from operator: ComparisonOperator) -> NSComparisonPredicate.Operator {
114118
switch `operator` {
115119
case .beginsWith:
@@ -298,6 +302,19 @@ extension Query: NSExpressionConvertible {
298302
}
299303
}
300304

305+
// MARK: - Primitive
306+
307+
private extension Primitive {
308+
var value: Any? {
309+
switch Self.type {
310+
case .nil:
311+
return NSNull()
312+
default:
313+
return self
314+
}
315+
}
316+
}
317+
301318
// MARK: - KeyPath
302319

303320
extension AnyKeyPath {

PredicateKit/Predicate.swift

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ public enum Function<Input: Expression, Output>: Expression where Input.Value: A
290290

291291
public enum Index<Array: Expression>: Expression where Array.Value: AnyArray {
292292
public typealias Root = Array.Root
293-
public typealias Value = Array.Value.Element
293+
public typealias Value = Array.Value.ArrayElement
294294

295295
case index(Array, Int)
296296
case first(Array)
@@ -371,6 +371,11 @@ public func == <E: Expression, T: Equatable & Primitive> (lhs: E, rhs: T) -> Pre
371371
.comparison(.init(lhs, .equal, rhs))
372372
}
373373

374+
@_disfavoredOverload
375+
public func == <E: Expression> (lhs: E, rhs: Nil) -> Predicate<E.Root> where E.Value: OptionalType {
376+
.comparison(.init(lhs, .equal, rhs))
377+
}
378+
374379
public func != <E: Expression, T: Equatable & Primitive> (lhs: E, rhs: T) -> Predicate<E.Root> where E.Value == T {
375380
.comparison(.init(lhs, .notEqual, rhs))
376381
}
@@ -495,15 +500,15 @@ extension Expression where Value: AnyArray {
495500
.last(self)
496501
}
497502

498-
public func at<T>(index: Int, _ keyPath: KeyPath<Value.Element, T>) -> ArrayElementKeyPath<Self, T> {
503+
public func at<T>(index: Int, _ keyPath: KeyPath<Value.ArrayElement, T>) -> ArrayElementKeyPath<Self, T> {
499504
.init(.index(index), self, keyPath)
500505
}
501506

502-
public func first<T>(_ keyPath: KeyPath<Value.Element, T>) -> ArrayElementKeyPath<Self, T> {
507+
public func first<T>(_ keyPath: KeyPath<Value.ArrayElement, T>) -> ArrayElementKeyPath<Self, T> {
503508
.init(.first, self, keyPath)
504509
}
505510

506-
public func last<T>(_ keyPath: KeyPath<Value.Element, T>) -> ArrayElementKeyPath<Self, T> {
511+
public func last<T>(_ keyPath: KeyPath<Value.ArrayElement, T>) -> ArrayElementKeyPath<Self, T> {
507512
.init(.last, self, keyPath)
508513
}
509514
}
@@ -600,6 +605,8 @@ extension Expression {
600605

601606
// MARK: - Supporting Protocols
602607

608+
// MARK: - StringValue
609+
603610
public protocol StringValue {
604611
}
605612

@@ -609,10 +616,15 @@ extension String: StringValue {
609616
extension Optional: StringValue where Wrapped == String {
610617
}
611618

619+
// MARK: - AnyArrayOrSet
620+
612621
public protocol AnyArrayOrSet {
613622
associatedtype Element
614623
}
615624

625+
extension Array: AnyArrayOrSet {
626+
}
627+
616628
extension Set: AnyArrayOrSet {
617629
}
618630

@@ -623,16 +635,22 @@ extension Optional: AnyArrayOrSet where Wrapped: AnyArrayOrSet {
623635
public typealias Element = Wrapped.Element
624636
}
625637

638+
// MARK: - AnyArray
639+
626640
public protocol AnyArray {
627-
associatedtype Element
641+
associatedtype ArrayElement
628642
}
629643

630-
extension Array: AnyArrayOrSet {
644+
extension Array: AnyArray {
645+
public typealias ArrayElement = Element
631646
}
632647

633-
extension Array: AnyArray {
648+
extension Optional: AnyArray where Wrapped: AnyArray {
649+
public typealias ArrayElement = Wrapped.ArrayElement
634650
}
635651

652+
// MARK: - PrimitiveCollection
653+
636654
public protocol PrimitiveCollection {
637655
associatedtype PrimitiveElement: Primitive
638656
}
@@ -649,6 +667,8 @@ extension Optional: PrimitiveCollection where Wrapped: PrimitiveCollection {
649667
public typealias PrimitiveElement = Wrapped.PrimitiveElement
650668
}
651669

670+
// MARK: - AdditiveCollection
671+
652672
public protocol AdditiveCollection {
653673
associatedtype AdditiveElement: AdditiveArithmetic & Primitive
654674
}
@@ -661,6 +681,8 @@ extension Optional: AdditiveCollection where Wrapped: PrimitiveCollection & Addi
661681
public typealias AdditiveElement = Wrapped.AdditiveElement
662682
}
663683

684+
// MARK: - ComparableCollection
685+
664686
public protocol ComparableCollection {
665687
associatedtype ComparableElement: Comparable & Primitive
666688
}
@@ -673,7 +695,7 @@ extension Optional: ComparableCollection where Wrapped: ComparableCollection {
673695
public typealias ComparableElement = Wrapped.ComparableElement
674696
}
675697

676-
// MARK: -
698+
// MARK: - Private Initializers
677699

678700
extension Comparison {
679701
fileprivate init<E: Expression>(

PredicateKit/Primitive.swift

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import Foundation
2222

23+
// MARK: - Primitive
24+
2325
public protocol Primitive {
2426
static var type: Type { get }
2527
}
@@ -45,6 +47,7 @@ public indirect enum Type: Equatable {
4547
case data
4648
case wrapped(Type)
4749
case array(Type)
50+
case `nil`
4851
}
4952

5053
extension Bool: Primitive {
@@ -131,6 +134,22 @@ extension Optional: Primitive where Wrapped: Primitive {
131134
public static var type: Type { Wrapped.type }
132135
}
133136

137+
public struct Nil: Primitive, ExpressibleByNilLiteral {
138+
public static var type: Type { .nil }
139+
140+
public init(nilLiteral: ()) {
141+
}
142+
}
143+
144+
// MARK: - Optional
145+
146+
public protocol OptionalType {
147+
associatedtype Wrapped
148+
}
149+
150+
extension Optional: OptionalType {
151+
}
152+
134153
extension Optional: Comparable where Wrapped: Comparable {
135154
public static func < (lhs: Self, rhs: Self) -> Bool {
136155
switch (lhs, rhs) {

PredicateKit/SwiftUI/SwiftUISupport.swift

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,32 @@ extension FetchRequest {
135135
self.init(context: context, predicate: predicate)
136136
}
137137
}
138+
139+
@available(iOS 13.0, watchOS 6.0, tvOS 13.0, *)
140+
extension FetchRequest {
141+
/// Creates a fetch request that returns all objects in the underlying store.
142+
///
143+
/// - Important: Use this initializer **only** in conjunction with the SwiftUI property wrapper` @FetchRequest`. Fetch
144+
/// requests created with this initializer cannot be executed outside of SwiftUI as they rely on the CoreData
145+
/// managed object context injected in the environment of a SwiftUI view.
146+
///
147+
/// ## Example
148+
///
149+
/// struct ContentView: View {
150+
/// @SwiftUI.FetchRequest()
151+
/// .sorted(by: \Note.creationDate, .ascending)
152+
/// .limit(100)
153+
/// )
154+
/// var notes: FetchedResults<Note>
155+
///
156+
/// var body: some View {
157+
/// List(notes, id: \.self) {
158+
/// Text($0.text)
159+
/// }
160+
/// }
161+
/// }
162+
///
163+
public init() {
164+
self.init(predicate: true)
165+
}
166+
}

PredicateKitTests/CoreDataTests/NSFetchRequestBuilderTests.swift

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,45 @@ final class NSFetchRequestBuilderTests: XCTestCase {
10391039

10401040
XCTAssertTrue(fatalError.contains("does not conform to NSExpressionConvertible"))
10411041
}
1042+
1043+
func testObjectNilEqualityPredicate() throws {
1044+
let request = makeRequest(\Data.optionalRelationship == nil)
1045+
let builder = makeRequestBuilder()
1046+
1047+
let result: NSFetchRequest<Data> = builder.makeRequest(from: request)
1048+
1049+
let comparison = try XCTUnwrap(result.predicate as? NSComparisonPredicate)
1050+
XCTAssertEqual(comparison.leftExpression, NSExpression(forKeyPath: "optionalRelationship"))
1051+
XCTAssertEqual(comparison.rightExpression, NSExpression(forConstantValue: NSNull()))
1052+
XCTAssertEqual(comparison.predicateOperatorType, .equalTo)
1053+
XCTAssertEqual(comparison.comparisonPredicateModifier, .direct)
1054+
}
1055+
1056+
func testArrayNilEqualityPredicate() throws {
1057+
let request = makeRequest(\Data.optionalRelationships == nil)
1058+
let builder = makeRequestBuilder()
1059+
1060+
let result: NSFetchRequest<Data> = builder.makeRequest(from: request)
1061+
1062+
let comparison = try XCTUnwrap(result.predicate as? NSComparisonPredicate)
1063+
XCTAssertEqual(comparison.leftExpression, NSExpression(forKeyPath: "optionalRelationships"))
1064+
XCTAssertEqual(comparison.rightExpression, NSExpression(forConstantValue: NSNull()))
1065+
XCTAssertEqual(comparison.predicateOperatorType, .equalTo)
1066+
XCTAssertEqual(comparison.comparisonPredicateModifier, .direct)
1067+
}
1068+
1069+
func testNestedPrimitiveNilEqualityPredicate() throws {
1070+
let request = makeRequest(\Data.optionalRelationship?.text == nil)
1071+
let builder = makeRequestBuilder()
1072+
1073+
let result: NSFetchRequest<Data> = builder.makeRequest(from: request)
1074+
1075+
let comparison = try XCTUnwrap(result.predicate as? NSComparisonPredicate)
1076+
XCTAssertEqual(comparison.leftExpression, NSExpression(forKeyPath: "optionalRelationship.text"))
1077+
XCTAssertEqual(comparison.rightExpression, NSExpression(forConstantValue: NSNull()))
1078+
XCTAssertEqual(comparison.predicateOperatorType, .equalTo)
1079+
XCTAssertEqual(comparison.comparisonPredicateModifier, .direct)
1080+
}
10421081
}
10431082

10441083
// MARK: -
@@ -1051,6 +1090,8 @@ private class Data: NSManagedObject {
10511090
@NSManaged var creationDate: Date
10521091
@NSManaged var relationship: Relationship
10531092
@NSManaged var relationships: [Relationship]
1093+
@NSManaged var optionalRelationship: Relationship?
1094+
@NSManaged var optionalRelationships: [Relationship]?
10541095
}
10551096

10561097
private class Relationship: NSManagedObject {
@@ -1079,3 +1120,12 @@ private func makeRequestBuilder(
10791120
) -> NSFetchRequestBuilder {
10801121
.init(entityName: "")
10811122
}
1123+
1124+
class NoteGroup: NSManagedObject {
1125+
@NSManaged var notes: [NewNote]?
1126+
}
1127+
1128+
class NewNote: NSManagedObject {
1129+
@NSManaged var group: NoteGroup?
1130+
@NSManaged var id: String?
1131+
}

PredicateKitTests/CoreDataTests/NSManagedObjectContextExtensionsTests.swift

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,40 @@ final class NSManagedObjectContextExtensionsTests: XCTestCase {
629629
XCTAssertTrue(inspector.inspectCalled)
630630
}
631631

632+
func testFetchWithNilEquality() throws {
633+
let now = Date()
634+
635+
try container.viewContext.insertNotes(
636+
(text: "Hello, World!", creationDate: .distantFuture, updateDate: now, numberOfViews: 42, tags: ["greeting"]),
637+
(text: "Goodbye!", creationDate: .distantPast, updateDate: nil, numberOfViews: 3, tags: ["greeting"])
638+
)
639+
640+
let notes: [Note] = try container.viewContext
641+
.fetch(where: \Note.updateDate == nil)
642+
.result()
643+
644+
XCTAssertEqual(notes.count, 1)
645+
XCTAssertEqual(notes.first?.text, "Goodbye!")
646+
XCTAssertEqual(notes.first?.tags, ["greeting"])
647+
XCTAssertEqual(notes.first?.numberOfViews, 3)
648+
}
649+
650+
func testFetchWithArrayNilEqualityNilEquality() throws {
651+
try container.viewContext.insertUsers(
652+
(name: "John Doe", billingAccountType: "Pro", purchases: [35.0, 120.0]),
653+
(name: "Jane Doe", billingAccountType: "Default", purchases: nil)
654+
)
655+
656+
let users: [User] = try container.viewContext
657+
.fetch(where: \User.billingInfo.purchases == nil)
658+
.inspect(on: MockNSFetchRequestInspector())
659+
.result()
660+
661+
XCTAssertEqual(users.count, 1)
662+
XCTAssertEqual(users.first?.name, "Jane Doe")
663+
XCTAssertEqual(users.first?.billingInfo.accountType, "Default")
664+
}
665+
632666
private func makePersistentContainer() -> NSPersistentContainer {
633667
return self.makePersistentContainer(with: model)
634668
}
@@ -639,6 +673,7 @@ final class NSManagedObjectContextExtensionsTests: XCTestCase {
639673
class Note: NSManagedObject {
640674
@NSManaged var text: String
641675
@NSManaged var creationDate: Date
676+
@NSManaged var updateDate: Date?
642677
@NSManaged var numberOfViews: Int
643678
@NSManaged var tags: [String]
644679
}
@@ -654,7 +689,7 @@ class User: NSManagedObject {
654689

655690
class BillingInfo: NSManagedObject {
656691
@NSManaged var accountType: String
657-
@NSManaged var purchases: [Double]
692+
@NSManaged var purchases: [Double]?
658693
}
659694

660695
class UserAccount: NSManagedObject {
@@ -703,6 +738,21 @@ private extension NSManagedObjectContext {
703738
try save()
704739
}
705740

741+
func insertNotes(
742+
_ notes: (text: String, creationDate: Date, updateDate: Date?, numberOfViews: Int, tags: [String])...
743+
) throws {
744+
for description in notes {
745+
let note = NSEntityDescription.insertNewObject(forEntityName: "Note", into: self) as! Note
746+
note.text = description.text
747+
note.tags = description.tags
748+
note.numberOfViews = description.numberOfViews
749+
note.creationDate = description.creationDate
750+
note.updateDate = description.updateDate
751+
}
752+
753+
try save()
754+
}
755+
706756
func insertAccounts(purchases: [[Double]]) throws {
707757
for description in purchases {
708758
let account = NSEntityDescription.insertNewObject(forEntityName: "Account", into: self) as! Account
@@ -712,7 +762,7 @@ private extension NSManagedObjectContext {
712762
try save()
713763
}
714764

715-
func insertUsers(_ users: (name: String, billingAccountType: String, purchases: [Double])...) throws {
765+
func insertUsers(_ users: (name: String, billingAccountType: String, purchases: [Double]?)...) throws {
716766
for description in users {
717767
let user = NSEntityDescription.insertNewObject(forEntityName: "User", into: self) as! User
718768
user.name = description.name

0 commit comments

Comments
 (0)