diff --git a/docs/src/index.md b/docs/src/index.md index 53f17057f3..71a76ed09f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -72,6 +72,44 @@ rust.toolchain( ) ``` +#### Using rust-toolchain.toml + +Alternatively, you can specify the Rust version using a `rust-toolchain.toml` file. This allows you to keep the Rust version in sync between Bazel and cargo/rustup: + +```python +rust = use_extension("@rules_rust//rust:extensions.bzl", "rust") +rust.toolchain( + edition = "2021", + rust_toolchain_file = "//:rust-toolchain.toml", +) +``` + +The parser supports both single-line and multi-line TOML arrays: + +```toml +[toolchain] +channel = "1.86.0" + +# Multi-line arrays are supported +components = [ + "rustfmt", + "clippy", +] + +# Single-line arrays work too +targets = ["wasm32-unknown-unknown"] +``` + +The following fields from `rust-toolchain.toml` are parsed and mapped to rules_rust attributes: + +| rust-toolchain.toml | rules_rust attribute | Notes | +|---------------------|---------------------|-------| +| `channel` | `versions` | Required. The Rust version to use. | +| `targets` | `extra_target_triples` | Merged with any explicit `extra_target_triples`. | +| `components` | `dev_components` | Set to `True` if `"rustc-dev"` is in components. | + +Explicit attributes in `rust.toolchain()` take precedence over values parsed from `rust-toolchain.toml`. + By default, a `stable` and `nightly` toolchain will be registered if no `toolchain` method is called (and thus no specific versions are registered). However, if only 1 version is passed and it is from the `nightly` or `beta` release channels (i.e. __not__ `stable`), then the following build setting flag must be present, either on the command line or set in the project's `.bazelrc` file: ```text diff --git a/rust/extensions.bzl b/rust/extensions.bzl index 749740032e..70f0c78034 100644 --- a/rust/extensions.bzl +++ b/rust/extensions.bzl @@ -10,6 +10,7 @@ load( "DEFAULT_NIGHTLY_VERSION", "DEFAULT_STATIC_RUST_URL_TEMPLATES", ) +load("//rust/private:rust_toolchain_toml.bzl", "parse_rust_toolchain_file") _RUST_TOOLCHAIN_VERSIONS = [ rust_common.default_version, @@ -98,7 +99,37 @@ def _rust_impl(module_ctx): for toolchain in toolchains: if toolchain.extra_rustc_flags and toolchain.extra_rustc_flags_triples: fail("Cannot define both extra_rustc_flags and extra_rustc_flags_triples") - if len(toolchain.versions) == 0: + + # Start with explicit attribute values + versions = toolchain.versions + extra_target_triples = toolchain.extra_target_triples + dev_components = toolchain.dev_components + + # Override/merge with rust_toolchain_file if specified + if toolchain.rust_toolchain_file: + toolchain_file_path = module_ctx.path(toolchain.rust_toolchain_file) + toolchain_file_content = module_ctx.read(toolchain_file_path) + parsed = parse_rust_toolchain_file(toolchain_file_content) + if parsed: + # Use parsed versions (required) + versions = parsed.versions + + # Merge extra_target_triples (explicit + parsed) + if parsed.extra_target_triples: + # Combine explicit targets with parsed targets, avoiding duplicates + combined_triples = list(extra_target_triples) + for triple in parsed.extra_target_triples: + if triple not in combined_triples: + combined_triples.append(triple) + extra_target_triples = combined_triples + + # Enable dev_components if specified in file (or already enabled explicitly) + if parsed.dev_components: + dev_components = True + else: + fail("Failed to parse rust-toolchain file: {}".format(toolchain.rust_toolchain_file)) + + if len(versions) == 0: # If the root module has asked for rules_rust to not register default # toolchains, an empty repository named `rust_toolchains` is created # so that the `register_toolchains()` in MODULES.bazel is still @@ -109,7 +140,7 @@ def _rust_impl(module_ctx): rust_register_toolchains( hub_name = "rust_toolchains", - dev_components = toolchain.dev_components, + dev_components = dev_components, edition = toolchain.edition, extra_rustc_flags = extra_rustc_flags, extra_exec_rustc_flags = toolchain.extra_exec_rustc_flags, @@ -117,9 +148,9 @@ def _rust_impl(module_ctx): rustfmt_version = toolchain.rustfmt_version, rust_analyzer_version = toolchain.rust_analyzer_version, sha256s = toolchain.sha256s, - extra_target_triples = toolchain.extra_target_triples, + extra_target_triples = extra_target_triples, urls = toolchain.urls, - versions = toolchain.versions, + versions = versions, register_toolchains = False, aliases = toolchain.aliases, toolchain_triples = toolchain_triples, @@ -225,6 +256,14 @@ _RUST_TOOLCHAIN_TAG = tag_class( "rust_analyzer_version": attr.string( doc = "The version of Rustc to pair with rust-analyzer.", ), + "rust_toolchain_file": attr.label( + doc = ( + "A label to a rust-toolchain.toml file. If specified, the toolchain version will be " + + "read from this file instead of using the `versions` attribute. This allows keeping " + + "the Rust version in sync between Bazel and cargo/rustup." + ), + allow_single_file = True, + ), "target_settings": attr.label_list( doc = "A list of `config_settings` that must be satisfied by the target configuration in order for this toolchain to be selected during toolchain resolution.", ), @@ -232,7 +271,8 @@ _RUST_TOOLCHAIN_TAG = tag_class( doc = ( "A list of toolchain versions to download. This parameter only accepts one version " + "per channel. E.g. `[\"1.65.0\", \"nightly/2022-11-02\", \"beta/2020-12-30\"]`. " + - "May be set to an empty list (`[]`) to inhibit `rules_rust` from registering toolchains." + "May be set to an empty list (`[]`) to inhibit `rules_rust` from registering toolchains. " + + "If `rust_toolchain_file` is specified, this attribute is ignored." ), default = _RUST_TOOLCHAIN_VERSIONS, ), diff --git a/rust/private/rust_toolchain_toml.bzl b/rust/private/rust_toolchain_toml.bzl new file mode 100644 index 0000000000..68507c380a --- /dev/null +++ b/rust/private/rust_toolchain_toml.bzl @@ -0,0 +1,139 @@ +"""Parser for rust-toolchain.toml files.""" + +def normalize_toml_multiline_arrays(content): + """Normalize multi-line TOML arrays to single-line for simpler parsing. + + Handles arrays like: + targets = [ + "wasm32-unknown-unknown", + "x86_64-unknown-linux-gnu", + ] + + Args: + content: The raw TOML file content. + + Returns: + Content with multi-line arrays collapsed to single lines. + """ + result = [] + in_array = False + array_buffer = "" + + for line in content.split("\n"): + stripped = line.strip() + + # Preserve comments and empty lines outside arrays + if not stripped or stripped.startswith("#"): + if not in_array: + result.append(line) + continue + + if in_array: + array_buffer += " " + stripped + if "]" in stripped: + in_array = False + result.append(array_buffer) + array_buffer = "" + elif "= [" in stripped and "]" not in stripped: + # Start of multi-line array + in_array = True + array_buffer = stripped + else: + result.append(line) + + return "\n".join(result) + +def parse_toml_string(line): + """Parse a TOML string value: key = "value" -> value + + Args: + line: A line containing a TOML key-value pair. + + Returns: + The parsed string value, or None if parsing fails. + """ + parts = line.split("=", 1) + if len(parts) == 2: + return parts[1].strip().strip("\"'") + return None + +def parse_toml_list(line): + """Parse a TOML list value: key = ["a", "b"] -> ["a", "b"] + + Args: + line: A line containing a TOML key-value pair with a list value. + + Returns: + The parsed list of strings, or an empty list if parsing fails. + """ + parts = line.split("=", 1) + if len(parts) == 2: + list_str = parts[1].strip() + if list_str.startswith("[") and list_str.endswith("]"): + items = list_str[1:-1].split(",") + return [item.strip().strip("\"'") for item in items if item.strip().strip("\"'")] + return [] + +def parse_rust_toolchain_file(content): + """Parse rust-toolchain.toml content and extract toolchain configuration. + + Supports: + - channel: The toolchain version (e.g., "1.92.0", "nightly-2024-01-01") + - targets: Additional target triples to install + - components: Additional components (sets dev_components=True if "rustc-dev" present) + + Both single-line and multi-line TOML arrays are supported. + + Args: + content: The content of the rust-toolchain.toml file. + + Returns: + A struct with versions, extra_target_triples, and dev_components fields, + or None if parsing fails completely. + """ + + # Normalize multi-line arrays first + content = normalize_toml_multiline_arrays(content) + + versions = None + extra_target_triples = [] + dev_components = False + + for line in content.split("\n"): + line = line.strip() + + # Skip empty lines, comments, and section headers + if not line or line.startswith("#") or line.startswith("["): + continue + + if line.startswith("channel"): + version = parse_toml_string(line) + if version: + versions = [version] + + elif line.startswith("targets"): + targets = parse_toml_list(line) + if targets: + extra_target_triples = targets + + elif line.startswith("components"): + components = parse_toml_list(line) + if "rustc-dev" in components: + dev_components = True + + # If no channel was found, try simple format (just version string) + if not versions: + for line in content.split("\n"): + line = line.strip() + if line and not line.startswith("#") and not line.startswith("[") and "=" not in line: + versions = [line] + break + + if not versions: + return None + + return struct( + versions = versions, + extra_target_triples = extra_target_triples, + dev_components = dev_components, + ) diff --git a/test/unit/rust_toolchain_toml/BUILD.bazel b/test/unit/rust_toolchain_toml/BUILD.bazel new file mode 100644 index 0000000000..688adf6ad0 --- /dev/null +++ b/test/unit/rust_toolchain_toml/BUILD.bazel @@ -0,0 +1,3 @@ +load(":rust_toolchain_toml_test.bzl", "rust_toolchain_toml_test_suite") + +rust_toolchain_toml_test_suite(name = "rust_toolchain_toml_tests") diff --git a/test/unit/rust_toolchain_toml/rust_toolchain_toml_test.bzl b/test/unit/rust_toolchain_toml/rust_toolchain_toml_test.bzl new file mode 100644 index 0000000000..be63a93b50 --- /dev/null +++ b/test/unit/rust_toolchain_toml/rust_toolchain_toml_test.bzl @@ -0,0 +1,190 @@ +"""Unit tests for rust_toolchain_toml.bzl parser functions.""" + +load("@bazel_skylib//lib:unittest.bzl", "asserts", "unittest") + +# buildifier: disable=bzl-visibility +load( + "//rust/private:rust_toolchain_toml.bzl", + "normalize_toml_multiline_arrays", + "parse_rust_toolchain_file", + "parse_toml_list", + "parse_toml_string", +) + +def _parse_toml_string_test_impl(ctx): + env = unittest.begin(ctx) + + # Basic string parsing + asserts.equals(env, "1.86.0", parse_toml_string('channel = "1.86.0"')) + asserts.equals(env, "nightly-2024-01-01", parse_toml_string('channel = "nightly-2024-01-01"')) + + # With extra whitespace + asserts.equals(env, "1.86.0", parse_toml_string('channel = "1.86.0"')) + + # Single quotes + asserts.equals(env, "1.86.0", parse_toml_string("channel = '1.86.0'")) + + # No equals sign + asserts.equals(env, None, parse_toml_string("just a string")) + + return unittest.end(env) + +def _parse_toml_list_test_impl(ctx): + env = unittest.begin(ctx) + + # Basic list parsing + asserts.equals( + env, + ["rustfmt", "clippy"], + parse_toml_list('components = ["rustfmt", "clippy"]'), + ) + + # Single item + asserts.equals( + env, + ["wasm32-unknown-unknown"], + parse_toml_list('targets = ["wasm32-unknown-unknown"]'), + ) + + # Empty list + asserts.equals(env, [], parse_toml_list("targets = []")) + + # With extra whitespace + asserts.equals( + env, + ["a", "b", "c"], + parse_toml_list('items = [ "a" , "b" , "c" ]'), + ) + + # No equals sign + asserts.equals(env, [], parse_toml_list("not a list")) + + # Not a list value + asserts.equals(env, [], parse_toml_list('key = "value"')) + + return unittest.end(env) + +def _normalize_multiline_arrays_test_impl(ctx): + env = unittest.begin(ctx) + + # Single-line array should be unchanged + single_line = 'targets = ["a", "b"]' + asserts.equals(env, single_line, normalize_toml_multiline_arrays(single_line)) + + # Multi-line array should be collapsed + multi_line = """components = [ + "rustfmt", + "clippy", +]""" + normalized = normalize_toml_multiline_arrays(multi_line) + asserts.true(env, "components = [" in normalized) + asserts.true(env, "]" in normalized) + asserts.true(env, "\n" not in normalized or normalized.count("\n") == 0) + + # Verify the list can be parsed after normalization + parsed = parse_toml_list(normalized) + asserts.equals(env, ["rustfmt", "clippy"], parsed) + + # Mixed content + mixed = """[toolchain] +channel = "1.86.0" +components = [ + "rustfmt", + "clippy", +] +targets = ["wasm32-unknown-unknown"]""" + normalized_mixed = normalize_toml_multiline_arrays(mixed) + asserts.true(env, 'channel = "1.86.0"' in normalized_mixed) + asserts.true(env, 'targets = ["wasm32-unknown-unknown"]' in normalized_mixed) + + return unittest.end(env) + +def _parse_rust_toolchain_file_test_impl(ctx): + env = unittest.begin(ctx) + + # Basic TOML format + basic = """[toolchain] +channel = "1.86.0" +""" + parsed = parse_rust_toolchain_file(basic) + asserts.true(env, parsed != None, "Should parse basic TOML") + asserts.equals(env, ["1.86.0"], parsed.versions) + asserts.equals(env, [], parsed.extra_target_triples) + asserts.equals(env, False, parsed.dev_components) + + # Full TOML with all fields + full = """[toolchain] +channel = "1.92.0" +components = ["rustfmt", "clippy", "rustc-dev"] +targets = ["wasm32-unknown-unknown", "x86_64-unknown-linux-gnu"] +""" + parsed_full = parse_rust_toolchain_file(full) + asserts.true(env, parsed_full != None, "Should parse full TOML") + asserts.equals(env, ["1.92.0"], parsed_full.versions) + asserts.equals( + env, + ["wasm32-unknown-unknown", "x86_64-unknown-linux-gnu"], + parsed_full.extra_target_triples, + ) + asserts.equals(env, True, parsed_full.dev_components) + + # Multi-line arrays + multiline = """[toolchain] +channel = "1.86.0" +components = [ + "rustfmt", + "clippy", +] +targets = ["wasm32-unknown-unknown"] +""" + parsed_multiline = parse_rust_toolchain_file(multiline) + asserts.true(env, parsed_multiline != None, "Should parse multi-line arrays") + asserts.equals(env, ["1.86.0"], parsed_multiline.versions) + asserts.equals(env, ["wasm32-unknown-unknown"], parsed_multiline.extra_target_triples) + + # Simple format (just version string) + simple = "1.86.0" + parsed_simple = parse_rust_toolchain_file(simple) + asserts.true(env, parsed_simple != None, "Should parse simple format") + asserts.equals(env, ["1.86.0"], parsed_simple.versions) + + # With comments + with_comments = """# This is a comment +[toolchain] +# Another comment +channel = "1.86.0" +""" + parsed_comments = parse_rust_toolchain_file(with_comments) + asserts.true(env, parsed_comments != None, "Should handle comments") + asserts.equals(env, ["1.86.0"], parsed_comments.versions) + + # Invalid content (no version) + invalid = """[toolchain] +components = ["rustfmt"] +""" + parsed_invalid = parse_rust_toolchain_file(invalid) + asserts.equals(env, None, parsed_invalid, "Should return None for invalid content") + + # Nightly channel + nightly = """[toolchain] +channel = "nightly-2024-06-01" +""" + parsed_nightly = parse_rust_toolchain_file(nightly) + asserts.true(env, parsed_nightly != None) + asserts.equals(env, ["nightly-2024-06-01"], parsed_nightly.versions) + + return unittest.end(env) + +parse_toml_string_test = unittest.make(_parse_toml_string_test_impl) +parse_toml_list_test = unittest.make(_parse_toml_list_test_impl) +normalize_multiline_arrays_test = unittest.make(_normalize_multiline_arrays_test_impl) +parse_rust_toolchain_file_test = unittest.make(_parse_rust_toolchain_file_test_impl) + +def rust_toolchain_toml_test_suite(name): + unittest.suite( + name, + parse_toml_string_test, + parse_toml_list_test, + normalize_multiline_arrays_test, + parse_rust_toolchain_file_test, + )