Skip to content

Commit 7156655

Browse files
committed
Introduce cluster support (#299)
1 parent f5758f6 commit 7156655

File tree

9 files changed

+338
-56
lines changed

9 files changed

+338
-56
lines changed

BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ rust_binary(
2929
"@crates//:futures",
3030
"@crates//:glob",
3131
"@crates//:home",
32+
"@crates//:itertools",
3233
"@crates//:rpassword",
3334
"@crates//:rustyline",
3435
"@crates//:sentry",

WORKSPACE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ rust_register_toolchains(
7575
"x86_64-pc-windows-msvc",
7676
"x86_64-unknown-linux-gnu",
7777
],
78-
rust_analyzer_version = "1.81.0",
79-
versions = ["1.81.0"],
78+
rust_analyzer_version = "1.84.0",
79+
versions = ["1.84.0"],
8080
)
8181

8282
rust_analyzer_toolchain_tools_repository(

dependencies/typedb/repositories.bzl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,19 @@
55
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
66

77
def typedb_dependencies():
8-
git_repository(
9-
name = "typedb_dependencies",
10-
remote = "https://github.com/typedb/typedb-dependencies",
11-
commit = "f6e710f9857b1c30ad1764c1c41afce4e4e02981", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_dependencies
12-
)
8+
# TODO: Return ref after merge to master
9+
git_repository(
10+
name = "typedb_dependencies",
11+
remote = "https://github.com/typedb/typedb-dependencies",
12+
commit = "19a70bcad19b9a28814016f183ac3e3a23c1ff0d", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_dependencies
13+
)
1314

1415
def typedb_driver():
16+
# TODO: Return ref after merge to master
1517
git_repository(
1618
name = "typedb_driver",
1719
remote = "https://github.com/typedb/typedb-driver",
18-
tag = "3.7.0", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_driver
20+
commit = "a804da767e68154cc3ee5a7531feeaba5dbf0c17", # sync-marker: do not remove this comment, this is used for sync-dependencies by @typedb_driver
1921
)
2022

2123
def typeql():

src/cli.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,28 @@ pub struct Args {
2424
#[arg(long, value_name = "path to script file")]
2525
pub script: Vec<String>,
2626

27-
/// TypeDB address to connect to. If using TLS encryption, this must start with "https://"
28-
#[arg(long, value_name = ADDRESS_VALUE_NAME)]
27+
/// TypeDB address to connect to (host:port). If using TLS encryption, this must start with "https://".
28+
#[arg(long, value_name = ADDRESS_VALUE_NAME, conflicts_with_all = ["addresses", "address_translation"])]
2929
pub address: Option<String>,
3030

31+
/// A comma-separated list of TypeDB replica addresses of a single cluster to connect to.
32+
#[arg(long, value_name = "host1:port1,host2:port2", conflicts_with_all = ["address", "address_translation"])]
33+
pub addresses: Option<String>,
34+
35+
/// A comma-separated list of public=private address pairs. Public addresses are the user-facing
36+
/// addresses of the replicas, and private addresses are the addresses used for the server-side
37+
/// connection between replicas.
38+
#[arg(long, value_name = "public=private,...", conflicts_with_all = ["address", "addresses"])]
39+
pub address_translation: Option<String>,
40+
41+
/// If used in a Cluster environment (Cloud or Enterprise), disables attempts to redirect
42+
/// requests to server replicas, limiting Console to communicate only with the single address
43+
/// specified in the `address` argument.
44+
/// Use for administrative / debug purposes to test a specific replica only: this option will
45+
/// lower the success rate of Console's operations in production.
46+
#[arg(long = "replication-disabled", default_value = "false")]
47+
pub replication_disabled: bool,
48+
3149
/// Username for authentication
3250
#[arg(long, value_name = USERNAME_VALUE_NAME)]
3351
pub username: Option<String>,
@@ -48,8 +66,8 @@ pub struct Args {
4866

4967
/// Disable error reporting. Error reporting helps TypeDB improve by reporting
5068
/// errors and crashes to the development team.
51-
#[arg(long = "diagnostics-disable", default_value = "false")]
52-
pub diagnostics_disable: bool,
69+
#[arg(long = "diagnostics-disabled", default_value = "false")]
70+
pub diagnostics_disabled: bool,
5371

5472
/// Print the Console binary version
5573
#[arg(long = "version")]

src/completion_cache.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use std::collections::HashSet;
2+
use std::hash::Hash;
3+
use std::sync::{Arc, RwLock};
4+
5+
#[derive(Clone)]
6+
pub struct CompletionCache<T> {
7+
inner: Arc<RwLock<HashSet<T>>>,
8+
}
9+
10+
impl<T> CompletionCache<T>
11+
where
12+
T: Eq + Hash + Clone,
13+
{
14+
pub fn new() -> Self {
15+
Self { inner: Arc::new(RwLock::new(HashSet::new())) }
16+
}
17+
18+
pub fn replace_all<I>(&self, items: impl IntoIterator<Item = T>) {
19+
*self.inner.write().unwrap() = items.into_iter().collect();
20+
}
21+
22+
pub fn snapshot(&self) -> Vec<T> {
23+
self.inner.read().unwrap().iter().cloned().collect()
24+
}
25+
26+
pub fn is_empty(&self) -> bool {
27+
self.inner.read().unwrap().is_empty()
28+
}
29+
30+
pub fn clear(&self) {
31+
self.inner.write().unwrap().clear();
32+
}
33+
}

src/main.rs

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66

77
use std::{
8+
collections::HashMap,
89
env,
910
env::temp_dir,
1011
error::Error,
@@ -22,15 +23,16 @@ use clap::Parser;
2223
use home::home_dir;
2324
use rustyline::error::ReadlineError;
2425
use sentry::ClientOptions;
25-
use typedb_driver::{Credentials, DriverOptions, Transaction, TransactionType, TypeDBDriver};
26+
use typedb_driver::{Addresses, Credentials, DriverOptions, Transaction, TransactionType, TypeDBDriver};
2627

2728
use crate::{
2829
cli::{Args, ADDRESS_VALUE_NAME, USERNAME_VALUE_NAME},
2930
completions::{database_name_completer_fn, file_completer},
3031
operations::{
3132
database_create, database_create_init, database_delete, database_export, database_import, database_list,
32-
database_schema, transaction_close, transaction_commit, transaction_query, transaction_read,
33-
transaction_rollback, transaction_schema, transaction_source, transaction_write, user_create, user_delete,
33+
database_schema, replica_deregister, replica_list, replica_primary, replica_register, server_version,
34+
transaction_close, transaction_commit, transaction_query, transaction_read, transaction_rollback,
35+
transaction_schema, transaction_source, transaction_write, user_create, user_delete,
3436
user_list, user_update_password,
3537
},
3638
repl::{
@@ -126,13 +128,19 @@ fn main() {
126128
println!("{}", VERSION);
127129
exit(ExitCode::Success as i32);
128130
}
129-
let address = match args.address {
130-
Some(address) => address,
131-
None => {
132-
println_error!("missing server address ('{}').", format_argument!("--address <{ADDRESS_VALUE_NAME}>"));
133-
exit(ExitCode::UserInputError as i32);
134-
}
135-
};
131+
let address_info = parse_addresses(&args);
132+
if !args.tls_disabled && !address_info.only_https {
133+
println_error!(
134+
"\
135+
TLS connections can only be enabled when connecting to HTTPS endpoints. \
136+
For example, using 'https://<ip>:port'.\n\
137+
Please modify the address or disable TLS ('{}'). {}\
138+
",
139+
format_argument!("--tls-disabled"),
140+
format_warning!("WARNING: this will send passwords over plaintext!"),
141+
);
142+
exit(ExitCode::UserInputError as i32);
143+
}
136144
let username = match args.username {
137145
Some(username) => username,
138146
None => {
@@ -146,34 +154,28 @@ fn main() {
146154
if args.password.is_none() {
147155
args.password = Some(LineReaderHidden::new().readline(&format!("password for '{username}': ")));
148156
}
149-
if !args.diagnostics_disable {
157+
if !args.diagnostics_disabled {
150158
init_diagnostics()
151159
}
152-
if !args.tls_disabled && !address.starts_with("https:") {
153-
println_error!(
154-
"\
155-
TLS connections can only be enabled when connecting to HTTPS endpoints. \
156-
For example, using 'https://<ip>:port'.\n\
157-
Please modify the address or disable TLS ('{}'). {}\
158-
",
159-
format_argument!("--tls-disabled"),
160-
format_warning!("WARNING: this will send passwords over plaintext!"),
161-
);
162-
exit(ExitCode::UserInputError as i32);
163-
}
164160
let tls_root_ca_path = args.tls_root_ca.as_ref().map(|value| Path::new(value));
165-
166161
let runtime = BackgroundRuntime::new();
162+
let driver_options = DriverOptions::new()
163+
.use_replication(!args.replication_disabled)
164+
.is_tls_enabled(!args.tls_disabled)
165+
.tls_root_ca(tls_root_ca_path)
166+
.unwrap();
167167
let driver = match runtime.run(TypeDBDriver::new(
168-
address,
168+
address_info.addresses,
169169
Credentials::new(&username, args.password.as_ref().unwrap()),
170-
DriverOptions::new(!args.tls_disabled, tls_root_ca_path).unwrap(),
170+
driver_options,
171171
)) {
172172
Ok(driver) => Arc::new(driver),
173173
Err(err) => {
174174
let tls_error =
175175
if args.tls_disabled { "" } else { "\nVerify that the server is also configured with TLS encryption." };
176-
println_error!("Failed to create driver connection to server. {err}{tls_error}");
176+
let replication_error =
177+
if args.replication_disabled { "\nVerify that the connection address is **exactly** the same as the server address specified in its config." } else { "" };
178+
println_error!("Failed to create driver connection to server. {err}{tls_error}{replication_error}");
177179
exit(ExitCode::ConnectionError as i32);
178180
}
179181
};
@@ -332,6 +334,28 @@ fn execute_commands(context: &mut ConsoleContext, mut input: &str, must_log_comm
332334
}
333335

334336
fn entry_repl(driver: Arc<TypeDBDriver>, runtime: BackgroundRuntime) -> Repl<ConsoleContext> {
337+
let server_commands =
338+
Subcommand::new("server").add(CommandLeaf::new("version", "Retrieve server version.", server_version));
339+
340+
let replica_commands = Subcommand::new("replica")
341+
.add(CommandLeaf::new("list", "List replicas.", replica_list))
342+
.add(CommandLeaf::new("primary", "Get current primary replica.", replica_primary))
343+
.add(CommandLeaf::new_with_inputs(
344+
"register",
345+
"Register new replica. Requires a clustering address, not a connection address.",
346+
vec![
347+
CommandInput::new_required("replica id", get_word, None),
348+
CommandInput::new_required("clustering address", get_word, None),
349+
],
350+
replica_register,
351+
))
352+
.add(CommandLeaf::new_with_input(
353+
"deregister",
354+
"Deregister existing replica.",
355+
CommandInput::new_required("replica id", get_word, None),
356+
replica_deregister,
357+
));
358+
335359
let database_commands = Subcommand::new("database")
336360
.add(CommandLeaf::new("list", "List databases on the server.", database_list))
337361
.add(CommandLeaf::new_with_input(
@@ -451,8 +475,10 @@ fn entry_repl(driver: Arc<TypeDBDriver>, runtime: BackgroundRuntime) -> Repl<Con
451475
let history_path = home_dir().unwrap_or_else(|| temp_dir()).join(ENTRY_REPL_HISTORY);
452476

453477
let repl = Repl::new(PROMPT.to_owned(), history_path, false, None)
478+
.add(server_commands)
454479
.add(database_commands)
455480
.add(user_commands)
481+
.add(replica_commands)
456482
.add(transaction_commands);
457483

458484
repl
@@ -508,6 +534,42 @@ fn transaction_type_str(transaction_type: TransactionType) -> &'static str {
508534
}
509535
}
510536

537+
struct AddressInfo {
538+
only_https: bool,
539+
addresses: Addresses,
540+
}
541+
542+
fn parse_addresses(args: &Args) -> AddressInfo {
543+
if let Some(address) = &args.address {
544+
AddressInfo {
545+
only_https: is_https_address(address),
546+
addresses: Addresses::try_from_address_str(address).unwrap(),
547+
}
548+
} else if let Some(addresses) = &args.addresses {
549+
let split = addresses.split(',').map(str::to_string).collect::<Vec<_>>();
550+
let only_https = split.iter().all(|address| is_https_address(address));
551+
AddressInfo { only_https, addresses: Addresses::try_from_addresses_str(split).unwrap() }
552+
} else if let Some(translation) = &args.address_translation {
553+
let mut map = HashMap::new();
554+
let mut only_https = true;
555+
for pair in translation.split(',') {
556+
let (public_address, private_address) = pair
557+
.split_once('=')
558+
.unwrap_or_else(|| panic!("Invalid address pair: {pair}. Must be of form public=private"));
559+
only_https = only_https && is_https_address(public_address);
560+
map.insert(public_address.to_string(), private_address.to_string());
561+
}
562+
println!("Translation map:: {map:?}"); // TODO: Remove
563+
AddressInfo { only_https, addresses: Addresses::try_from_translation_str(map).unwrap() }
564+
} else {
565+
panic!("At least one of --address, --addresses, or --address-translation must be provided.");
566+
}
567+
}
568+
569+
fn is_https_address(address: &str) -> bool {
570+
address.starts_with("https:")
571+
}
572+
511573
fn init_diagnostics() {
512574
let _ = sentry::init((
513575
DIAGNOSTICS_REPORTING_URI,

0 commit comments

Comments
 (0)