Skip to content

Commit

Permalink
More flexible regexes; fix default split function (#1443)
Browse files Browse the repository at this point in the history
Signed-off-by: Lee Yang <[email protected]>
  • Loading branch information
leewyang authored Dec 3, 2024
1 parent fc49eb6 commit fb72520
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion user_tools/src/spark_rapids_tools/tools/qualx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def extract_model_features(
default_df = default_df.loc[~default_df.appName.str.startswith(f'{ds_name}:')]
modified_default_df = default_split_fn(default_df)
if modified_default_df.index.equals(default_df.index):
cpu_aug_tbl.update(default_df)
cpu_aug_tbl.update(modified_default_df)
cpu_aug_tbl.astype(df_schema)
else:
raise ValueError('Default split_function unexpectedly modified row indices.')
Expand Down
10 changes: 7 additions & 3 deletions user_tools/src/spark_rapids_tools/tools/qualx/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
get_logger,
get_dataset_platforms,
load_plugin,
run_profiler_tool, log_fallback,
log_fallback,
run_profiler_tool,
RegexPattern
)

PREPROCESSED_FILE = 'preprocessed.parquet'
Expand Down Expand Up @@ -269,10 +271,12 @@ def infer_app_meta(eventlogs: List[str]) -> Mapping[str, Mapping]:
app_meta_inner = {}
for e in eventlog_list:
parts = Path(e).parts
app_id_inner = parts[-1]
app_id_part = parts[-1]
match = RegexPattern.app_id.search(app_id_part)
app_id = match.group() if match else app_id_part
run_type = parts[-2].upper()
job_name = parts[-4]
app_meta_inner[app_id_inner] = {
app_meta_inner[app_id] = {
'jobName': job_name,
'runType': run_type,
'scaleFactor': 1,
Expand Down
6 changes: 3 additions & 3 deletions user_tools/src/spark_rapids_tools/tools/qualx/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def get_logger(name: str) -> logging.Logger:

@dataclass
class RegexPattern:
app_id = re.compile(r'^app.*[_-][0-9]+[_-][0-9]+$')
profile = re.compile(r'^prof_[0-9]+_[0-9a-zA-Z]+$')
qual_tool = re.compile(r'^qual_[0-9]+_[0-9a-zA-Z]+$')
app_id = re.compile(r'app.*[_-][0-9]+[_-][0-9]+')
profile = re.compile(r'prof_[0-9]+_[0-9a-zA-Z]+')
qual_tool = re.compile(r'qual_[0-9]+_[0-9a-zA-Z]+')
rapids_profile = re.compile(r'rapids_4_spark_profile')
rapids_qual = re.compile(r'rapids_4_spark_qualification_output')
qual_tool_metrics = re.compile(r'raw_metrics')
Expand Down

0 comments on commit fb72520

Please sign in to comment.