Skip to content

Commit cc3eeab

Browse files
feat(optimizer)!: annotate type for Snowflake REGEXP_SUBSTR_ALL function
1 parent 73186a8 commit cc3eeab

File tree

4 files changed

+52
-5
lines changed

4 files changed

+52
-5
lines changed

sqlglot/dialects/snowflake.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,9 @@ def _builder(args: t.List) -> E | exp.Escape:
333333
return _builder
334334

335335

336-
def _regexpextract_sql(self, expression: exp.RegexpExtract | exp.RegexpExtractAll) -> str:
336+
def _regexpextract_sql(
337+
self, expression: exp.RegexpExtract | exp.RegexpExtractAll | exp.RegexpSubstrAll
338+
) -> str:
337339
# Other dialects don't support all of the following parameters, so we need to
338340
# generate default values as necessary to ensure the transpilation is correct
339341
group = expression.args.get("group")
@@ -347,8 +349,15 @@ def _regexpextract_sql(self, expression: exp.RegexpExtract | exp.RegexpExtractAl
347349
occurrence = expression.args.get("occurrence") or (parameters and exp.Literal.number(1))
348350
position = expression.args.get("position") or (occurrence and exp.Literal.number(1))
349351

352+
if isinstance(expression, exp.RegexpExtract):
353+
func_name = "REGEXP_SUBSTR"
354+
elif isinstance(expression, exp.RegexpSubstrAll):
355+
func_name = "REGEXP_SUBSTR_ALL"
356+
else: # exp.RegexpExtractAll
357+
func_name = "REGEXP_EXTRACT_ALL"
358+
350359
return self.func(
351-
"REGEXP_SUBSTR" if isinstance(expression, exp.RegexpExtract) else "REGEXP_EXTRACT_ALL",
360+
func_name,
352361
expression.this,
353362
expression.expression,
354363
position,
@@ -585,6 +594,7 @@ class Snowflake(Dialect):
585594
},
586595
exp.DataType.Type.ARRAY: {
587596
exp.Split,
597+
exp.RegexpSubstrAll,
588598
},
589599
exp.DataType.Type.OBJECT: {
590600
exp.ParseUrl,
@@ -755,7 +765,7 @@ class Parser(parser.Parser):
755765
"REGEXP_EXTRACT_ALL": _build_regexp_extract(exp.RegexpExtractAll),
756766
"REGEXP_REPLACE": _build_regexp_replace,
757767
"REGEXP_SUBSTR": _build_regexp_extract(exp.RegexpExtract),
758-
"REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll),
768+
"REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpSubstrAll),
759769
"REPLACE": build_replace_with_optional_replacement,
760770
"RLIKE": exp.RegexpLike.from_arg_list,
761771
"SHA1_BINARY": exp.SHA1Digest.from_arg_list,
@@ -1424,6 +1434,7 @@ class Generator(generator.Generator):
14241434
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
14251435
exp.RegexpExtract: _regexpextract_sql,
14261436
exp.RegexpExtractAll: _regexpextract_sql,
1437+
exp.RegexpSubstrAll: _regexpextract_sql,
14271438
exp.RegexpILike: _regexpilike_sql,
14281439
exp.Rand: rename_func("RANDOM"),
14291440
exp.Select: transforms.preprocess(

sqlglot/expressions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7256,6 +7256,17 @@ class RegexpExtractAll(Func):
72567256
}
72577257

72587258

7259+
class RegexpSubstrAll(Func):
7260+
arg_types = {
7261+
"this": True,
7262+
"expression": True,
7263+
"position": False,
7264+
"occurrence": False,
7265+
"parameters": False,
7266+
"group": False,
7267+
}
7268+
7269+
72597270
class RegexpReplace(Func):
72607271
arg_types = {
72617272
"this": True,

tests/dialects/test_snowflake.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2416,10 +2416,15 @@ def test_regexp_substr(self, logger):
24162416
},
24172417
)
24182418

2419+
self.validate_identity("REGEXP_SUBSTR_ALL(subject, pattern)")
2420+
self.validate_identity("REGEXP_SUBSTR_ALL(subject, pattern, 1)")
2421+
self.validate_identity("REGEXP_SUBSTR_ALL(subject, pattern, 1, 1)")
2422+
self.validate_identity("REGEXP_SUBSTR_ALL(subject, pattern, 1, 1, 'i')")
24192423
self.validate_identity(
2420-
"REGEXP_SUBSTR_ALL(subject, pattern)",
2421-
"REGEXP_EXTRACT_ALL(subject, pattern)",
2424+
"REGEXP_SUBSTR_ALL(subject, pattern, 1, 1, 'i', 0)",
2425+
"REGEXP_SUBSTR_ALL(subject, pattern, 1, 1, 'i')",
24222426
)
2427+
self.validate_identity("REGEXP_EXTRACT_ALL(subject, pattern)")
24232428

24242429
self.validate_identity("SELECT REGEXP_COUNT('hello world', 'l')")
24252430
self.validate_identity("SELECT REGEXP_COUNT('hello world', 'l', 1)")

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,26 @@ VARCHAR;
19271927
REGEXP_SUBSTR('hello world', 'world', 1, 1, 'e', NULL);
19281928
VARCHAR;
19291929

1930+
# dialect: snowflake
1931+
REGEXP_SUBSTR_ALL('hello world', 'world');
1932+
ARRAY;
1933+
1934+
# dialect: snowflake
1935+
REGEXP_SUBSTR_ALL('hello world', 'world', 1);
1936+
ARRAY;
1937+
1938+
# dialect: snowflake
1939+
REGEXP_SUBSTR_ALL('hello world', 'world', 1, 1);
1940+
ARRAY;
1941+
1942+
# dialect: snowflake
1943+
REGEXP_SUBSTR_ALL('hello world', 'world', 1, 1, 'i');
1944+
ARRAY;
1945+
1946+
# dialect: snowflake
1947+
REGEXP_SUBSTR_ALL('hello world', 'world', 1, 1, 'i', 0);
1948+
ARRAY;
1949+
19301950
# dialect: snowflake
19311951
REPEAT('hello', 3);
19321952
VARCHAR;

0 commit comments

Comments
 (0)