A comprehensive Go client library for interacting with the MLflow REST API.
go get github.com/julpayne/mlflow-go-client/pkg/mlflowpackage main
import (
"fmt"
"github.com/julpayne/mlflow-go-client/pkg/mlflow"
)
func main() {
// Create a new client
client := mlflow.NewClient("http://localhost:5000")
// Optional: Set authentication token
client.SetAuthToken("your-token-here")
// Optional: Set custom timeout
client.SetTimeout(60 * time.Second)
}req := mlflow.CreateExperimentRequest{
Name: "my-experiment",
Tags: []mlflow.ExperimentTag{
{Key: "team", Value: "ml-team"},
},
}
resp, err := client.CreateExperiment(req)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Created experiment with ID: %s\n", resp.ExperimentID)// By ID
experiment, err := client.GetExperiment("experiment-id")
if err != nil {
log.Fatal(err)
}
// By name
experiment, err := client.GetExperimentByName("my-experiment")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Experiment: %s\n", experiment.Experiment.Name)searchReq := mlflow.SearchExperimentsRequest{
MaxResults: 200,
}
experiments, err := client.SearchExperiments(searchReq)
if err != nil {
log.Fatal(err)
}
for _, exp := range experiments.Experiments {
fmt.Printf("Experiment: %s (ID: %s)\n", exp.Name, exp.ExperimentID)
}// Search experiments with filters
searchReq := mlflow.SearchExperimentsRequest{
ViewType: "ACTIVE_ONLY", // ACTIVE_ONLY, DELETED_ONLY, or ALL
MaxResults: 200,
Filter: "name LIKE '%test%'",
OrderBy: []string{"name ASC"},
}
results, err := client.SearchExperiments(searchReq)
if err != nil {
log.Fatal(err)
}
for _, exp := range results.Experiments {
fmt.Printf("Experiment: %s\n", exp.Name)
}// Update experiment name
err := client.UpdateExperiment("experiment-id", "new-name")
// Set experiment tag
err := client.SetExperimentTag("experiment-id", "key", "value")
// Delete experiment tag
err := client.DeleteExperimentTag("experiment-id", "key")
// Delete experiment
err := client.DeleteExperiment("experiment-id")
// Restore deleted experiment
err := client.RestoreExperiment("experiment-id")req := mlflow.CreateRunRequest{
ExperimentID: "experiment-id",
RunName: "my-run",
Tags: []mlflow.RunTag{
{Key: "version", Value: "1.0"},
},
}
run, err := client.CreateRun(req)
if err != nil {
log.Fatal(err)
}
runID := run.Run.Info.RunID
fmt.Printf("Created run with ID: %s\n", runID)// Log a metric
err := client.LogMetric(mlflow.LogMetricRequest{
RunID: runID,
Key: "accuracy",
Value: 0.95,
Step: 1,
})
// Log a parameter
err := client.LogParam(mlflow.LogParamRequest{
RunID: runID,
Key: "learning_rate",
Value: "0.01",
})
// Set a tag
err := client.SetTag(mlflow.SetTagRequest{
RunID: runID,
Key: "model_type",
Value: "neural_network",
})metrics := []mlflow.Metric{
{Key: "loss", Value: 0.5, Step: 1, Timestamp: time.Now().UnixMilli()},
{Key: "accuracy", Value: 0.95, Step: 1, Timestamp: time.Now().UnixMilli()},
}
params := []mlflow.Param{
{Key: "epochs", Value: "100"},
{Key: "batch_size", Value: "32"},
}
tags := []mlflow.RunTag{
{Key: "framework", Value: "pytorch"},
}
err := client.LogBatch(runID, metrics, params, tags)// Get a run by ID
run, err := client.GetRun(runID)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Run status: %s\n", run.Run.Info.Status)
// Search for runs
searchReq := mlflow.SearchRunsRequest{
ExperimentIDs: []string{"experiment-id"},
Filter: "metrics.accuracy > 0.9",
MaxResults: 10,
OrderBy: []string{"metrics.accuracy DESC"},
}
results, err := client.SearchRuns(searchReq)
if err != nil {
log.Fatal(err)
}
for _, run := range results.Runs {
fmt.Printf("Run: %s, Accuracy: %f\n", run.Info.RunID,
getMetricValue(run.Data.Metrics, "accuracy"))
}// Log a model to a run
err := client.LogModel(mlflow.LogModelRequest{
RunID: runID,
ModelJSON: `{"model": "sklearn", "version": "1.0"}`,
})
// Log inputs (datasets and model inputs)
err := client.LogInputs(mlflow.LogInputsRequest{
RunID: runID,
Datasets: []mlflow.Dataset{
{
Name: "training_data",
Digest: "abc123",
SourceType: "LOCAL",
Source: "/path/to/data",
},
},
})// Get metric history for a run
history, err := client.GetMetricHistory(mlflow.GetMetricHistoryRequest{
RunUUID: runID,
MetricKey: "accuracy",
MaxResults: 200,
})
if err != nil {
log.Fatal(err)
}
for _, metric := range history.Metrics {
fmt.Printf("Step %d: %f\n", metric.Step, metric.Value)
}
// List artifacts for a run
artifacts, err := client.ListArtifacts(runID, "", "")
if err != nil {
log.Fatal(err)
}
for _, file := range artifacts.Files {
fmt.Printf("Artifact: %s (size: %d)\n", file.Path, file.FileSize)
}// Update run status
err := client.UpdateRun(mlflow.UpdateRunRequest{
RunID: runID,
Status: "FINISHED",
EndTime: time.Now().UnixMilli(),
})
// Delete a run
err := client.DeleteRun(runID)
// Restore a deleted run
err := client.RestoreRun(runID)req := mlflow.CreateRegisteredModelRequest{
Name: "my-model",
Description: "A machine learning model",
Tags: []mlflow.RegisteredModelTag{
{Key: "type", Value: "classification"},
},
}
model, err := client.CreateRegisteredModel(req)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Created model: %s\n", model.RegisteredModel.Name)req := mlflow.CreateModelVersionRequest{
Name: "my-model",
Source: "runs:/run-id/model",
RunID: "run-id",
Description: "Version 1.0 of the model",
}
version, err := client.CreateModelVersion(req)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Created version: %s\n", version.ModelVersion.Version)// Get registered model
model, err := client.GetRegisteredModel("my-model")
// List all registered models
models, err := client.ListRegisteredModels(100, "")
// Get model version
version, err := client.GetModelVersion("my-model", "1")
// List model versions
versions, err := client.ListModelVersions("my-model", 100, "")// Search registered models
searchReq := mlflow.SearchRegisteredModelsRequest{
Filter: "name LIKE '%classifier%'",
MaxResults: 50,
OrderBy: []string{"name ASC"},
}
models, err := client.SearchRegisteredModels(searchReq)
if err != nil {
log.Fatal(err)
}
// Search model versions
versionSearchReq := mlflow.SearchModelVersionsRequest{
Filter: "name='my-model' AND version='1'",
MaxResults: 200,
}
versions, err := client.SearchModelVersions(versionSearchReq)
if err != nil {
log.Fatal(err)
}// Rename registered model
renamed, err := client.RenameRegisteredModel(mlflow.RenameRegisteredModelRequest{
Name: "old-name",
NewName: "new-name",
})
// Get latest model versions
latest, err := client.GetLatestModelVersions(mlflow.GetLatestModelVersionsRequest{
Name: "my-model",
Stages: []string{"Production", "Staging"},
})
// Update model version
err := client.UpdateModelVersion("my-model", "1", "Updated description", "Production")
// Transition model version stage
version, err := client.TransitionModelVersionStage(
"my-model",
"1",
"Production",
"true", // archive existing versions
)
// Get download URIs for model artifacts
uris, err := client.GetDownloadURIs(mlflow.GetDownloadURIsRequest{
Name: "my-model",
Version: "1",
Paths: []string{"model.pkl", "requirements.txt"},
})// Set registered model tag
err := client.SetRegisteredModelTag(mlflow.SetRegisteredModelTagRequest{
Name: "my-model",
Key: "team",
Value: "ml-team",
})
// Set model version tag
err := client.SetModelVersionTag(mlflow.SetModelVersionTagRequest{
Name: "my-model",
Version: "1",
Key: "deployed",
Value: "true",
})
// Delete tags
err := client.DeleteRegisteredModelTag(mlflow.DeleteRegisteredModelTagRequest{
Name: "my-model",
Key: "team",
})
err := client.DeleteModelVersionTag(mlflow.DeleteModelVersionTagRequest{
Name: "my-model",
Version: "1",
Key: "deployed",
})
// Set model alias
err := client.SetRegisteredModelAlias(mlflow.SetRegisteredModelAliasRequest{
Name: "my-model",
Alias: "production",
Version: "1",
})
// Get model version by alias
version, err := client.GetModelVersionByAlias(mlflow.GetModelVersionByAliasRequest{
Name: "my-model",
Alias: "production",
})
// Delete alias
err := client.DeleteRegisteredModelAlias(mlflow.DeleteRegisteredModelAliasRequest{
Name: "my-model",
Alias: "production",
})
// Delete model version
err := client.DeleteModelVersion("my-model", "1")
// Delete registered model
err := client.DeleteRegisteredModel("my-model")This client supports the following MLflow API endpoints:
- ✅ Create experiment
- ✅ Get experiment (by ID)
- ✅ Get experiment (by name)
- ✅ Search experiments (use instead of list)
- ✅ Update experiment
- ✅ Delete experiment
- ✅ Restore experiment
- ✅ Set experiment tag
- ✅ Delete experiment tag
- ✅ Create run
- ✅ Get run
- ✅ Search runs
- ✅ Update run
- ✅ Delete run
- ✅ Restore run
- ✅ Log metric
- ✅ Log parameter
- ✅ Set tag
- ✅ Delete tag
- ✅ Log batch (multiple metrics/params/tags)
- ✅ Log model
- ✅ Log inputs (datasets and model inputs)
- ✅ Get metric history
- ✅ List artifacts
- ✅ Create registered model
- ✅ Get registered model
- ✅ List registered models
- ✅ Search registered models
- ✅ Update registered model
- ✅ Rename registered model
- ✅ Delete registered model
- ✅ Create model version
- ✅ Get model version
- ✅ List model versions
- ✅ Search model versions
- ✅ Get latest model versions
- ✅ Update model version
- ✅ Delete model version
- ✅ Transition model version stage
- ✅ Get download URIs for model version artifacts
- ✅ Set registered model tag
- ✅ Set model version tag
- ✅ Delete registered model tag
- ✅ Delete model version tag
- ✅ Set registered model alias
- ✅ Delete registered model alias
- ✅ Get model version by alias
All methods return errors that can be checked. The client uses a custom APIError type that provides detailed information about API errors:
experiment, err := client.GetExperiment("id")
if err != nil {
// Check if it's an APIError to access detailed information
if apiErr, ok := mlflow.IsAPIError(err); ok {
fmt.Printf("Status Code: %d\n", apiErr.GetStatusCode())
fmt.Printf("Error Code: %s\n", apiErr.GetErrorCode())
fmt.Printf("Message: %s\n", apiErr.GetMessage())
fmt.Printf("Response Body: %s\n", apiErr.GetResponseBodyString())
} else {
// Handle other types of errors (network, etc.)
fmt.Printf("Error: %v\n", err)
}
return
}The APIError type provides the following methods:
GetStatusCode()- Returns the HTTP status codeGetErrorCode()- Returns the MLflow error code (if available)GetMessage()- Returns the error messageGetResponseBody()- Returns the raw response body as bytesGetResponseBodyString()- Returns the response body as a string
The client supports Bearer token authentication:
client.SetAuthToken("your-api-token")This repository includes scripts and Makefile targets to easily download and run the MLflow server locally for testing and development.
-
Install MLflow:
make install-mlflow
-
Start the server:
make run-mlflow
-
Use with the Go client:
client := mlflow.NewClient("http://localhost:5000")
The server will be available at http://localhost:5000 by default.
You can customize the server settings using environment variables:
# Custom host and port
MLFLOW_PORT=8080 make run-mlflow
# Use PostgreSQL backend
MLFLOW_BACKEND_STORE_URI=postgresql://user:password@localhost/mlflow \
make run-mlflow
# Stop the server
make stop-mlflowAlternatively, you can use the scripts directly:
# Install MLflow
./scripts/download_mlflow.sh
# Start the server
./scripts/run_mlflow.sh
# With custom configuration
MLFLOW_HOST=0.0.0.0 MLFLOW_PORT=8080 ./scripts/run_mlflow.shmake install-mlflow- Download and install MLflow servermake run-mlflow- Run MLflow server locallymake stop-mlflow- Stop MLflow server (if running)make build- Build the Go client librarymake test- Run testsmake fmt- Format Go codemake vet- Run go vetmake lint- Run golangci-lint (if installed)make clean- Clean build artifactsmake help- Show all available targets
For more details, see scripts/README.md.
MIT