-
Notifications
You must be signed in to change notification settings - Fork 21
Feature/mlflow support #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Arsen Gumin <gumin@live.ru>
Summary of ChangesHello @aagumin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly extends the tool's capabilities by integrating with MLflow, a popular platform for managing the machine learning lifecycle. Users can now seamlessly pull and generate modelfiles from models stored in MLflow registries, including those with artifacts hosted on S3. This broadens the ecosystem of supported model sources, making the tool more versatile for MLOps workflows. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for downloading models from MLflow, which is a great addition. The implementation includes a new mlflow provider, logic to interact with the MLflow registry, and functionality to download model artifacts from S3. While the overall structure is good, I've identified several critical issues that must be addressed. These include a file descriptor leak due to improper use of defer in a loop, a potential panic from unsafe string slicing, and incorrect credential handling logic. Additionally, I've provided feedback on improving logging, making version selection deterministic, and enhancing test coverage to ensure the new provider is robust and reliable.
| parsed, err := url.Parse(uri.ArtifactUri) | ||
| if parsed == nil { | ||
| return "", errors.New("failed to parse artifact uri") | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error returned by url.Parse is not checked. The code only checks if parsed is nil, but url.Parse can return a non-nil URL object even when an error occurs. This can lead to unexpected behavior or panics later on. You should always check the returned error.
| parsed, err := url.Parse(uri.ArtifactUri) | |
| if parsed == nil { | |
| return "", errors.New("failed to parse artifact uri") | |
| } | |
| parsed, err := url.Parse(uri.ArtifactUri) | |
| if err != nil { | |
| return "", fmt.Errorf("failed to parse artifact uri: %w", err) | |
| } |
| log.Printf("Error creating local file %s: %v\n", localFilePath, err) | ||
| continue | ||
| } | ||
| defer file.Close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
defer file.Close() is used inside a for loop. The deferred call will not be executed until the surrounding function DownloadModel returns. This will cause file descriptors to leak, as all downloaded files will remain open until the entire download process is complete. For a large number of files, this could lead to the process hitting its file descriptor limit. You should close the file explicitly at the end of each loop iteration.
| err = os.Setenv("DATABRICKS_USERNAME", usr) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| err = os.Setenv("DATABRICKS_PASSWORD", pass) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a bug in this logic. You are setting the DATABRICKS_USERNAME and DATABRICKS_PASSWORD environment variables using usr and pass, which were read from the (empty) DATABRICKS_* environment variables at the beginning of the function. You should be using mlfuser and mlfpass from the MLFLOW_* variables instead.
| err = os.Setenv("DATABRICKS_USERNAME", usr) | |
| if err != nil { | |
| return err | |
| } | |
| err = os.Setenv("DATABRICKS_PASSWORD", pass) | |
| err = os.Setenv("DATABRICKS_USERNAME", mlfuser) | |
| if err != nil { | |
| return err | |
| } | |
| err = os.Setenv("DATABRICKS_PASSWORD", mlfpass) |
| } | ||
|
|
||
| bucketName := parsed.Host | ||
| s3FolderPrefix := parsed.Path[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line s3FolderPrefix := parsed.Path[1:] will cause a panic if parsed.Path has a length less than 1 (e.g., an empty string). This can happen if the S3 URL is just the bucket name, like s3://my-bucket. You should handle this case more safely, for example by using strings.TrimPrefix.
| s3FolderPrefix := parsed.Path[1:] | |
| s3FolderPrefix := strings.TrimPrefix(parsed.Path, "/") |
| func checkMlflowAuth() error { | ||
|
|
||
| var err error | ||
|
|
||
| host := os.Getenv("DATABRICKS_HOST") | ||
| usr := os.Getenv("DATABRICKS_USERNAME") | ||
| pass := os.Getenv("DATABRICKS_PASSWORD") | ||
| mlfhost := os.Getenv("MLFLOW_TRACKING_URI") | ||
| mlfuser := os.Getenv("MLFLOW_TRACKING_USERNAME") | ||
| mlfpass := os.Getenv("MLFLOW_TRACKING_PASSWORD") | ||
|
|
||
| if host == "" && usr == "" && pass == "" { | ||
| fmt.Println("Please set DATABRICKS_HOST environment variable.") | ||
| fmt.Println("Please set DATABRICKS_USERNAME environment variable.") | ||
| fmt.Println("Please set DATABRICKS_PASSWORD environment variable.") | ||
| } else { | ||
| return nil | ||
| } | ||
| if mlfhost != "" && mlfuser != "" && mlfpass != "" { | ||
| err = os.Setenv("DATABRICKS_HOST", mlfhost) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| err = os.Setenv("DATABRICKS_USERNAME", usr) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| err = os.Setenv("DATABRICKS_PASSWORD", pass) | ||
| if err != nil { | ||
| return err | ||
| } | ||
|
|
||
| } else { | ||
| return errors.New("please set MLFLOW tracking environment variable.") | ||
| } | ||
| return err | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in checkMlflowAuth is complex, has side effects (setting environment variables), and prints to standard output instead of returning descriptive errors. This makes it hard to understand and test. Consider refactoring this function to be clearer. For example, it could first check for Databricks credentials, then MLflow credentials, and if neither is found, return a comprehensive error message listing all required environment variables. Also, printing messages to stdout should be avoided in library code; return errors instead.
|
|
||
| if mlflowClient != nil { | ||
| registry = ml.NewModelRegistry(mlflowClient) | ||
| fmt.Println("Use default mlflow client for MlFlowRegistryAPI") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using fmt.Println for logging in library code can lead to inconsistent and hard-to-manage output. Since you are using the log package elsewhere in this file, it would be better to use it here as well for consistency.
| fmt.Println("Use default mlflow client for MlFlowRegistryAPI") | |
| log.Println("Use default mlflow client for MlFlowRegistryAPI") |
| if modelVersion == "" { | ||
| modelVersion = rawVersion[0] | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When modelVersion is not provided, you are selecting the first version from the rawVersion slice (rawVersion[0]). The GetLatestVersionsAll function from the Databricks SDK does not guarantee the order of the returned versions. This means your code might non-deterministically select a version, which could be undesirable. You should consider sorting the versions (e.g., using semantic versioning if applicable) to pick the latest one deterministically. At a minimum, you should log which version is being automatically selected.
| func TestMlFlowClient_PullModelByName(t *testing.T) { | ||
| type fields struct { | ||
| registry *ml.ModelRegistryAPI | ||
| } | ||
| type args struct { | ||
| ctx context.Context | ||
| modelName string | ||
| modelVersion string | ||
| destSrc string | ||
| } | ||
| tests := []struct { | ||
| name string | ||
| fields fields | ||
| args args | ||
| want string | ||
| wantErr bool | ||
| }{ | ||
| { | ||
| name: "nil receiver returns error", | ||
| fields: fields{registry: nil}, | ||
| args: args{ctx: context.Background(), modelName: "model", modelVersion: "1", destSrc: "/tmp"}, | ||
| want: "", | ||
| wantErr: true, | ||
| }, | ||
| } | ||
| for _, tt := range tests { | ||
| t.Run(tt.name, func(t *testing.T) { | ||
| mlfr := &MlFlowClient{ | ||
| registry: tt.fields.registry, | ||
| } | ||
| got, err := mlfr.PullModelByName(tt.args.ctx, tt.args.modelName, tt.args.modelVersion, tt.args.destSrc) | ||
| if (err != nil) != tt.wantErr { | ||
| t.Errorf("PullModelByName() error = %v, wantErr %v", err, tt.wantErr) | ||
| return | ||
| } | ||
| if got != tt.want { | ||
| t.Errorf("PullModelByName() got = %v, want %v", got, tt.want) | ||
| } | ||
| }) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test for PullModelByName only covers the case of a nil receiver. This is a good start, but it leaves the core logic of the function untested. Please add more test cases to cover success scenarios and other error conditions. This will likely require mocking the ml.ModelRegistryAPI interface to simulate responses from the MLflow registry.
| func (p *MlflowProvider) SupportsURL(url string) bool { | ||
| url = strings.TrimSpace(url) | ||
| // TODO Mlflow API equals with Databricks Model Registry, support later | ||
| possibleUrls := []string{"models"} | ||
|
|
||
| return hasAnyPrefix(url, possibleUrls) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The SupportsURL implementation checks if the URL has the prefix "models". This is not specific enough and could lead to false positives (e.g., a URL like modelsomething-else://...). Given that MLflow model URIs use the models: scheme (e.g., models:/my-model/1), you should check for models:/ to be more precise.
| func (p *MlflowProvider) SupportsURL(url string) bool { | |
| url = strings.TrimSpace(url) | |
| // TODO Mlflow API equals with Databricks Model Registry, support later | |
| possibleUrls := []string{"models"} | |
| return hasAnyPrefix(url, possibleUrls) | |
| } | |
| func (p *MlflowProvider) SupportsURL(url string) bool { | |
| url = strings.TrimSpace(url) | |
| // TODO Mlflow API equals with Databricks Model Registry, support later | |
| possiblePrefixes := []string{"models:/"} | |
| return hasAnyPrefix(url, possiblePrefixes) | |
| } |
| {"models name with http schema", args{modelURL: "http://my-model/1"}, "", "", true}, | ||
| {"models name without version", args{modelURL: "my-model"}, "my-model", "", false}, | ||
| { | ||
| "models with schema and without version", | ||
| args{modelURL: "models://my-model"}, | ||
| "my-model", | ||
| "", | ||
| false, | ||
| }, | ||
| {"invalid url", args{modelURL: "://my-model/1"}, "", "", true}, | ||
| { | ||
| "model without schema should return error", | ||
| args{modelURL: "my-model/1"}, | ||
| "my-model", | ||
| "1", | ||
| false, | ||
| }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test case model without schema should return error seems to have an incorrect expectation. The implementation of parseModelURL correctly handles the my-model/1 format and returns the model and version without an error. However, the test name suggests an error is expected. Please either adjust the test to reflect the expected successful parsing or clarify if this case should indeed be an error.
#342 Implement Support Mlflow
Now we can:
or
or autopull latest
Two dependencies have been added:
• databricks-go-sdk — includes the MLflow client, with possible future support for their model registry.
• AWS SDK — used for downloading models. This dependency is relevant for many self-hosted solutions and cloud providers.