@@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546546 case "NewSet" :
547547 pset , errs := oc .processNewSet (info , pkgPath , call , nil , varName )
548548 return pset , notePositionAll (exprPos , errs )
549+ case "Subtract" :
550+ pset , errs := oc .processSubtract (info , pkgPath , call , nil , varName )
551+ return pset , notePositionAll (exprPos , errs )
549552 case "Bind" :
550553 b , err := processBind (oc .fset , info , call )
551554 if err != nil {
@@ -590,6 +593,113 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
590593 return nil , []error {notePosition (exprPos , errors .New ("unknown pattern" ))}
591594}
592595
596+ func (oc * objectCache ) filterType (s * ProviderSet , st types.Type ) []error {
597+ hasType := func (outs []types.Type ) bool {
598+ for _ , o := range outs {
599+ if types .Identical (o , st ) {
600+ return true
601+ }
602+ pt , ok := o .(* types.Pointer )
603+ if ok && types .Identical (pt .Elem (), st ) {
604+ return true
605+ }
606+ }
607+ return false
608+ }
609+ providers := make ([]* Provider , 0 , len (s .Providers ))
610+ for _ , p := range s .Providers {
611+ if ! hasType (p .Out ) {
612+ providers = append (providers , p )
613+ }
614+ }
615+ s .Providers = providers
616+
617+ bindings := make ([]* IfaceBinding , 0 , len (s .Bindings ))
618+ for _ , i := range s .Bindings {
619+ if ! types .Identical (i .Iface , st ) {
620+ bindings = append (bindings , i )
621+ }
622+ }
623+ s .Bindings = bindings
624+
625+ values := make ([]* Value , 0 , len (s .Values ))
626+ for _ , v := range s .Values {
627+ if ! types .Identical (v .Out , st ) {
628+ values = append (values , v )
629+ }
630+ }
631+ s .Values = values
632+
633+ fields := make ([]* Field , 0 , len (s .Fields ))
634+ for _ , f := range s .Fields {
635+ if ! hasType (f .Out ) {
636+ fields = append (fields , f )
637+ }
638+ }
639+ s .Fields = fields
640+
641+ imports := make ([]* ProviderSet , 0 , len (s .Imports ))
642+ for _ , p := range s .Imports {
643+ clone := * p
644+ if errs := oc .filterType (& clone , st ); len (errs ) > 0 {
645+ return errs
646+ }
647+ imports = append (imports , & clone )
648+ }
649+ s .Imports = imports
650+ var errs []error
651+ s .providerMap , s .srcMap , errs = buildProviderMap (oc .fset , oc .hasher , s )
652+ if len (errs ) > 0 {
653+ return errs
654+ }
655+ return nil
656+ }
657+
658+ func (oc * objectCache ) processSubtract (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (interface {}, []error ) {
659+ // Assumes that call.Fun is wire.Subtract.
660+ if len (call .Args ) < 2 {
661+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
662+ errors .New ("call to Subtract must specify types to be subtracted" ))}
663+ }
664+ firstArg , errs := oc .processExpr (info , pkgPath , call .Args [0 ], "" )
665+ if len (errs ) > 0 {
666+ return nil , errs
667+ }
668+ set , ok := firstArg .(* ProviderSet )
669+ if ! ok {
670+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
671+ fmt .Errorf ("first argument to Subtract must be a Set" )),
672+ }
673+ }
674+ pset := & ProviderSet {
675+ Pos : call .Pos (),
676+ InjectorArgs : args ,
677+ PkgPath : pkgPath ,
678+ VarName : varName ,
679+ // Copy the other fields.
680+ Providers : set .Providers ,
681+ Bindings : set .Bindings ,
682+ Values : set .Values ,
683+ Fields : set .Fields ,
684+ Imports : set .Imports ,
685+ }
686+ ec := new (errorCollector )
687+ for _ , arg := range call .Args [1 :] {
688+ ptr , ok := info .TypeOf (arg ).(* types.Pointer )
689+ if ! ok {
690+ ec .add (notePosition (oc .fset .Position (arg .Pos ()),
691+ errors .New ("argument to Subtract must be a pointer" ),
692+ ))
693+ continue
694+ }
695+ ec .add (oc .filterType (pset , ptr .Elem ())... )
696+ }
697+ if len (ec .errors ) > 0 {
698+ return nil , ec .errors
699+ }
700+ return pset , nil
701+ }
702+
593703func (oc * objectCache ) processNewSet (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (* ProviderSet , []error ) {
594704 // Assumes that call.Fun is wire.NewSet or wire.Build.
595705
@@ -1173,9 +1283,9 @@ func (pt ProvidedType) IsNil() bool {
11731283//
11741284// - For a function provider, this is the first return value type.
11751285// - For a struct provider, this is either the struct type or the pointer type
1176- // whose element type is the struct type.
1177- // - For a value, this is the type of the expression.
1178- // - For an argument, this is the type of the argument.
1286+ // whose element type is the struct type.
1287+ // - For a value, this is the type of the expression.
1288+ // - For an argument, this is the type of the argument.
11791289func (pt ProvidedType ) Type () types.Type {
11801290 return pt .t
11811291}
0 commit comments