@@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546546 case "NewSet" :
547547 pset , errs := oc .processNewSet (info , pkgPath , call , nil , varName )
548548 return pset , notePositionAll (exprPos , errs )
549+ case "Subtract" :
550+ pset , errs := oc .processSubtract (info , pkgPath , call , nil , varName )
551+ return pset , notePositionAll (exprPos , errs )
549552 case "Bind" :
550553 b , err := processBind (oc .fset , info , call )
551554 if err != nil {
@@ -590,6 +593,114 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
590593 return nil , []error {notePosition (exprPos , errors .New ("unknown pattern" ))}
591594}
592595
596+ func (oc * objectCache ) filterType (s * ProviderSet , st types.Type ) []error {
597+ hasType := func (outs []types.Type ) bool {
598+ for _ , o := range outs {
599+ if types .Identical (o , st ) {
600+ return true
601+ }
602+ pt , ok := o .(* types.Pointer )
603+ if ok && types .Identical (pt .Elem (), st ) {
604+ return true
605+ }
606+ }
607+ return false
608+ }
609+ providers := make ([]* Provider , 0 , len (s .Providers ))
610+ for _ , p := range s .Providers {
611+ if ! hasType (p .Out ) {
612+ providers = append (providers , p )
613+ }
614+ }
615+ s .Providers = providers
616+
617+ bindings := make ([]* IfaceBinding , 0 , len (s .Bindings ))
618+ for _ , i := range s .Bindings {
619+ if ! types .Identical (i .Iface , st ) {
620+ bindings = append (bindings , i )
621+ }
622+ }
623+ s .Bindings = bindings
624+
625+ values := make ([]* Value , 0 , len (s .Values ))
626+ for _ , v := range s .Values {
627+ if ! types .Identical (v .Out , st ) {
628+ values = append (values , v )
629+ }
630+ }
631+ s .Values = values
632+
633+ fields := make ([]* Field , 0 , len (s .Fields ))
634+ for _ , f := range s .Fields {
635+ if ! hasType (f .Out ) {
636+ fields = append (fields , f )
637+ }
638+ }
639+ s .Fields = fields
640+
641+ imports := make ([]* ProviderSet , 0 , len (s .Imports ))
642+ for _ , p := range s .Imports {
643+ clone := * p
644+ if errs := oc .filterType (& clone , st ); len (errs ) > 0 {
645+ return errs
646+ }
647+ imports = append (imports , & clone )
648+ }
649+ s .Imports = imports
650+
651+ var errs []error
652+ s .providerMap , s .srcMap , errs = buildProviderMap (oc .fset , oc .hasher , s )
653+ if len (errs ) > 0 {
654+ return errs
655+ }
656+ return nil
657+ }
658+
659+ func (oc * objectCache ) processSubtract (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (interface {}, []error ) {
660+ // Assumes that call.Fun is wire.Subtract.
661+ if len (call .Args ) < 2 {
662+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
663+ errors .New ("call to Subtract must specify types to be subtracted" ))}
664+ }
665+ firstArg , errs := oc .processExpr (info , pkgPath , call .Args [0 ], "" )
666+ if len (errs ) > 0 {
667+ return nil , errs
668+ }
669+ set , ok := firstArg .(* ProviderSet )
670+ if ! ok {
671+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
672+ fmt .Errorf ("first argument to Subtract must be a Set" )),
673+ }
674+ }
675+ pset := & ProviderSet {
676+ Pos : call .Pos (),
677+ InjectorArgs : args ,
678+ PkgPath : pkgPath ,
679+ VarName : varName ,
680+ // Copy the other fields.
681+ Providers : set .Providers ,
682+ Bindings : set .Bindings ,
683+ Values : set .Values ,
684+ Fields : set .Fields ,
685+ Imports : set .Imports ,
686+ }
687+ ec := new (errorCollector )
688+ for _ , arg := range call .Args [1 :] {
689+ ptr , ok := info .TypeOf (arg ).(* types.Pointer )
690+ if ! ok {
691+ ec .add (notePosition (oc .fset .Position (arg .Pos ()),
692+ errors .New ("argument to Subtract must be a pointer" ),
693+ ))
694+ continue
695+ }
696+ ec .add (oc .filterType (pset , ptr .Elem ())... )
697+ }
698+ if len (ec .errors ) > 0 {
699+ return nil , ec .errors
700+ }
701+ return pset , nil
702+ }
703+
593704func (oc * objectCache ) processNewSet (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (* ProviderSet , []error ) {
594705 // Assumes that call.Fun is wire.NewSet or wire.Build.
595706
@@ -1173,9 +1284,9 @@ func (pt ProvidedType) IsNil() bool {
11731284//
11741285// - For a function provider, this is the first return value type.
11751286// - For a struct provider, this is either the struct type or the pointer type
1176- // whose element type is the struct type.
1177- // - For a value, this is the type of the expression.
1178- // - For an argument, this is the type of the argument.
1287+ // whose element type is the struct type.
1288+ // - For a value, this is the type of the expression.
1289+ // - For an argument, this is the type of the argument.
11791290func (pt ProvidedType ) Type () types.Type {
11801291 return pt .t
11811292}
0 commit comments