diff --git a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java index 226176d25ff..1da12939937 100644 --- a/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java +++ b/rls/src/main/java/io/grpc/rls/LbPolicyConfiguration.java @@ -248,10 +248,10 @@ ChildPolicyWrapper createOrGet(String target) { RefCountedChildPolicyWrapper pooledChildPolicyWrapper = childPolicyMap.get(target); if (pooledChildPolicyWrapper == null) { ChildPolicyWrapper childPolicyWrapper = new ChildPolicyWrapper( - target, childPolicy, childLbResolvedAddressFactory, childLbHelperProvider, - childLbStatusListener); + target, childPolicy, childLbHelperProvider, childLbStatusListener); pooledChildPolicyWrapper = RefCountedChildPolicyWrapper.of(childPolicyWrapper); childPolicyMap.put(target, pooledChildPolicyWrapper); + childPolicyWrapper.start(childLbResolvedAddressFactory); return pooledChildPolicyWrapper.getObject(); } else { ChildPolicyWrapper childPolicyWrapper = pooledChildPolicyWrapper.getObject(); @@ -298,7 +298,6 @@ static final class ChildPolicyWrapper { public ChildPolicyWrapper( String target, ChildLoadBalancingPolicy childPolicy, - final ResolvedAddressFactory childLbResolvedAddressFactory, ChildLoadBalancerHelperProvider childLbHelperProvider, ChildLbStatusListener childLbStatusListener) { this.target = target; @@ -313,6 +312,9 @@ public ChildPolicyWrapper( this.childLbConfig = lbConfig.getConfig(); helper.getChannelLogger().log( ChannelLogLevel.DEBUG, "RLS child lb created. config: {0}", childLbConfig); + } + + void start(ResolvedAddressFactory childLbResolvedAddressFactory) { helper.getSynchronizationContext().execute( new Runnable() { @Override diff --git a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java index d6025d5bad4..e4c21bd48a6 100644 --- a/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java +++ b/rls/src/test/java/io/grpc/rls/LbPolicyConfigurationTest.java @@ -21,6 +21,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -45,8 +46,8 @@ import io.grpc.rls.LbPolicyConfiguration.ChildPolicyWrapper.ChildPolicyReportingHelper; import io.grpc.rls.LbPolicyConfiguration.InvalidChildPolicyConfigException; import io.grpc.rls.LbPolicyConfiguration.RefCountedChildPolicyWrapperFactory; -import java.lang.Thread.UncaughtExceptionHandler; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -61,6 +62,9 @@ public class LbPolicyConfigurationTest { private final LoadBalancer lb = mock(LoadBalancer.class); private final SubchannelStateManager subchannelStateManager = new SubchannelStateManagerImpl(); private final SubchannelPicker picker = mock(SubchannelPicker.class); + private final SynchronizationContext syncContext = new SynchronizationContext((t, e) -> { + throw new AssertionError(e); + }); private final ChildLbStatusListener childLbStatusListener = mock(ChildLbStatusListener.class); private final ResolvedAddressFactory resolvedAddressFactory = new ResolvedAddressFactory() { @@ -84,15 +88,7 @@ public ResolvedAddresses create(Object childLbConfig) { @Before public void setUp() { doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger(); - doReturn( - new SynchronizationContext( - new UncaughtExceptionHandler() { - @Override - public void uncaughtException(Thread t, Throwable e) { - throw new AssertionError(e); - } - })) - .when(helper).getSynchronizationContext(); + doReturn(syncContext).when(helper).getSynchronizationContext(); doReturn(lb).when(lbProvider).newLoadBalancer(any(Helper.class)); doReturn(ConfigOrError.fromConfig(new Object())) .when(lbProvider).parseLoadBalancingPolicyConfig(ArgumentMatchers.>any()); @@ -190,4 +186,22 @@ public void updateBalancingState_triggersListener() { // picker governs childPickers will be reported to parent LB verify(helper).updateBalancingState(ConnectivityState.READY, picker); } + + @Test + public void refCountedGetOrCreate_addsChildBeforeConfiguringChild() { + AtomicBoolean calledAlready = new AtomicBoolean(); + when(lb.acceptResolvedAddresses(any(ResolvedAddresses.class))).thenAnswer(i -> { + if (!calledAlready.get()) { + calledAlready.set(true); + // Should end up calling this function again, as this child should already be added to the + // list of children. In practice, this can be caused by CDS is_dynamic=true starting a watch + // when XdsClient already has the cluster cached (e.g., from another channel). + syncContext.execute(() -> + factory.acceptResolvedAddressFactory(resolvedAddressFactory)); + } + return Status.OK; + }); + ChildPolicyWrapper unused = factory.createOrGet("foo.google.com"); + verify(lb, times(2)).acceptResolvedAddresses(any(ResolvedAddresses.class)); + } }