Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 13 additions & 82 deletions protocol/triple/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"fmt"
"net"
"net/http"
"reflect"
"strings"
"time"
)
Expand Down Expand Up @@ -57,66 +56,41 @@ const (
// callUnary, callClientStream, callServerStream, callBidiStream.
// A Reference has a clientManager.
type clientManager struct {
isIDL bool
// triple_protocol clients, key is method name
triClients map[string]*tri.Client
isIDL bool
triClient *tri.Client
}

// TODO: code a triple client between clientManager and triple_protocol client
// TODO: write a NewClient for triple client

func (cm *clientManager) getClient(method string) (*tri.Client, error) {
triClient, ok := cm.triClients[method]
if !ok {
return nil, fmt.Errorf("missing triple client for method: %s", method)
}
return triClient, nil
}

func (cm *clientManager) callUnary(ctx context.Context, method string, req, resp any) error {
triClient, err := cm.getClient(method)
if err != nil {
return err
}
triReq := tri.NewRequest(req)
triResp := tri.NewResponse(resp)
if err := triClient.CallUnary(ctx, triReq, triResp); err != nil {
if err := cm.triClient.CallUnary(ctx, triReq, method, triResp); err != nil {
return err
}
return nil
}

func (cm *clientManager) callClientStream(ctx context.Context, method string) (any, error) {
triClient, err := cm.getClient(method)
if err != nil {
return nil, err
}
stream, err := triClient.CallClientStream(ctx)
stream, err := cm.triClient.CallClientStream(ctx, method)
if err != nil {
return nil, err
}
return stream, nil
}

func (cm *clientManager) callServerStream(ctx context.Context, method string, req any) (any, error) {
triClient, err := cm.getClient(method)
if err != nil {
return nil, err
}
triReq := tri.NewRequest(req)
stream, err := triClient.CallServerStream(ctx, triReq)
stream, err := cm.triClient.CallServerStream(ctx, triReq, method)
if err != nil {
return nil, err
}
return stream, nil
}

func (cm *clientManager) callBidiStream(ctx context.Context, method string) (any, error) {
triClient, err := cm.getClient(method)
if err != nil {
return nil, err
}
stream, err := triClient.CallBidiStream(ctx)
stream, err := cm.triClient.CallBidiStream(ctx, method)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -282,59 +256,16 @@ func newClientManager(url *common.URL) (*clientManager, error) {
baseTriURL = httpPrefix + baseTriURL
}

triClients := make(map[string]*tri.Client)

// Check if this is a generic call - for generic call, we only need $invoke method
generic := url.GetParam(constant.GenericKey, "")
isGeneric := isGenericCall(generic)

if isGeneric {
// For generic call, only register $invoke method
invokeURL, err := joinPath(baseTriURL, url.Interface(), constant.Generic)
if err != nil {
return nil, fmt.Errorf("JoinPath failed for base %s, interface %s, method %s", baseTriURL, url.Interface(), constant.Generic)
}
triClients[constant.Generic] = tri.NewClient(httpClient, invokeURL, cliOpts...)
} else if len(url.Methods) != 0 {
for _, method := range url.Methods {
triURL, err := joinPath(baseTriURL, url.Interface(), method)
if err != nil {
return nil, fmt.Errorf("JoinPath failed for base %s, interface %s, method %s", baseTriURL, url.Interface(), method)
}
triClient := tri.NewClient(httpClient, triURL, cliOpts...)
triClients[method] = triClient
}
} else {
// This branch is for the non-IDL mode, where we pass in the service solely
// for the purpose of using reflection to obtain all methods of the service.
// There might be potential for optimization in this area later on.
service, ok := url.GetAttribute(constant.RpcServiceKey)
if !ok {
return nil, fmt.Errorf("triple clientmanager can't get methods")
}

serviceType := reflect.TypeOf(service)
for i := range serviceType.NumMethod() {
methodName := serviceType.Method(i).Name
triURL, err := joinPath(baseTriURL, url.Interface(), methodName)
if err != nil {
return nil, fmt.Errorf("JoinPath failed for base %s, interface %s, method %s", baseTriURL, url.Interface(), methodName)
}
triClient := tri.NewClient(httpClient, triURL, cliOpts...)
triClients[methodName] = triClient
}

// Register $invoke method for generic call support in non-IDL mode
invokeURL, err := joinPath(baseTriURL, url.Interface(), constant.Generic)
if err != nil {
return nil, fmt.Errorf("JoinPath failed for base %s, interface %s, method %s", baseTriURL, url.Interface(), constant.Generic)
}
triClients[constant.Generic] = tri.NewClient(httpClient, invokeURL, cliOpts...)
triURL, err := joinPath(baseTriURL, url.Interface())
if err != nil {
return nil, fmt.Errorf("JoinPath failed for base %s, interface %s", baseTriURL, url.Interface())
}

triClient := tri.NewClient(httpClient, triURL, cliOpts...)

return &clientManager{
isIDL: isIDL,
triClients: triClients,
isIDL: isIDL,
triClient: triClient,
}, nil
}

Expand Down
129 changes: 16 additions & 113 deletions protocol/triple/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,7 @@ func TestClientManager_HTTP2AndHTTP3(t *testing.T) {
// If successfully created, verify the client manager
assert.NotNil(t, clientManager)
assert.True(t, clientManager.isIDL)
assert.NotEmpty(t, clientManager.triClients)

// Verify that the client for the specific method exists
client, exists := clientManager.triClients["testMethod"]
assert.True(t, exists)
assert.NotNil(t, client)
assert.NotNil(t, clientManager.triClient)
}

func TestDualTransport(t *testing.T) {
Expand All @@ -98,104 +93,18 @@ func TestDualTransport(t *testing.T) {
assert.True(t, ok, "transport should implement http.RoundTripper")
}

func TestClientManager_GetClient(t *testing.T) {
tests := []struct {
desc string
cm *clientManager
method string
expectErr bool
}{
{
desc: "method exists",
cm: &clientManager{
triClients: map[string]*tri.Client{
"TestMethod": tri.NewClient(&http.Client{}, "http://localhost:8080/test"),
},
},
method: "TestMethod",
expectErr: false,
},
{
desc: "method not exists",
cm: &clientManager{
triClients: map[string]*tri.Client{
"TestMethod": tri.NewClient(&http.Client{}, "http://localhost:8080/test"),
},
},
method: "NonExistMethod",
expectErr: true,
},
{
desc: "empty triClients",
cm: &clientManager{
triClients: map[string]*tri.Client{},
},
method: "AnyMethod",
expectErr: true,
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
client, err := test.cm.getClient(test.method)
if test.expectErr {
require.Error(t, err)
assert.Nil(t, client)
assert.Contains(t, err.Error(), "missing triple client")
} else {
require.NoError(t, err)
assert.NotNil(t, client)
}
})
}
}

func TestClientManager_Close(t *testing.T) {
cm := &clientManager{
isIDL: true,
triClients: map[string]*tri.Client{
"Method1": tri.NewClient(&http.Client{}, "http://localhost:8080/test1"),
"Method2": tri.NewClient(&http.Client{}, "http://localhost:8080/test2"),
},
isIDL: true,
triClient: tri.NewClient(&http.Client{}, "http://localhost:8080/test"),
}

err := cm.close()
require.NoError(t, err)
}

func TestClientManager_CallMethods_MissingClient(t *testing.T) {
cm := &clientManager{
triClients: map[string]*tri.Client{},
}
ctx := context.Background()

t.Run("callUnary missing client", func(t *testing.T) {
err := cm.callUnary(ctx, "NonExist", nil, nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "missing triple client")
})

t.Run("callClientStream missing client", func(t *testing.T) {
stream, err := cm.callClientStream(ctx, "NonExist")
require.Error(t, err)
assert.Nil(t, stream)
assert.Contains(t, err.Error(), "missing triple client")
})

t.Run("callServerStream missing client", func(t *testing.T) {
stream, err := cm.callServerStream(ctx, "NonExist", nil)
require.Error(t, err)
assert.Nil(t, stream)
assert.Contains(t, err.Error(), "missing triple client")
})

t.Run("callBidiStream missing client", func(t *testing.T) {
stream, err := cm.callBidiStream(ctx, "NonExist")
require.Error(t, err)
assert.Nil(t, stream)
assert.Contains(t, err.Error(), "missing triple client")
})
}
// TestClientManager_CallMethods_MissingClient removed - no longer applicable
// in the service-level client architecture where all methods share a single triClient.

func Test_genKeepAliveOptions(t *testing.T) {
defaultInterval, _ := time.ParseDuration(constant.DefaultKeepAliveInterval)
Expand Down Expand Up @@ -359,16 +268,17 @@ func Test_newClientManager_Serialization(t *testing.T) {
}

func Test_newClientManager_NoMethods(t *testing.T) {
// Test when url has no methods and no RpcServiceKey attribute
// Test when url has no methods - in service-level client architecture,
// this is valid as the client is created at service level, not method level
url := common.NewURLWithOptions(
common.WithLocation("localhost:20000"),
common.WithPath("com.example.TestService"),
)

cm, err := newClientManager(url)
require.Error(t, err)
assert.Nil(t, cm)
assert.Contains(t, err.Error(), "can't get methods")
require.NoError(t, err, "service-level client should be created even without method list")
assert.NotNil(t, cm)
assert.NotNil(t, cm.triClient, "triClient should be created at service level")
}

func Test_newClientManager_WithMethods(t *testing.T) {
Expand All @@ -381,10 +291,7 @@ func Test_newClientManager_WithMethods(t *testing.T) {
cm, err := newClientManager(url)
require.NoError(t, err)
assert.NotNil(t, cm)
assert.Len(t, cm.triClients, 3)
assert.Contains(t, cm.triClients, "Method1")
assert.Contains(t, cm.triClients, "Method2")
assert.Contains(t, cm.triClients, "Method3")
assert.NotNil(t, cm.triClient, "triClient should be created")
}

func Test_newClientManager_WithGroupAndVersion(t *testing.T) {
Expand Down Expand Up @@ -475,8 +382,8 @@ func Test_newClientManager_WithRpcService(t *testing.T) {
cm, err := newClientManager(url)
require.NoError(t, err)
assert.NotNil(t, cm)
// Should have methods from mockService (Reference, TestMethod1, TestMethod2)
assert.GreaterOrEqual(t, len(cm.triClients), 2)
// In service-level client architecture, a single triClient is created
assert.NotNil(t, cm.triClient, "triClient should be created for non-IDL mode")
}

func TestDualTransport_Structure(t *testing.T) {
Expand Down Expand Up @@ -552,7 +459,7 @@ func Test_newClientManager_URLPrefixHandling(t *testing.T) {
cm, err := newClientManager(url)
require.NoError(t, err)
assert.NotNil(t, cm)
assert.Len(t, cm.triClients, 1)
assert.NotNil(t, cm.triClient, "triClient should be created")
})
}
}
Expand Down Expand Up @@ -635,12 +542,8 @@ func Test_newClientManager_MultipleMethods(t *testing.T) {
cm, err := newClientManager(url)
require.NoError(t, err)
assert.NotNil(t, cm)
assert.Len(t, cm.triClients, len(methods))

for _, method := range methods {
_, exists := cm.triClients[method]
assert.True(t, exists, "method %s should exist", method)
}
// In service-level client architecture, a single triClient handles all methods
assert.NotNil(t, cm.triClient, "triClient should be created to handle all methods")
}

func Test_newClientManager_InterfaceName(t *testing.T) {
Expand Down
Loading
Loading