Skip to content

Commit

Permalink
Merge pull request #6998 from janezd/create-class-re
Browse files Browse the repository at this point in the history
Create Class: Add regular expressions
  • Loading branch information
VesnaT authored Jan 24, 2025
2 parents 083c53d + 9bd18c7 commit 091e88c
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 80 deletions.
205 changes: 148 additions & 57 deletions Orange/widgets/data/owcreateclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Widget for creating classes from non-numeric attribute by substrings"""
import re
from itertools import count
from typing import Optional, Sequence

import numpy as np

Expand All @@ -19,39 +20,71 @@
from Orange.widgets.widget import Msg, Input, Output


def map_by_substring(a, patterns, case_sensitive, match_beginning,
map_values=None):
def map_by_substring(
a: np.ndarray,
patterns: list[str],
case_sensitive: bool, match_beginning: bool, regular_expressions: bool,
map_values: Optional[Sequence[int]] = None) -> np.ndarray:
"""
Map values in a using a list of patterns. The patterns are considered in
order of appearance.
Flags `match_beginning` and `regular_expressions` are incompatible.
Args:
a (np.array): input array of `dtype` `str`
patterns (list of str): list of strings
case_sensitive (bool): case sensitive match
match_beginning (bool): match only at the beginning of the string
map_values (list of int): list of len(pattens);
contains return values for each pattern
map_values (list of int, optional):
list of len(patterns); return values for each pattern
regular_expressions (bool): use regular expressions
Returns:
np.array of floats representing indices of matched patterns
"""
assert not (regular_expressions and match_beginning)
if map_values is None:
map_values = np.arange(len(patterns))
else:
map_values = np.array(map_values, dtype=int)
res = np.full(len(a), np.nan)
if not case_sensitive:
if not case_sensitive and not regular_expressions:
a = np.char.lower(a)
patterns = (pattern.lower() for pattern in patterns)
for val_idx, pattern in reversed(list(enumerate(patterns))):
indices = np.char.find(a, pattern)
matches = indices == 0 if match_beginning else indices != -1
# Note that similar code repeats in update_counts. Any changes here
# should be reflected there.
if regular_expressions:
re_pattern = re.compile(pattern,
re.IGNORECASE if not case_sensitive else 0)
matches = np.array([bool(re_pattern.search(s)) for s in a],
dtype=bool)
else:
indices = np.char.find(a, pattern)
matches = indices == 0 if match_beginning else indices != -1
res[matches] = map_values[val_idx]
return res


class ValueFromStringSubstring(Transformation):
class _EqHashMixin:
def __eq__(self, other):
return super().__eq__(other) \
and self.patterns == other.patterns \
and self.case_sensitive == other.case_sensitive \
and self.match_beginning == other.match_beginning \
and self.regular_expressions == other.regular_expressions \
and np.all(self.map_values == other.map_values)

def __hash__(self):
return hash((type(self), self.variable,
tuple(self.patterns),
self.case_sensitive, self.match_beginning,
self.regular_expressions,
None if self.map_values is None else tuple(self.map_values)
))

class ValueFromStringSubstring(_EqHashMixin, Transformation):
"""
Transformation that computes a discrete variable from a string variable by
pattern matching.
Expand All @@ -67,15 +100,28 @@ class ValueFromStringSubstring(Transformation):
sensitive
match_beginning (bool, optional): if set to `True`, the pattern must
appear at the beginning of the string
map_values (list of int, optional): return values for each pattern
regular_expressions (bool, optional): if set to `True`, the patterns are
"""
def __init__(self, variable, patterns,
case_sensitive=False, match_beginning=False, map_values=None):
# regular_expressions was added later and at the end (instead of with other
# flags) for compatibility with older existing pickles
def __init__(
self,
variable: StringVariable,
patterns: list[str],
case_sensitive: bool = False,
match_beginning: bool = False,
map_values: Optional[Sequence[int]] = None,
regular_expressions: bool = False):
super().__init__(variable)
self.patterns = patterns
self.case_sensitive = case_sensitive
self.match_beginning = match_beginning
self.regular_expressions = regular_expressions
self.map_values = map_values

InheritEq = True

def transform(self, c):
"""
Transform the given data.
Expand All @@ -90,26 +136,14 @@ def transform(self, c):
c = c.astype(str)
c[nans] = ""
res = map_by_substring(
c, self.patterns, self.case_sensitive, self.match_beginning,
c, self.patterns,
self.case_sensitive, self.match_beginning, self.regular_expressions,
self.map_values)
res[nans] = np.nan
return res

def __eq__(self, other):
return super().__eq__(other) \
and self.patterns == other.patterns \
and self.case_sensitive == other.case_sensitive \
and self.match_beginning == other.match_beginning \
and self.map_values == other.map_values

def __hash__(self):
return hash((type(self), self.variable,
tuple(self.patterns),
self.case_sensitive, self.match_beginning,
self.map_values))


class ValueFromDiscreteSubstring(Lookup):
class ValueFromDiscreteSubstring(_EqHashMixin, Lookup):
"""
Transformation that computes a discrete variable from discrete variable by
pattern matching.
Expand All @@ -126,16 +160,29 @@ class ValueFromDiscreteSubstring(Lookup):
sensitive
match_beginning (bool, optional): if set to `True`, the pattern must
appear at the beginning of the string
map_values (list of int, optional): return values for each pattern
regular_expressions (bool, optional): if set to `True`, the patterns are
"""
def __init__(self, variable, patterns,
case_sensitive=False, match_beginning=False,
map_values=None):
# regular_expressions was added later and at the end (instead of with other
# flags) for compatibility with older existing pickles
def __init__(
self,
variable: DiscreteVariable,
patterns: list[str],
case_sensitive: bool = False,
match_beginning: bool = False,
map_values: Optional[Sequence[int]] = None,
regular_expressions: bool = False):
super().__init__(variable, [])
self.case_sensitive = case_sensitive
self.match_beginning = match_beginning
self.map_values = map_values
self.regular_expressions = regular_expressions
self.patterns = patterns # Finally triggers computation of the lookup

InheritEq = True

def __setattr__(self, key, value):
"""__setattr__ is overloaded to recompute the lookup table when the
patterns, the original attribute or the flags change."""
Expand All @@ -145,10 +192,10 @@ def __setattr__(self, key, value):
"variable", "map_values"):
self.lookup_table = map_by_substring(
self.variable.values, self.patterns,
self.case_sensitive, self.match_beginning, self.map_values)

self.case_sensitive, self.match_beginning,
self.regular_expressions, self.map_values)

def unique_in_order_mapping(a):
def unique_in_order_mapping(a: Sequence[str]) -> tuple[list[str], list[int]]:
""" Return
- unique elements of the input list (in the order of appearance)
- indices of the input list onto the returned uniques
Expand Down Expand Up @@ -187,6 +234,7 @@ class Outputs:
rules = ContextSetting({})
match_beginning = ContextSetting(False)
case_sensitive = ContextSetting(False)
regular_expressions = ContextSetting(False)

TRANSFORMERS = {StringVariable: ValueFromStringSubstring,
DiscreteVariable: ValueFromDiscreteSubstring}
Expand All @@ -202,6 +250,7 @@ class Warning(widget.OWWidget.Warning):
class Error(widget.OWWidget.Error):
class_name_duplicated = Msg("Class name duplicated.")
class_name_empty = Msg("Class name should not be empty.")
invalid_regular_expression = Msg("Invalid regular expression: {}")

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -252,9 +301,9 @@ def __init__(self):
rules_box.addWidget(QLabel("Count"), 0, 3, 1, 2)
self.update_rules()

widget = QWidget(patternbox)
widget.setLayout(rules_box)
patternbox.layout().addWidget(widget)
widg = QWidget(patternbox)
widg.setLayout(rules_box)
patternbox.layout().addWidget(widg)

box = gui.hBox(patternbox)
gui.rubber(box)
Expand All @@ -264,8 +313,12 @@ def __init__(self):
QSizePolicy.Maximum))

optionsbox = gui.vBox(self.controlArea, "Options")
gui.checkBox(
optionsbox, self, "regular_expressions", "Use regular expressions",
callback=self.options_changed)
gui.checkBox(
optionsbox, self, "match_beginning", "Match only at the beginning",
stateWhenDisabled=False,
callback=self.options_changed)
gui.checkBox(
optionsbox, self, "case_sensitive", "Case sensitive",
Expand Down Expand Up @@ -322,6 +375,7 @@ def update_rules(self):
# TODO: Indicator that changes need to be applied

def options_changed(self):
self.controls.match_beginning.setEnabled(not self.regular_expressions)
self.update_counts()

def adjust_n_rule_rows(self):
Expand All @@ -344,8 +398,8 @@ def _add_line():
self.rules_box.addWidget(button, n_lines, 0)
self.counts.append([])
for coli, kwargs in enumerate(
(dict(),
dict(styleSheet="color: gray"))):
({},
{"styleSheet": "color: gray"})):
label = QLabel(alignment=Qt.AlignCenter, **kwargs)
self.counts[-1].append(label)
self.rules_box.addWidget(label, n_lines, 3 + coli)
Expand Down Expand Up @@ -401,23 +455,48 @@ def class_labels(self):
if re.match("^C\\d+", label)),
default=0)
class_count = count(largest_c + 1)
return [label_edit.text() or "C{}".format(next(class_count))
return [label_edit.text() or f"C{next(class_count)}"
for label_edit, _ in self.line_edits]

def invalid_patterns(self):
if not self.regular_expressions:
return None
for _, pattern in self.active_rules:
try:
re.compile(pattern)
except re.error:
return pattern
return None

def update_counts(self):
"""Recompute and update the counts of matches."""
def _matcher(strings, pattern):
"""Return indices of strings into patterns; consider case
sensitivity and matching at the beginning. The given strings are
assumed to be in lower case if match is case insensitive. Patterns
are fixed on the fly."""
if not self.case_sensitive:
pattern = pattern.lower()
indices = np.char.find(strings, pattern.strip())
return indices == 0 if self.match_beginning else indices != -1

def _lower_if_needed(strings):
return strings if self.case_sensitive else np.char.lower(strings)
if self.regular_expressions:
def _matcher(strings, pattern):
# Note that similar code repeats in map_by_substring.
# Any changes here should be reflected there.
re_pattern = re.compile(
pattern,
re.IGNORECASE if not self.case_sensitive else 0)
return np.array([bool(re_pattern.search(s)) for s in strings],
dtype=bool)

def _lower_if_needed(strings):
return strings
else:
def _matcher(strings, pattern):
"""Return indices of strings into patterns; consider case
sensitivity and matching at the beginning. The given strings are
assumed to be in lower case if match is case insensitive. Patterns
are fixed on the fly."""
# Note that similar code repeats in map_by_substring.
# Any changes here should be reflected there.
if not self.case_sensitive:
pattern = pattern.lower()
indices = np.char.find(strings, pattern.strip())
return indices == 0 if self.match_beginning else indices != -1

def _lower_if_needed(strings):
return strings if self.case_sensitive else np.char.lower(strings)

def _string_counts():
"""
Expand Down Expand Up @@ -469,9 +548,9 @@ def _set_labels():
for (n_matched, n_total), (lab_matched, lab_total), (lab, patt) in \
zip(self.match_counts, self.counts, self.active_rules):
n_before = n_total - n_matched
lab_matched.setText("{}".format(n_matched))
lab_matched.setText(f"{n_matched}")
if n_before and (lab or patt):
lab_total.setText("+ {}".format(n_before))
lab_total.setText(f"+ {n_before}")
if n_matched:
tip = f"{n_before} o" \
f"f {n_total} matching {pl(n_total, 'instance')} " \
Expand All @@ -496,6 +575,11 @@ def _set_placeholders():
lab_edit.setPlaceholderText(label)

_clear_labels()
if (invalid := self.invalid_patterns()) is not None:
self.Error.invalid_regular_expression(invalid)
return
self.Error.invalid_regular_expression.clear()

attr = self.attribute
if attr is None:
return
Expand All @@ -510,6 +594,11 @@ def _set_placeholders():
def apply(self):
"""Output the transformed data."""
self.Error.clear()
if (invalid := self.invalid_patterns()) is not None:
self.Error.invalid_regular_expression(invalid)
self.Outputs.data.send(None)
return

self.class_name = self.class_name.strip()
if not self.attribute:
self.Outputs.data.send(None)
Expand Down Expand Up @@ -541,19 +630,21 @@ def _create_variable(self):
if valid)
transformer = self.TRANSFORMERS[type(self.attribute)]

# join patters with the same names
# join patterns with the same names
names, map_values = unique_in_order_mapping(names)
names = tuple(str(a) for a in names)
map_values = tuple(map_values)

var_key = (self.attribute, self.class_name, names,
patterns, self.case_sensitive, self.match_beginning, map_values)
patterns, self.case_sensitive, self.match_beginning,
self.regular_expressions, map_values)
if var_key in self.cached_variables:
return self.cached_variables[var_key]

compute_value = transformer(
self.attribute, patterns, self.case_sensitive, self.match_beginning,
map_values)
self.attribute, patterns, self.case_sensitive,
self.match_beginning and not self.regular_expressions,
map_values, self.regular_expressions)
new_var = DiscreteVariable(
self.class_name, names, compute_value=compute_value)
self.cached_variables[var_key] = new_var
Expand Down Expand Up @@ -597,10 +688,10 @@ def _count_part():
for (n_matched, n_total), class_name, (lab, patt) in \
zip(self.match_counts, names, self.active_rules):
if lab or patt or n_total:
output += "<li>{}; {}</li>".format(_cond_part(), _count_part())
output += f"<li>{_cond_part()}; {_count_part()}</li>"
if output:
self.report_items("Output", [("Class name", self.class_name)])
self.report_raw("<ol>{}</ol>".format(output))
self.report_raw(f"<ol>{output}</ol>")


if __name__ == "__main__": # pragma: no cover
Expand Down
Loading

0 comments on commit 091e88c

Please sign in to comment.