diff --git a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/DescriptorUtils.java b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/DescriptorUtils.java index e8f34431e222..a69893c9d217 100644 --- a/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/DescriptorUtils.java +++ b/dubbo-rpc/dubbo-rpc-triple/src/main/java/org/apache/dubbo/rpc/protocol/tri/DescriptorUtils.java @@ -18,6 +18,7 @@ import org.apache.dubbo.common.URL; import org.apache.dubbo.common.constants.CommonConstants; +import org.apache.dubbo.common.stream.StreamObserver; import org.apache.dubbo.common.utils.CollectionUtils; import org.apache.dubbo.remoting.http12.exception.UnimplementedException; import org.apache.dubbo.rpc.Invoker; @@ -30,6 +31,7 @@ import java.io.IOException; import java.io.InputStream; +import java.lang.reflect.Type; import java.util.Arrays; import java.util.List; @@ -99,49 +101,133 @@ public static MethodDescriptor findReflectionMethodDescriptor( .get(0); } else { List methodDescriptors = serviceDescriptor.getMethods(methodName); - if (CollectionUtils.isEmpty(methodDescriptors)) { - return null; - } - // In most cases there is only one method - if (methodDescriptors.size() == 1) { - methodDescriptor = methodDescriptors.get(0); - } - // generated unary method ,use unary type - // Response foo(Request) - // void foo(Request,StreamObserver) - if (methodDescriptors.size() == 2) { - if (methodDescriptors.get(1).getRpcType() == MethodDescriptor.RpcType.SERVER_STREAM) { - methodDescriptor = methodDescriptors.get(0); - } else if (methodDescriptors.get(0).getRpcType() == MethodDescriptor.RpcType.SERVER_STREAM) { - methodDescriptor = methodDescriptors.get(1); - } - } + methodDescriptor = findSingleOrGeneratedUnaryMethodDescriptor(methodDescriptors); } return methodDescriptor; } public static MethodDescriptor findTripleMethodDescriptor( ServiceDescriptor serviceDescriptor, String methodName, InputStream rawMessage) throws IOException { - MethodDescriptor methodDescriptor = findReflectionMethodDescriptor(serviceDescriptor, methodName); - if (methodDescriptor == null) { - rawMessage.mark(Integer.MAX_VALUE); - List methodDescriptors = serviceDescriptor.getMethods(methodName); - TripleRequestWrapper request = TripleRequestWrapper.parseFrom(rawMessage); - String[] paramTypes = request.getArgTypes().toArray(new String[0]); - // wrapper mode the method can overload so maybe list - for (MethodDescriptor descriptor : methodDescriptors) { - // params type is array - if (Arrays.equals(descriptor.getCompatibleParamSignatures(), paramTypes)) { - methodDescriptor = descriptor; - break; - } + if (isGeneric(methodName)) { + return ServiceDescriptorInternalCache.genericService() + .getMethods(methodName) + .get(0); + } + if (isEcho(methodName)) { + return ServiceDescriptorInternalCache.echoService() + .getMethods(methodName) + .get(0); + } + + List methodDescriptors = serviceDescriptor.getMethods(methodName); + if (CollectionUtils.isEmpty(methodDescriptors)) { + throw new UnimplementedException("method:" + methodName); + } + if (methodDescriptors.size() == 1) { + return methodDescriptors.get(0); + } + + TripleRequestWrapper request = parseRequestWrapper(rawMessage); + List argTypes = request == null ? null : request.getArgTypes(); + if (argTypes != null) { + MethodDescriptor methodDescriptor = + findMethodDescriptorByParamTypes(methodDescriptors, argTypes.toArray(new String[0])); + if (methodDescriptor != null) { + return methodDescriptor; } - if (methodDescriptor == null) { + if (CollectionUtils.isNotEmpty(argTypes)) { throw new UnimplementedException("method:" + methodName); } + } + + MethodDescriptor methodDescriptor = findGeneratedUnaryMethodDescriptor(methodDescriptors); + if (methodDescriptor != null) { + return methodDescriptor; + } + throw new UnimplementedException("method:" + methodName); + } + + private static MethodDescriptor findSingleOrGeneratedUnaryMethodDescriptor( + List methodDescriptors) { + if (CollectionUtils.isEmpty(methodDescriptors)) { + return null; + } + // In most cases there is only one method + if (methodDescriptors.size() == 1) { + return methodDescriptors.get(0); + } + return findGeneratedUnaryMethodDescriptor(methodDescriptors); + } + + private static TripleRequestWrapper parseRequestWrapper(InputStream rawMessage) throws IOException { + rawMessage.mark(Integer.MAX_VALUE); + try { + return TripleRequestWrapper.parseFrom(rawMessage); + } catch (IOException | RuntimeException ignored) { + return null; + } finally { rawMessage.reset(); } - return methodDescriptor; + } + + private static MethodDescriptor findMethodDescriptorByParamTypes( + List methodDescriptors, String[] paramTypes) { + for (MethodDescriptor descriptor : methodDescriptors) { + // wrapper mode the method can overload so maybe list + if (Arrays.equals(descriptor.getCompatibleParamSignatures(), paramTypes)) { + return descriptor; + } + } + return null; + } + + private static MethodDescriptor findGeneratedUnaryMethodDescriptor(List methodDescriptors) { + // Generated unary methods may expose two Java methods for the same RPC: + // Response foo(Request) + // void foo(Request, StreamObserver) + if (methodDescriptors.size() != 2) { + return null; + } + + MethodDescriptor unaryMethodDescriptor = null; + MethodDescriptor serverStreamMethodDescriptor = null; + for (MethodDescriptor descriptor : methodDescriptors) { + if (descriptor.getRpcType() == MethodDescriptor.RpcType.UNARY) { + unaryMethodDescriptor = descriptor; + } else if (descriptor.getRpcType() == MethodDescriptor.RpcType.SERVER_STREAM) { + serverStreamMethodDescriptor = descriptor; + } + } + if (unaryMethodDescriptor == null || serverStreamMethodDescriptor == null) { + return null; + } + if (!Arrays.equals( + unaryMethodDescriptor.getParameterClasses(), + getServerStreamRequestTypes(serverStreamMethodDescriptor))) { + return null; + } + if (!isSameResponseType(unaryMethodDescriptor, serverStreamMethodDescriptor)) { + return null; + } + return unaryMethodDescriptor; + } + + private static Class[] getServerStreamRequestTypes(MethodDescriptor serverStreamMethodDescriptor) { + Class[] parameterClasses = serverStreamMethodDescriptor.getParameterClasses(); + if (parameterClasses.length == 0 + || !StreamObserver.class.isAssignableFrom(parameterClasses[parameterClasses.length - 1])) { + return null; + } + return Arrays.copyOf(parameterClasses, parameterClasses.length - 1); + } + + private static boolean isSameResponseType( + MethodDescriptor unaryMethodDescriptor, MethodDescriptor serverStreamMethodDescriptor) { + Type[] returnTypes = unaryMethodDescriptor.getReturnTypes(); + if (returnTypes.length == 0 || !(returnTypes[0] instanceof Class)) { + return false; + } + return returnTypes[0] == serverStreamMethodDescriptor.getActualResponseType(); } private static boolean isGeneric(String methodName) { diff --git a/dubbo-rpc/dubbo-rpc-triple/src/test/java/org/apache/dubbo/rpc/protocol/tri/DescriptorUtilsTest.java b/dubbo-rpc/dubbo-rpc-triple/src/test/java/org/apache/dubbo/rpc/protocol/tri/DescriptorUtilsTest.java new file mode 100644 index 000000000000..1333b3609239 --- /dev/null +++ b/dubbo-rpc/dubbo-rpc-triple/src/test/java/org/apache/dubbo/rpc/protocol/tri/DescriptorUtilsTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.dubbo.rpc.protocol.tri; + +import org.apache.dubbo.common.stream.StreamObserver; +import org.apache.dubbo.remoting.http12.exception.UnimplementedException; +import org.apache.dubbo.rpc.model.MethodDescriptor; +import org.apache.dubbo.rpc.model.ReflectionServiceDescriptor; + +import java.io.ByteArrayInputStream; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class DescriptorUtilsTest { + + @Test + void shouldSelectServerStreamMethodByWrapperArgTypesWhenUnaryMethodHasSameName() throws Exception { + ReflectionServiceDescriptor serviceDescriptor = new ReflectionServiceDescriptor(OverloadedService.class); + TripleCustomerProtocolWrapper.TripleRequestWrapper request = + TripleCustomerProtocolWrapper.TripleRequestWrapper.Builder.newBuilder() + .setSerializeType("hessian4") + .addArgTypes(String.class.getName()) + .addArgTypes(StreamObserver.class.getName()) + .build(); + + MethodDescriptor methodDescriptor = DescriptorUtils.findTripleMethodDescriptor( + serviceDescriptor, "sync", new ByteArrayInputStream(request.toByteArray())); + + assertEquals(MethodDescriptor.RpcType.SERVER_STREAM, methodDescriptor.getRpcType()); + } + + @Test + void shouldSelectUnaryMethodByWrapperArgTypesWhenServerStreamMethodHasSameName() throws Exception { + ReflectionServiceDescriptor serviceDescriptor = new ReflectionServiceDescriptor(OverloadedService.class); + TripleCustomerProtocolWrapper.TripleRequestWrapper request = + TripleCustomerProtocolWrapper.TripleRequestWrapper.Builder.newBuilder() + .setSerializeType("hessian4") + .addArgTypes(String.class.getName()) + .build(); + + MethodDescriptor methodDescriptor = DescriptorUtils.findTripleMethodDescriptor( + serviceDescriptor, "sync", new ByteArrayInputStream(request.toByteArray())); + + assertEquals(MethodDescriptor.RpcType.UNARY, methodDescriptor.getRpcType()); + } + + @Test + void shouldSelectNoArgMethodByEmptyWrapperArgTypesWhenMethodIsOverloaded() throws Exception { + ReflectionServiceDescriptor serviceDescriptor = new ReflectionServiceDescriptor(NoArgOverloadedService.class); + TripleCustomerProtocolWrapper.TripleRequestWrapper request = + TripleCustomerProtocolWrapper.TripleRequestWrapper.Builder.newBuilder() + .setSerializeType("hessian4") + .build(); + + MethodDescriptor methodDescriptor = DescriptorUtils.findTripleMethodDescriptor( + serviceDescriptor, "overload", new ByteArrayInputStream(request.toByteArray())); + + assertEquals(0, methodDescriptor.getParameterClasses().length); + } + + @Test + void shouldRejectOverloadedMethodsWithoutMatchingWrapperArgTypes() { + ReflectionServiceDescriptor serviceDescriptor = new ReflectionServiceDescriptor(OverloadedService.class); + TripleCustomerProtocolWrapper.TripleRequestWrapper request = + TripleCustomerProtocolWrapper.TripleRequestWrapper.Builder.newBuilder() + .setSerializeType("hessian4") + .build(); + + assertThrows( + UnimplementedException.class, + () -> DescriptorUtils.findTripleMethodDescriptor( + serviceDescriptor, "sync", new ByteArrayInputStream(request.toByteArray()))); + } + + @Test + void shouldFallbackToUnaryMethodForGeneratedPairWithoutWrapperArgTypes() throws Exception { + ReflectionServiceDescriptor serviceDescriptor = new ReflectionServiceDescriptor(GeneratedUnaryService.class); + TripleCustomerProtocolWrapper.TripleRequestWrapper request = + TripleCustomerProtocolWrapper.TripleRequestWrapper.Builder.newBuilder() + .setSerializeType("hessian4") + .build(); + + MethodDescriptor methodDescriptor = DescriptorUtils.findTripleMethodDescriptor( + serviceDescriptor, "generated", new ByteArrayInputStream(request.toByteArray())); + + assertEquals(MethodDescriptor.RpcType.UNARY, methodDescriptor.getRpcType()); + } + + private interface OverloadedService { + + DataWrapper sync(String value); + + void sync(String value, StreamObserver response); + } + + private interface NoArgOverloadedService { + + String overload(); + + String overload(String value); + } + + private interface GeneratedUnaryService { + + String generated(String value); + + void generated(String value, StreamObserver response); + } +}