Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 109 additions & 14 deletions datafusion/functions/src/regex/regexpreplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
// under the License.

//! Regex expressions
use memchr::memchr;

use arrow::array::ArrayDataBuilder;
use arrow::array::BufferBuilder;
use arrow::array::GenericStringArray;
Expand All @@ -40,7 +42,8 @@ use datafusion_expr::{
Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;
use regex::Regex;
use regex::{CaptureLocations, Regex};
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::{Arc, LazyLock};

Expand Down Expand Up @@ -199,6 +202,96 @@ fn regex_replace_posix_groups(replacement: &str) -> String {
.into_owned()
}

struct ShortRegex {
/// Shortened anchored regex used to extract capture group 1 directly.
/// See [`try_build_short_extract_regex`] for details.
short_re: Regex,
/// Reusable capture locations for `short_re` to avoid per-row allocation.
locs: CaptureLocations,
}

/// Holds the normal compiled regex together with the optional fast path used
/// for `regexp_replace(str, '^...(capture)...*$', '\1')`.
struct OptimizedRegex {
/// Full regex used for the normal replacement path and as a correctness fallback.
re: Regex,
/// Precomputed state for the direct-extraction fast path, when applicable.
short_re: Option<ShortRegex>,
}

impl OptimizedRegex {
/// Builds any reusable state needed by the extraction fast path.
///
/// The fast path is only enabled for single replacements where the pattern
/// and replacement satisfy [`try_build_short_extract_regex`].
fn new(re: Regex, limit: usize, pattern: &str, replacement: &str) -> Self {
let short_re = if limit == 1 {
try_build_short_extract_regex(pattern, replacement)
} else {
None
};

let short_re = short_re.map(|short_re| {
let locs = short_re.capture_locations();
ShortRegex { short_re, locs }
});

Self { re, short_re }
}

/// Applies the direct-extraction fast path when it preserves the result of
/// `Regex::replacen`; otherwise falls back to the full regex replacement.
fn replacen<'a>(
&mut self,
val: &'a str,
limit: usize,
replacement: &str,
) -> Cow<'a, str> {
// If this pattern is not eligible for direct extraction, use the full regex.
let Some(ShortRegex { short_re, locs }) = self.short_re.as_mut() else {
return self.re.replacen(val, limit, replacement);
};

// If the shortened regex does not match, the original anchored regex would
// also leave the input unchanged.
if short_re.captures_read(locs, val).is_none() {
return Cow::Borrowed(val);
};

// `captures_read` succeeded, so the overall shortened match is present.
let match_end = locs.get(0).unwrap().1;
if memchr(b'\n', &val.as_bytes()[match_end..]).is_some() {
// If there is a newline after the match, we can't use the short
// regex since it won't match across lines. Fall back to the full
// regex replacement.
return self.re.replacen(val, limit, replacement);
};
// The fast path only applies to `${1}` replacements, so the result is
// either capture group 1 or the empty string if that group did not match.
if let Some((start, end)) = locs.get(1) {
Cow::Borrowed(&val[start..end])
} else {
Cow::Borrowed("")
}
}
}

/// For anchored patterns like `^...(capture)....*$` where the replacement
/// is `\1`, build a shorter regex (stripping trailing `.*$`) and use
/// `captures_read` with `CaptureLocations` for direct extraction — no
/// `expand()`, no `String` allocation.
fn try_build_short_extract_regex(pattern: &str, replacement: &str) -> Option<Regex> {
if replacement != "${1}" || !pattern.starts_with('^') || !pattern.ends_with(".*$") {
return None;
}
let short = &pattern[..pattern.len() - 3];
let re = Regex::new(short).ok()?;
if re.captures_len() != 2 {
return None;
}
Some(re)
}

/// Replaces substring(s) matching a PCRE-like regular expression.
///
/// The full list of supported features and syntax can be found at
Expand Down Expand Up @@ -422,7 +515,7 @@ macro_rules! fetch_string_arg {
/// hold a single Regex object for the replace operation. This also speeds
/// up the pre-processing time of the replacement string, since it only
/// needs to processed once.
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
fn regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
let array_size = args[0].len();
Expand Down Expand Up @@ -457,6 +550,8 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
// with rust ones.
let replacement = regex_replace_posix_groups(replacement);

let mut opt_re = OptimizedRegex::new(re, limit, &pattern, &replacement);

let string_array_type = args[0].data_type();
match string_array_type {
DataType::Utf8 | DataType::LargeUtf8 => {
Expand All @@ -475,7 +570,7 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(

string_array.iter().for_each(|val| {
if let Some(val) = val {
let result = re.replacen(val, limit, replacement.as_str());
let result = opt_re.replacen(val, limit, replacement.as_str());
vals.append_slice(result.as_bytes());
}
new_offsets.append(T::from_usize(vals.len()).unwrap());
Expand All @@ -496,8 +591,8 @@ fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(

for val in string_view_array.iter() {
if let Some(val) = val {
let result = re.replacen(val, limit, replacement.as_str());
builder.append_value(result);
let result = opt_re.replacen(val, limit, replacement.as_str());
builder.append_value(result.as_ref());
} else {
builder.append_null();
}
Expand Down Expand Up @@ -576,7 +671,7 @@ fn specialize_regexp_replace<T: OffsetSizeTrait>(
arg.to_array(expansion_len)
})
.collect::<Result<Vec<_>>>()?;
_regexp_replace_static_pattern_replace::<T>(&args)
regexp_replace_static_pattern_replace::<T>(&args)
}

// If there are no specialized implementations, we'll fall back to the
Expand Down Expand Up @@ -710,7 +805,7 @@ mod tests {
let replacements = <$T>::from(replacement);
let expected = <$T>::from(expected);

let re = _regexp_replace_static_pattern_replace::<$O>(&[
let re = regexp_replace_static_pattern_replace::<$O>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand Down Expand Up @@ -755,7 +850,7 @@ mod tests {
let flags = StringArray::from(vec!["i"; 5]);
let expected = <$T>::from(expected);

let re = _regexp_replace_static_pattern_replace::<$O>(&[
let re = regexp_replace_static_pattern_replace::<$O>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand Down Expand Up @@ -787,7 +882,7 @@ mod tests {
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec![None::<&str>; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
let re = regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand All @@ -804,7 +899,7 @@ mod tests {
let replacements = StringArray::from(Vec::<Option<&str>>::new());
let expected = StringArray::from(Vec::<Option<&str>>::new());

let re = _regexp_replace_static_pattern_replace::<i32>(&[
let re = regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand All @@ -822,7 +917,7 @@ mod tests {
let flags = StringArray::from(vec![None::<&str>; 5]);
let expected = StringArray::from(vec![None::<&str>; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
let re = regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand All @@ -841,7 +936,7 @@ mod tests {
let patterns = StringArray::from(vec!["["; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
let re = regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand Down Expand Up @@ -878,7 +973,7 @@ mod tests {
Some("c"),
]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
let re = regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand Down Expand Up @@ -906,7 +1001,7 @@ mod tests {
let replacements = StringArray::from(vec!["foo"; 1]);
let expected = StringArray::from(vec![Some("b"), None, Some("foo"), None, None]);

let re = _regexp_replace_static_pattern_replace::<i32>(&[
let re = regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Expand Down
29 changes: 29 additions & 0 deletions datafusion/sqllogictest/test_files/regexp/regexp_replace.slt
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,35 @@ from (values ('a'), ('b')) as tbl(col);
NULL NULL NULL
NULL NULL NULL

# Extract domain from URL using anchored pattern with trailing .*
# This tests that the full URL suffix is replaced, not just the matched prefix
query T
SELECT regexp_replace(url, '^https?://(?:www\.)?([^/]+)/.*$', '\1') FROM (VALUES
('https://www.example.com/path/to/page?q=1'),
('http://test.org/foo/bar'),
('https://example.com/'),
('not-a-url')
) AS t(url);
----
example.com
test.org
example.com
not-a-url

# More than one capture group should disable the short-regex fast path.
# This still uses replacement \1, but captures_len() will be > 2, so the
# implementation must fall back to the normal regexp_replace path.
query T
SELECT regexp_replace(url, '^https?://((www\.)?([^/]+))/.*$', '\1') FROM (VALUES
('https://www.example.com/path/to/page?q=1'),
('http://test.org/foo/bar'),
('not-a-url')
) AS t(url);
----
www.example.com
test.org
not-a-url

# If the overall pattern matches but capture group 1 does not participate,
# regexp_replace(..., '\1') should substitute the empty string, not keep
# the original input.
Expand Down
Loading