Skip to content

Commit ab92b8b

Browse files
author
Varun Deep Saini
committed
Sync registered model aliases via SetAlias/DeleteAlias in direct mode
Signed-off-by: Varun Deep Saini <varun.23bcs10048@ms.sst.scaler.com>
1 parent c520ed9 commit ab92b8b

4 files changed

Lines changed: 342 additions & 7 deletions

File tree

bundle/direct/dresources/registered_model.go

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package dresources
22

33
import (
44
"context"
5+
"fmt"
56

67
"github.com/databricks/cli/bundle/config/resources"
78
"github.com/databricks/cli/libs/utils"
89
"github.com/databricks/databricks-sdk-go"
910
"github.com/databricks/databricks-sdk-go/service/catalog"
11+
"golang.org/x/sync/errgroup"
1012
)
1113

1214
type ResourceRegisteredModel struct {
@@ -48,10 +50,9 @@ func (*ResourceRegisteredModel) RemapState(model *catalog.RegisteredModelInfo) *
4850

4951
func (r *ResourceRegisteredModel) DoRead(ctx context.Context, id string) (*catalog.RegisteredModelInfo, error) {
5052
return r.client.RegisteredModels.Get(ctx, catalog.GetRegisteredModelRequest{
51-
FullName: id,
52-
IncludeAliases: false,
53-
IncludeBrowse: false,
54-
ForceSendFields: nil,
53+
FullName: id,
54+
IncludeAliases: true,
55+
IncludeBrowse: false,
5556
})
5657
}
5758

@@ -61,10 +62,24 @@ func (r *ResourceRegisteredModel) DoCreate(ctx context.Context, config *catalog.
6162
return "", nil, err
6263
}
6364

65+
// The Create API does not apply aliases, so we sync them separately.
66+
if err := r.syncAliases(ctx, response.FullName, config.Aliases, nil); err != nil {
67+
return "", nil, fmt.Errorf("failed to sync aliases: %w", err)
68+
}
69+
6470
return response.FullName, response, nil
6571
}
6672

6773
func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, config *catalog.CreateRegisteredModelRequest, _ Changes) (*catalog.RegisteredModelInfo, error) {
74+
// Fetch current remote state to determine which aliases to add/remove.
75+
remote, err := r.client.RegisteredModels.Get(ctx, catalog.GetRegisteredModelRequest{
76+
FullName: id,
77+
IncludeAliases: true,
78+
})
79+
if err != nil {
80+
return nil, fmt.Errorf("failed to read current state for alias sync: %w", err)
81+
}
82+
6883
updateRequest := catalog.UpdateRegisteredModelRequest{
6984
FullName: id,
7085
Comment: config.Comment,
@@ -77,7 +92,8 @@ func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, confi
7792
// Note: TF also does not support changing name without a recreate so the current behavior matches TF.
7893
NewName: "",
7994

80-
Aliases: config.Aliases,
95+
// Aliases are synced separately via SetAlias/DeleteAlias calls because
96+
// the Update API ignores the Aliases field.
8197
BrowseOnly: config.BrowseOnly,
8298
CreatedAt: config.CreatedAt,
8399
CreatedBy: config.CreatedBy,
@@ -90,8 +106,20 @@ func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, confi
90106
CatalogName: config.CatalogName,
91107
}
92108

93-
response, err := r.client.RegisteredModels.Update(ctx, updateRequest)
94-
if err != nil {
109+
var eg errgroup.Group
110+
var response *catalog.RegisteredModelInfo
111+
112+
eg.Go(func() error {
113+
var err error
114+
response, err = r.client.RegisteredModels.Update(ctx, updateRequest)
115+
return err
116+
})
117+
118+
eg.Go(func() error {
119+
return r.syncAliases(ctx, id, config.Aliases, remote.Aliases)
120+
})
121+
122+
if err := eg.Wait(); err != nil {
95123
return nil, err
96124
}
97125

@@ -103,3 +131,49 @@ func (r *ResourceRegisteredModel) DoDelete(ctx context.Context, id string) error
103131
FullName: id,
104132
})
105133
}
134+
135+
// syncAliases compares desired and current aliases and calls SetAlias/DeleteAlias
136+
// APIs to reconcile the difference. The Update API ignores the Aliases field,
137+
// so separate API calls are required.
138+
func (r *ResourceRegisteredModel) syncAliases(ctx context.Context, fullName string, desired, current []catalog.RegisteredModelAlias) error {
139+
desiredByName := make(map[string]int, len(desired))
140+
for _, a := range desired {
141+
desiredByName[a.AliasName] = a.VersionNum
142+
}
143+
144+
currentByName := make(map[string]int, len(current))
145+
for _, a := range current {
146+
currentByName[a.AliasName] = a.VersionNum
147+
}
148+
149+
var eg errgroup.Group
150+
151+
// Set new or updated aliases.
152+
for name, version := range desiredByName {
153+
if v, ok := currentByName[name]; ok && v == version {
154+
continue
155+
}
156+
eg.Go(func() error {
157+
_, err := r.client.RegisteredModels.SetAlias(ctx, catalog.SetRegisteredModelAliasRequest{
158+
FullName: fullName,
159+
Alias: name,
160+
VersionNum: version,
161+
})
162+
return err
163+
})
164+
}
165+
166+
// Delete removed aliases.
167+
for name := range currentByName {
168+
if _, ok := desiredByName[name]; !ok {
169+
eg.Go(func() error {
170+
return r.client.RegisteredModels.DeleteAlias(ctx, catalog.DeleteAliasRequest{
171+
FullName: fullName,
172+
Alias: name,
173+
})
174+
})
175+
}
176+
}
177+
178+
return eg.Wait()
179+
}
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
package dresources
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/databricks/cli/libs/testserver"
8+
"github.com/databricks/databricks-sdk-go"
9+
"github.com/databricks/databricks-sdk-go/service/catalog"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func setupRegisteredModelTest(t *testing.T) (*ResourceRegisteredModel, *testserver.Server) {
15+
server := testserver.New(t)
16+
testserver.AddDefaultHandlers(server)
17+
18+
client, err := databricks.NewWorkspaceClient(&databricks.Config{
19+
Host: server.URL,
20+
Token: "testtoken",
21+
})
22+
require.NoError(t, err)
23+
24+
r := (&ResourceRegisteredModel{}).New(client)
25+
return r, server
26+
}
27+
28+
func TestRegisteredModelCreateSetsAliases(t *testing.T) {
29+
r, _ := setupRegisteredModelTest(t)
30+
ctx := context.Background()
31+
32+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
33+
Name: "my_model",
34+
CatalogName: "main",
35+
SchemaName: "default",
36+
Aliases: []catalog.RegisteredModelAlias{
37+
{AliasName: "champion", VersionNum: 1},
38+
{AliasName: "staging", VersionNum: 2},
39+
},
40+
})
41+
require.NoError(t, err)
42+
43+
remote, err := r.DoRead(ctx, id)
44+
require.NoError(t, err)
45+
46+
assert.Len(t, remote.Aliases, 2)
47+
aliasByName := aliasMap(remote.Aliases)
48+
assert.Equal(t, 1, aliasByName["champion"])
49+
assert.Equal(t, 2, aliasByName["staging"])
50+
}
51+
52+
func TestRegisteredModelUpdateAddsAliases(t *testing.T) {
53+
r, _ := setupRegisteredModelTest(t)
54+
ctx := context.Background()
55+
56+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
57+
Name: "my_model",
58+
CatalogName: "main",
59+
SchemaName: "default",
60+
})
61+
require.NoError(t, err)
62+
63+
_, err = r.DoUpdate(ctx, id, &catalog.CreateRegisteredModelRequest{
64+
Name: "my_model",
65+
CatalogName: "main",
66+
SchemaName: "default",
67+
Aliases: []catalog.RegisteredModelAlias{
68+
{AliasName: "champion", VersionNum: 3},
69+
},
70+
}, nil)
71+
require.NoError(t, err)
72+
73+
remote, err := r.DoRead(ctx, id)
74+
require.NoError(t, err)
75+
76+
assert.Len(t, remote.Aliases, 1)
77+
assert.Equal(t, "champion", remote.Aliases[0].AliasName)
78+
assert.Equal(t, 3, remote.Aliases[0].VersionNum)
79+
}
80+
81+
func TestRegisteredModelUpdateModifiesAndDeletesAliases(t *testing.T) {
82+
r, _ := setupRegisteredModelTest(t)
83+
ctx := context.Background()
84+
85+
// Create with two aliases.
86+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
87+
Name: "my_model",
88+
CatalogName: "main",
89+
SchemaName: "default",
90+
Aliases: []catalog.RegisteredModelAlias{
91+
{AliasName: "champion", VersionNum: 1},
92+
{AliasName: "staging", VersionNum: 2},
93+
},
94+
})
95+
require.NoError(t, err)
96+
97+
// Update: modify "champion" version, remove "staging", add "latest".
98+
_, err = r.DoUpdate(ctx, id, &catalog.CreateRegisteredModelRequest{
99+
Name: "my_model",
100+
CatalogName: "main",
101+
SchemaName: "default",
102+
Aliases: []catalog.RegisteredModelAlias{
103+
{AliasName: "champion", VersionNum: 5},
104+
{AliasName: "latest", VersionNum: 3},
105+
},
106+
}, nil)
107+
require.NoError(t, err)
108+
109+
remote, err := r.DoRead(ctx, id)
110+
require.NoError(t, err)
111+
112+
assert.Len(t, remote.Aliases, 2)
113+
aliasByName := aliasMap(remote.Aliases)
114+
assert.Equal(t, 5, aliasByName["champion"])
115+
assert.Equal(t, 3, aliasByName["latest"])
116+
_, hasStaging := aliasByName["staging"]
117+
assert.False(t, hasStaging)
118+
}
119+
120+
func TestRegisteredModelUpdateRemovesAllAliases(t *testing.T) {
121+
r, _ := setupRegisteredModelTest(t)
122+
ctx := context.Background()
123+
124+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
125+
Name: "my_model",
126+
CatalogName: "main",
127+
SchemaName: "default",
128+
Aliases: []catalog.RegisteredModelAlias{
129+
{AliasName: "champion", VersionNum: 1},
130+
},
131+
})
132+
require.NoError(t, err)
133+
134+
// Update with no aliases to remove all.
135+
_, err = r.DoUpdate(ctx, id, &catalog.CreateRegisteredModelRequest{
136+
Name: "my_model",
137+
CatalogName: "main",
138+
SchemaName: "default",
139+
}, nil)
140+
require.NoError(t, err)
141+
142+
remote, err := r.DoRead(ctx, id)
143+
require.NoError(t, err)
144+
145+
assert.Empty(t, remote.Aliases)
146+
}
147+
148+
func TestRegisteredModelUpdateNoopWhenAliasesUnchanged(t *testing.T) {
149+
r, server := setupRegisteredModelTest(t)
150+
ctx := context.Background()
151+
152+
aliases := []catalog.RegisteredModelAlias{
153+
{AliasName: "champion", VersionNum: 1},
154+
}
155+
156+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
157+
Name: "my_model",
158+
CatalogName: "main",
159+
SchemaName: "default",
160+
Aliases: aliases,
161+
})
162+
require.NoError(t, err)
163+
164+
// Count SetAlias calls during update.
165+
setAliasCalls := 0
166+
server.Handle("PUT", "/api/2.1/unity-catalog/models/{full_name}/aliases/{alias}", func(req testserver.Request) any {
167+
setAliasCalls++
168+
return req.Workspace.RegisteredModelsSetAlias(req, req.Vars["full_name"], req.Vars["alias"])
169+
})
170+
171+
_, err = r.DoUpdate(ctx, id, &catalog.CreateRegisteredModelRequest{
172+
Name: "my_model",
173+
CatalogName: "main",
174+
SchemaName: "default",
175+
Aliases: aliases,
176+
}, nil)
177+
require.NoError(t, err)
178+
179+
assert.Equal(t, 0, setAliasCalls, "expected no SetAlias calls when aliases are unchanged")
180+
}
181+
182+
func aliasMap(aliases []catalog.RegisteredModelAlias) map[string]int {
183+
m := make(map[string]int, len(aliases))
184+
for _, a := range aliases {
185+
m[a.AliasName] = a.VersionNum
186+
}
187+
return m
188+
}

libs/testserver/handlers.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,14 @@ func AddDefaultHandlers(server *Server) {
485485
return MapDelete(req.Workspace, req.Workspace.RegisteredModels, req.Vars["full_name"])
486486
})
487487

488+
server.Handle("PUT", "/api/2.1/unity-catalog/models/{full_name}/aliases/{alias}", func(req Request) any {
489+
return req.Workspace.RegisteredModelsSetAlias(req, req.Vars["full_name"], req.Vars["alias"])
490+
})
491+
492+
server.Handle("DELETE", "/api/2.1/unity-catalog/models/{full_name}/aliases/{alias}", func(req Request) any {
493+
return req.Workspace.RegisteredModelsDeleteAlias(req.Vars["full_name"], req.Vars["alias"])
494+
})
495+
488496
// Volumes:
489497

490498
server.Handle("GET", "/api/2.1/unity-catalog/volumes/{full_name}", func(req Request) any {

0 commit comments

Comments
 (0)