Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ object BlockingStubGenerator {
this,
service,
ClassName("io.grpc.stub", "AbstractStub"),
true,
)
}
.addBlockingStubRpcCalls(generator, service)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ object StubGenerator {
this,
service,
ClassName("io.grpc.kotlin", "AbstractCoroutineStub"),
false,
)
addSuspendedStubRpcCalls(generator, this, service, options)
}
Expand Down Expand Up @@ -98,7 +99,13 @@ object StubGenerator {
.addType(
TypeSpec.classBuilder(stubClassName)
.apply {
addAbstractStubConstructor(generator, this, service, ClassName("io.grpc.stub", "AbstractStub"))
addAbstractStubConstructor(
generator,
this,
service,
ClassName("io.grpc.stub", "AbstractStub"),
false,
)
addStubRpcCalls(generator, this, service, options)
}
.build(),
Expand All @@ -110,12 +117,14 @@ object StubGenerator {
builder: TypeSpec.Builder,
service: Service,
superClass: ClassName,
blockingStub: Boolean,
): TypeSpec.Builder {
val stubType = if (blockingStub) "Blocking" else ""
val serviceClassName = generator.classNameFor(service.type)
val stubClassName = ClassName(
packageName = serviceClassName.packageName,
"${serviceClassName.simpleName}WireGrpc",
"${serviceClassName.simpleName}Stub",
"${serviceClassName.simpleName}${stubType}Stub",
)
return builder
// Really this is a superclass, just want to add secondary constructors.
Expand All @@ -140,8 +149,8 @@ object StubGenerator {
.addModifiers(KModifier.OVERRIDE)
.addParameter("channel", ClassName("io.grpc", "Channel"))
.addParameter("callOptions", ClassName("io.grpc", "CallOptions"))
.addStatement("return ${service.name}Stub(channel, callOptions)")
.returns(ClassName("", "${service.name}Stub"))
.addStatement("return ${service.name}${stubType}Stub(channel, callOptions)")
.returns(ClassName("", "${service.name}${stubType}Stub"))
.build(),
)
}
Expand Down
166 changes: 166 additions & 0 deletions server-generator/src/test/golden/BlockingBidiStreamingService.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
// Code generated by Wire protocol buffer compiler, do not edit.
package test

import com.google.protobuf.DescriptorProtos
import com.google.protobuf.Descriptors
import com.squareup.wire.kotlin.grpcserver.MessageSinkAdapter
import com.squareup.wire.kotlin.grpcserver.MessageSourceAdapter
import com.squareup.wire.kotlin.grpcserver.WireBindableService
import com.squareup.wire.kotlin.grpcserver.WireMethodMarshaller
import io.grpc.CallOptions
import io.grpc.Channel
import io.grpc.MethodDescriptor
import io.grpc.ServerServiceDefinition
import io.grpc.ServiceDescriptor
import io.grpc.ServiceDescriptor.newBuilder
import io.grpc.stub.AbstractStub
import io.grpc.stub.StreamObserver
import java.io.InputStream
import java.lang.Class
import java.lang.UnsupportedOperationException
import java.util.concurrent.ExecutorService
import kotlin.Array
import kotlin.String
import kotlin.collections.Map
import kotlin.collections.Set
import kotlin.jvm.Volatile
import io.grpc.stub.ClientCalls.asyncBidiStreamingCall as clientCallsAsyncBidiStreamingCall
import io.grpc.stub.ServerCalls.asyncBidiStreamingCall as serverCallsAsyncBidiStreamingCall

public object TestServiceWireGrpc {
public const val SERVICE_NAME: String = "test.TestService"

@Volatile
private var serviceDescriptor: ServiceDescriptor? = null

private val descriptorMap: Map<String, DescriptorProtos.FileDescriptorProto> =
createDescriptorMap0()


@Volatile
private var getTestRPCMethod: MethodDescriptor<Test, Test>? = null

private fun descriptorFor(`data`: Array<String>): DescriptorProtos.FileDescriptorProto {
val str = data.fold(java.lang.StringBuilder()) { b, s -> b.append(s) }.toString()
val bytes = java.util.Base64.getDecoder().decode(str)
return DescriptorProtos.FileDescriptorProto.parseFrom(bytes)
}

private fun fileDescriptor(path: String, visited: Set<String>): Descriptors.FileDescriptor {
val proto = descriptorMap[path]!!
val deps = proto.dependencyList.filter { !visited.contains(it) }.map { fileDescriptor(it,
visited + path) }
return Descriptors.FileDescriptor.buildFrom(proto, deps.toTypedArray())
}

private fun createDescriptorMap0(): Map<String, DescriptorProtos.FileDescriptorProto> {
val subMap = mapOf(
"service.proto" to descriptorFor(arrayOf(
"Cg1zZXJ2aWNlLnByb3RvEgR0ZXN0IgYKBFRlc3QyNAoLVGVzdFNlcnZpY2USJQoHVGVzdFJQQxIKLnRl",
"c3QuVGVzdBoKLnRlc3QuVGVzdCgBMAE=",
)),
)
return subMap
}

public fun getServiceDescriptor(): ServiceDescriptor? {
var result = serviceDescriptor
if (result == null) {
synchronized(TestServiceWireGrpc::class) {
result = serviceDescriptor
if (result == null) {
result = newBuilder(SERVICE_NAME)
.addMethod(getTestRPCMethod())
.setSchemaDescriptor(io.grpc.protobuf.ProtoFileDescriptorSupplier {
fileDescriptor("service.proto", emptySet())
})
.build()
serviceDescriptor = result
}
}
}
return result
}

public fun getTestRPCMethod(): MethodDescriptor<Test, Test> {
var result: MethodDescriptor<Test, Test>? = getTestRPCMethod
if (result == null) {
synchronized(TestServiceWireGrpc::class) {
result = getTestRPCMethod
if (result == null) {
getTestRPCMethod = MethodDescriptor.newBuilder<Test, Test>()
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.setFullMethodName(
MethodDescriptor.generateFullMethodName(
"test.TestService", "TestRPC"
)
)
.setSampledToLocalTracing(true)
.setRequestMarshaller(TestServiceImplBase.TestMarshaller())
.setResponseMarshaller(TestServiceImplBase.TestMarshaller())
.build()
}
}
}
return getTestRPCMethod!!
}

public fun newStub(channel: Channel): TestServiceStub = TestServiceStub(channel)

public fun newBlockingStub(channel: Channel): TestServiceBlockingStub =
TestServiceBlockingStub(channel)

public abstract class TestServiceImplBase : WireBindableService {
public open fun TestRPC(response: StreamObserver<Test>): StreamObserver<Test> = throw
UnsupportedOperationException()

override fun bindService(): ServerServiceDefinition =
ServerServiceDefinition.builder(getServiceDescriptor()).addMethod(
getTestRPCMethod(),
serverCallsAsyncBidiStreamingCall(this@TestServiceImplBase::TestRPC)
).build()

public class TestMarshaller : WireMethodMarshaller<Test> {
override fun stream(`value`: Test): InputStream = Test.ADAPTER.encode(value).inputStream()

override fun marshalledClass(): Class<Test> = Test::class.java

override fun parse(stream: InputStream): Test = Test.ADAPTER.decode(stream)
}
}

public class BindableAdapter(
private val streamExecutor: ExecutorService,
private val service: () -> TestServiceBlockingServer,
) : TestServiceImplBase() {
override fun TestRPC(response: StreamObserver<Test>): StreamObserver<Test> {
val requestStream = MessageSourceAdapter<Test>()
streamExecutor.submit {
service().TestRPC(requestStream, MessageSinkAdapter(response))
}
return requestStream
}
}

public class TestServiceStub : AbstractStub<TestServiceStub> {
internal constructor(channel: Channel) : super(channel)

internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions)

override fun build(channel: Channel, callOptions: CallOptions): TestServiceStub =
TestServiceStub(channel, callOptions)

public fun TestRPC(response: StreamObserver<Test>): StreamObserver<Test> =
clientCallsAsyncBidiStreamingCall(channel.newCall(getTestRPCMethod(), callOptions),
response)
}

public class TestServiceBlockingStub : AbstractStub<TestServiceBlockingStub> {
internal constructor(channel: Channel) : super(channel)

internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions)

override fun build(channel: Channel, callOptions: CallOptions): TestServiceBlockingStub =
TestServiceBlockingStub(channel, callOptions)
}
}
171 changes: 171 additions & 0 deletions server-generator/src/test/golden/BlockingClientStreamingService.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Code generated by Wire protocol buffer compiler, do not edit.
package test

import com.google.protobuf.DescriptorProtos
import com.google.protobuf.Descriptors
import com.squareup.wire.kotlin.grpcserver.MessageSourceAdapter
import com.squareup.wire.kotlin.grpcserver.WireBindableService
import com.squareup.wire.kotlin.grpcserver.WireMethodMarshaller
import io.grpc.CallOptions
import io.grpc.Channel
import io.grpc.MethodDescriptor
import io.grpc.ServerServiceDefinition
import io.grpc.ServiceDescriptor
import io.grpc.ServiceDescriptor.newBuilder
import io.grpc.stub.AbstractStub
import io.grpc.stub.ClientCalls.blockingServerStreamingCall
import io.grpc.stub.StreamObserver
import java.io.InputStream
import java.lang.Class
import java.lang.UnsupportedOperationException
import java.util.concurrent.ExecutorService
import kotlin.Array
import kotlin.String
import kotlin.collections.Iterator
import kotlin.collections.Map
import kotlin.collections.Set
import kotlin.jvm.Volatile
import io.grpc.stub.ClientCalls.asyncClientStreamingCall as clientCallsAsyncClientStreamingCall
import io.grpc.stub.ServerCalls.asyncClientStreamingCall as serverCallsAsyncClientStreamingCall

public object TestServiceWireGrpc {
public const val SERVICE_NAME: String = "test.TestService"

@Volatile
private var serviceDescriptor: ServiceDescriptor? = null

private val descriptorMap: Map<String, DescriptorProtos.FileDescriptorProto> =
createDescriptorMap0()


@Volatile
private var getTestRPCMethod: MethodDescriptor<Test, Test>? = null

private fun descriptorFor(`data`: Array<String>): DescriptorProtos.FileDescriptorProto {
val str = data.fold(java.lang.StringBuilder()) { b, s -> b.append(s) }.toString()
val bytes = java.util.Base64.getDecoder().decode(str)
return DescriptorProtos.FileDescriptorProto.parseFrom(bytes)
}

private fun fileDescriptor(path: String, visited: Set<String>): Descriptors.FileDescriptor {
val proto = descriptorMap[path]!!
val deps = proto.dependencyList.filter { !visited.contains(it) }.map { fileDescriptor(it,
visited + path) }
return Descriptors.FileDescriptor.buildFrom(proto, deps.toTypedArray())
}

private fun createDescriptorMap0(): Map<String, DescriptorProtos.FileDescriptorProto> {
val subMap = mapOf(
"service.proto" to descriptorFor(arrayOf(
"Cg1zZXJ2aWNlLnByb3RvEgR0ZXN0IgYKBFRlc3QyMgoLVGVzdFNlcnZpY2USIwoHVGVzdFJQQxIKLnRl",
"c3QuVGVzdBoKLnRlc3QuVGVzdCgB",
)),
)
return subMap
}

public fun getServiceDescriptor(): ServiceDescriptor? {
var result = serviceDescriptor
if (result == null) {
synchronized(TestServiceWireGrpc::class) {
result = serviceDescriptor
if (result == null) {
result = newBuilder(SERVICE_NAME)
.addMethod(getTestRPCMethod())
.setSchemaDescriptor(io.grpc.protobuf.ProtoFileDescriptorSupplier {
fileDescriptor("service.proto", emptySet())
})
.build()
serviceDescriptor = result
}
}
}
return result
}

public fun getTestRPCMethod(): MethodDescriptor<Test, Test> {
var result: MethodDescriptor<Test, Test>? = getTestRPCMethod
if (result == null) {
synchronized(TestServiceWireGrpc::class) {
result = getTestRPCMethod
if (result == null) {
getTestRPCMethod = MethodDescriptor.newBuilder<Test, Test>()
.setType(MethodDescriptor.MethodType.CLIENT_STREAMING)
.setFullMethodName(
MethodDescriptor.generateFullMethodName(
"test.TestService", "TestRPC"
)
)
.setSampledToLocalTracing(true)
.setRequestMarshaller(TestServiceImplBase.TestMarshaller())
.setResponseMarshaller(TestServiceImplBase.TestMarshaller())
.build()
}
}
}
return getTestRPCMethod!!
}

public fun newStub(channel: Channel): TestServiceStub = TestServiceStub(channel)

public fun newBlockingStub(channel: Channel): TestServiceBlockingStub =
TestServiceBlockingStub(channel)

public abstract class TestServiceImplBase : WireBindableService {
public open fun TestRPC(response: StreamObserver<Test>): StreamObserver<Test> = throw
UnsupportedOperationException()

override fun bindService(): ServerServiceDefinition =
ServerServiceDefinition.builder(getServiceDescriptor()).addMethod(
getTestRPCMethod(),
serverCallsAsyncClientStreamingCall(this@TestServiceImplBase::TestRPC)
).build()

public class TestMarshaller : WireMethodMarshaller<Test> {
override fun stream(`value`: Test): InputStream = Test.ADAPTER.encode(value).inputStream()

override fun marshalledClass(): Class<Test> = Test::class.java

override fun parse(stream: InputStream): Test = Test.ADAPTER.decode(stream)
}
}

public class BindableAdapter(
private val streamExecutor: ExecutorService,
private val service: () -> TestServiceBlockingServer,
) : TestServiceImplBase() {
override fun TestRPC(response: StreamObserver<Test>): StreamObserver<Test> {
val requestStream = MessageSourceAdapter<Test>()
streamExecutor.submit {
response.onNext(service().TestRPC(requestStream))
response.onCompleted()
}
return requestStream
}
}

public class TestServiceStub : AbstractStub<TestServiceStub> {
internal constructor(channel: Channel) : super(channel)

internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions)

override fun build(channel: Channel, callOptions: CallOptions): TestServiceStub =
TestServiceStub(channel, callOptions)

public fun TestRPC(response: StreamObserver<Test>): StreamObserver<Test> =
clientCallsAsyncClientStreamingCall(channel.newCall(getTestRPCMethod(), callOptions),
response)
}

public class TestServiceBlockingStub : AbstractStub<TestServiceBlockingStub> {
internal constructor(channel: Channel) : super(channel)

internal constructor(channel: Channel, callOptions: CallOptions) : super(channel, callOptions)

override fun build(channel: Channel, callOptions: CallOptions): TestServiceBlockingStub =
TestServiceBlockingStub(channel, callOptions)

public fun TestRPC(request: Test): Iterator<Test> = blockingServerStreamingCall(channel,
getTestRPCMethod(), callOptions, request)
}
}
Loading