Skip to content

Commit 86d4e85

Browse files
author
Varun Deep Saini
committed
Sync registered model aliases via SetAlias/DeleteAlias in direct mode
Signed-off-by: Varun Deep Saini <varun.23bcs10048@ms.sst.scaler.com>
1 parent 0c0b80f commit 86d4e85

8 files changed

Lines changed: 311 additions & 6 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
### Bundles
88
* Validate that either source_code_path or git_source is set for apps ([#4632](https://github.com/databricks/cli/pull/4632))
9+
* direct: Sync registered model aliases via SetAlias/DeleteAlias ([#4637](https://github.com/databricks/cli/pull/4637))
910

1011
### Dependency updates
1112

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
bundle:
2+
name: deploy-registered-models-aliases-$UNIQUE_NAME
3+
4+
resources:
5+
registered_models:
6+
my_registered_model:
7+
name: my-registered-model-$UNIQUE_NAME
8+
comment: "test model"
9+
catalog_name: main
10+
schema_name: default
11+
aliases:
12+
- alias_name: champion
13+
version_num: 1
14+
- alias_name: staging
15+
version_num: 2
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
envsubst < databricks.yml.tmpl > databricks.yml
2+
3+
cleanup() {
4+
trace $CLI bundle destroy --auto-approve
5+
}
6+
trap cleanup EXIT
7+
8+
get_aliases() {
9+
registered_model_id=$($CLI bundle summary --output json | jq -r '.resources.registered_models.my_registered_model.id')
10+
$CLI registered-models get "${registered_model_id}" --include-aliases | jq '[.aliases[] | {alias_name, version_num}] | sort_by(.alias_name)'
11+
}
12+
13+
title "create with aliases"
14+
trace $CLI bundle deploy
15+
trace print_requests.py //unity-catalog/models --sort > out.create-requests.$DATABRICKS_BUNDLE_ENGINE.json
16+
trace get_aliases
17+
18+
title "update: modify champion version, remove staging, add latest"
19+
update_file.py databricks.yml "version_num: 1" "version_num: 5"
20+
update_file.py databricks.yml "staging" "latest"
21+
update_file.py databricks.yml "version_num: 2" "version_num: 3"
22+
trace $CLI bundle deploy
23+
trace print_requests.py //unity-catalog/models --sort > out.update-requests.$DATABRICKS_BUNDLE_ENGINE.json
24+
trace get_aliases
25+
26+
title "remove all aliases"
27+
update_file.py databricks.yml " aliases:" ""
28+
update_file.py databricks.yml " - alias_name: champion" ""
29+
update_file.py databricks.yml " version_num: 5" ""
30+
update_file.py databricks.yml " - alias_name: latest" ""
31+
update_file.py databricks.yml " version_num: 3" ""
32+
trace $CLI bundle deploy
33+
trace print_requests.py //unity-catalog/models --sort > out.remove-requests.$DATABRICKS_BUNDLE_ENGINE.json
34+
trace get_aliases
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Cloud = true
2+
Local = true
3+
RequiresUnityCatalog = true

bundle/direct/dresources/registered_model.go

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ import (
44
"context"
55

66
"github.com/databricks/cli/bundle/config/resources"
7+
"github.com/databricks/cli/libs/structs/structpath"
78
"github.com/databricks/cli/libs/utils"
89
"github.com/databricks/databricks-sdk-go"
910
"github.com/databricks/databricks-sdk-go/service/catalog"
11+
"golang.org/x/sync/errgroup"
1012
)
1113

14+
var pathAliases = structpath.MustParsePath("aliases")
15+
1216
type ResourceRegisteredModel struct {
1317
client *databricks.WorkspaceClient
1418
}
@@ -48,10 +52,9 @@ func (*ResourceRegisteredModel) RemapState(model *catalog.RegisteredModelInfo) *
4852

4953
func (r *ResourceRegisteredModel) DoRead(ctx context.Context, id string) (*catalog.RegisteredModelInfo, error) {
5054
return r.client.RegisteredModels.Get(ctx, catalog.GetRegisteredModelRequest{
51-
FullName: id,
52-
IncludeAliases: false,
53-
IncludeBrowse: false,
54-
ForceSendFields: nil,
55+
FullName: id,
56+
IncludeAliases: true,
57+
IncludeBrowse: false,
5558
})
5659
}
5760

@@ -61,10 +64,15 @@ func (r *ResourceRegisteredModel) DoCreate(ctx context.Context, config *catalog.
6164
return "", nil, err
6265
}
6366

67+
// The Create API does not apply aliases, so we sync them separately.
68+
if err := r.syncAliases(ctx, response.FullName, config.Aliases, nil); err != nil {
69+
return "", nil, err
70+
}
71+
6472
return response.FullName, response, nil
6573
}
6674

67-
func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, config *catalog.CreateRegisteredModelRequest, _ Changes) (*catalog.RegisteredModelInfo, error) {
75+
func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, config *catalog.CreateRegisteredModelRequest, changes Changes) (*catalog.RegisteredModelInfo, error) {
6876
updateRequest := catalog.UpdateRegisteredModelRequest{
6977
FullName: id,
7078
Comment: config.Comment,
@@ -77,7 +85,8 @@ func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, confi
7785
// Note: TF also does not support changing name without a recreate so the current behavior matches TF.
7886
NewName: "",
7987

80-
Aliases: config.Aliases,
88+
// Aliases are synced separately via SetAlias/DeleteAlias calls because
89+
// the Update API ignores the Aliases field.
8190
BrowseOnly: config.BrowseOnly,
8291
CreatedAt: config.CreatedAt,
8392
CreatedBy: config.CreatedBy,
@@ -95,6 +104,12 @@ func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, confi
95104
return nil, err
96105
}
97106

107+
if changes.HasChange(pathAliases) {
108+
if err := r.syncAliases(ctx, id, config.Aliases, nil); err != nil {
109+
return nil, err
110+
}
111+
}
112+
98113
return response, nil
99114
}
100115

@@ -103,3 +118,61 @@ func (r *ResourceRegisteredModel) DoDelete(ctx context.Context, id string) error
103118
FullName: id,
104119
})
105120
}
121+
122+
// syncAliases compares desired and current aliases and calls SetAlias/DeleteAlias
123+
// APIs to reconcile the difference. The Update API ignores the Aliases field,
124+
// so separate API calls are required.
125+
// If current is nil, the current aliases are fetched from the remote.
126+
func (r *ResourceRegisteredModel) syncAliases(ctx context.Context, fullName string, desired, current []catalog.RegisteredModelAlias) error {
127+
if current == nil {
128+
remote, err := r.client.RegisteredModels.Get(ctx, catalog.GetRegisteredModelRequest{
129+
FullName: fullName,
130+
IncludeAliases: true,
131+
})
132+
if err != nil {
133+
return err
134+
}
135+
current = remote.Aliases
136+
}
137+
138+
desiredByName := make(map[string]int, len(desired))
139+
for _, a := range desired {
140+
desiredByName[a.AliasName] = a.VersionNum
141+
}
142+
143+
currentByName := make(map[string]int, len(current))
144+
for _, a := range current {
145+
currentByName[a.AliasName] = a.VersionNum
146+
}
147+
148+
var eg errgroup.Group
149+
150+
// Set new or updated aliases.
151+
for name, version := range desiredByName {
152+
if v, ok := currentByName[name]; ok && v == version {
153+
continue
154+
}
155+
eg.Go(func() error {
156+
_, err := r.client.RegisteredModels.SetAlias(ctx, catalog.SetRegisteredModelAliasRequest{
157+
FullName: fullName,
158+
Alias: name,
159+
VersionNum: version,
160+
})
161+
return err
162+
})
163+
}
164+
165+
// Delete removed aliases.
166+
for name := range currentByName {
167+
if _, ok := desiredByName[name]; !ok {
168+
eg.Go(func() error {
169+
return r.client.RegisteredModels.DeleteAlias(ctx, catalog.DeleteAliasRequest{
170+
FullName: fullName,
171+
Alias: name,
172+
})
173+
})
174+
}
175+
}
176+
177+
return eg.Wait()
178+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package dresources
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/databricks/cli/libs/testserver"
8+
"github.com/databricks/databricks-sdk-go"
9+
"github.com/databricks/databricks-sdk-go/service/catalog"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func setupRegisteredModelTest(t *testing.T) *ResourceRegisteredModel {
15+
server := testserver.New(t)
16+
testserver.AddDefaultHandlers(server)
17+
18+
client, err := databricks.NewWorkspaceClient(&databricks.Config{
19+
Host: server.URL,
20+
Token: "testtoken",
21+
})
22+
require.NoError(t, err)
23+
24+
return (&ResourceRegisteredModel{}).New(client)
25+
}
26+
27+
func TestSyncAliases_AddsNewAliases(t *testing.T) {
28+
r := setupRegisteredModelTest(t)
29+
ctx := context.Background()
30+
31+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
32+
Name: "my_model", CatalogName: "main", SchemaName: "default",
33+
Aliases: []catalog.RegisteredModelAlias{
34+
{AliasName: "champion", VersionNum: 1},
35+
{AliasName: "staging", VersionNum: 2},
36+
},
37+
})
38+
require.NoError(t, err)
39+
40+
remote, err := r.DoRead(ctx, id)
41+
require.NoError(t, err)
42+
assert.Len(t, remote.Aliases, 2)
43+
44+
m := aliasMap(remote.Aliases)
45+
assert.Equal(t, 1, m["champion"])
46+
assert.Equal(t, 2, m["staging"])
47+
}
48+
49+
func TestSyncAliases_UpdatesAndDeletesAliases(t *testing.T) {
50+
r := setupRegisteredModelTest(t)
51+
ctx := context.Background()
52+
53+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
54+
Name: "my_model", CatalogName: "main", SchemaName: "default",
55+
Aliases: []catalog.RegisteredModelAlias{
56+
{AliasName: "champion", VersionNum: 1},
57+
{AliasName: "staging", VersionNum: 2},
58+
},
59+
})
60+
require.NoError(t, err)
61+
62+
// Modify champion version, remove staging, add latest.
63+
err = r.syncAliases(ctx, id, []catalog.RegisteredModelAlias{
64+
{AliasName: "champion", VersionNum: 5},
65+
{AliasName: "latest", VersionNum: 3},
66+
}, nil)
67+
require.NoError(t, err)
68+
69+
remote, err := r.DoRead(ctx, id)
70+
require.NoError(t, err)
71+
assert.Len(t, remote.Aliases, 2)
72+
73+
m := aliasMap(remote.Aliases)
74+
assert.Equal(t, 5, m["champion"])
75+
assert.Equal(t, 3, m["latest"])
76+
_, hasStaging := m["staging"]
77+
assert.False(t, hasStaging)
78+
}
79+
80+
func TestSyncAliases_RemovesAllAliases(t *testing.T) {
81+
r := setupRegisteredModelTest(t)
82+
ctx := context.Background()
83+
84+
id, _, err := r.DoCreate(ctx, &catalog.CreateRegisteredModelRequest{
85+
Name: "my_model", CatalogName: "main", SchemaName: "default",
86+
Aliases: []catalog.RegisteredModelAlias{
87+
{AliasName: "champion", VersionNum: 1},
88+
},
89+
})
90+
require.NoError(t, err)
91+
92+
err = r.syncAliases(ctx, id, nil, nil)
93+
require.NoError(t, err)
94+
95+
remote, err := r.DoRead(ctx, id)
96+
require.NoError(t, err)
97+
assert.Empty(t, remote.Aliases)
98+
}
99+
100+
func aliasMap(aliases []catalog.RegisteredModelAlias) map[string]int {
101+
m := make(map[string]int, len(aliases))
102+
for _, a := range aliases {
103+
m[a.AliasName] = a.VersionNum
104+
}
105+
return m
106+
}

libs/testserver/handlers.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,14 @@ func AddDefaultHandlers(server *Server) {
485485
return MapDelete(req.Workspace, req.Workspace.RegisteredModels, req.Vars["full_name"])
486486
})
487487

488+
server.Handle("PUT", "/api/2.1/unity-catalog/models/{full_name}/aliases/{alias}", func(req Request) any {
489+
return req.Workspace.RegisteredModelsSetAlias(req, req.Vars["full_name"], req.Vars["alias"])
490+
})
491+
492+
server.Handle("DELETE", "/api/2.1/unity-catalog/models/{full_name}/aliases/{alias}", func(req Request) any {
493+
return req.Workspace.RegisteredModelsDeleteAlias(req.Vars["full_name"], req.Vars["alias"])
494+
})
495+
488496
// Volumes:
489497

490498
server.Handle("GET", "/api/2.1/unity-catalog/volumes/{full_name}", func(req Request) any {

libs/testserver/registered_models.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,68 @@ func (s *FakeWorkspace) RegisteredModelsUpdate(req Request, fullName string) Res
8484
Body: existing,
8585
}
8686
}
87+
88+
func (s *FakeWorkspace) RegisteredModelsSetAlias(req Request, fullName, alias string) Response {
89+
defer s.LockUnlock()()
90+
91+
existing, ok := s.RegisteredModels[fullName]
92+
if !ok {
93+
return Response{
94+
StatusCode: http.StatusNotFound,
95+
Body: fmt.Sprintf("registered model %s not found", fullName),
96+
}
97+
}
98+
99+
var setRequest catalog.SetRegisteredModelAliasRequest
100+
if err := json.Unmarshal(req.Body, &setRequest); err != nil {
101+
return Response{
102+
Body: fmt.Sprintf("internal error: %s", err),
103+
StatusCode: http.StatusInternalServerError,
104+
}
105+
}
106+
107+
newAlias := catalog.RegisteredModelAlias{
108+
AliasName: alias,
109+
VersionNum: setRequest.VersionNum,
110+
}
111+
112+
// Update existing alias or append new one.
113+
found := false
114+
for i, a := range existing.Aliases {
115+
if a.AliasName == alias {
116+
existing.Aliases[i] = newAlias
117+
found = true
118+
break
119+
}
120+
}
121+
if !found {
122+
existing.Aliases = append(existing.Aliases, newAlias)
123+
}
124+
125+
s.RegisteredModels[fullName] = existing
126+
return Response{
127+
Body: newAlias,
128+
}
129+
}
130+
131+
func (s *FakeWorkspace) RegisteredModelsDeleteAlias(fullName, alias string) Response {
132+
defer s.LockUnlock()()
133+
134+
existing, ok := s.RegisteredModels[fullName]
135+
if !ok {
136+
return Response{
137+
StatusCode: http.StatusNotFound,
138+
Body: fmt.Sprintf("registered model %s not found", fullName),
139+
}
140+
}
141+
142+
for i, a := range existing.Aliases {
143+
if a.AliasName == alias {
144+
existing.Aliases = append(existing.Aliases[:i], existing.Aliases[i+1:]...)
145+
break
146+
}
147+
}
148+
149+
s.RegisteredModels[fullName] = existing
150+
return Response{}
151+
}

0 commit comments

Comments
 (0)