Skip to content

Commit

Permalink
Fix compare function
Browse files Browse the repository at this point in the history
Signed-off-by: Yinqing Hao <[email protected]>
  • Loading branch information
yinqingh committed Oct 17, 2024
1 parent 574de5a commit 56fc6f2
Showing 1 changed file with 28 additions and 36 deletions.
64 changes: 28 additions & 36 deletions nds-h/nds_h_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
import os
import re
import time
from decimal import *
from decimal import Decimal

from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import *
from pyspark.sql.types import DoubleType, FloatType
from pyspark.sql.functions import col

from nds_h_power import gen_sql_from_stream, get_query_subset
Expand Down Expand Up @@ -132,16 +132,10 @@ def collect_results(df: DataFrame,
df = df.drop(*SKIP_COLUMNS[query_name])

# apply sorting if specified
non_float_cols = [col(field.name) for \
field in df.schema.fields \
if (field.dataType.typeName() != FloatType.typeName()) \
and \
(field.dataType.typeName() != DoubleType.typeName())]
float_cols = [col(field.name) for \
field in df.schema.fields \
if (field.dataType.typeName() == FloatType.typeName()) \
or \
(field.dataType.typeName() == DoubleType.typeName())]
non_float_cols = [col(field.name) for field in df.schema.fields
if field.dataType.typeName() not in (FloatType.typeName(), DoubleType.typeName())]
float_cols = [col(field.name) for field in df.schema.fields
if field.dataType.typeName() in (FloatType.typeName(), DoubleType.typeName())]
if ignore_ordering:
df = df.sort(non_float_cols + float_cols)

Expand Down Expand Up @@ -172,21 +166,16 @@ def compare(expected, actual, epsilon=0.00001):
# Double is converted to float in pyspark...
if math.isnan(expected) and math.isnan(actual):
return True
else:
return math.isclose(expected, actual, rel_tol=epsilon)
elif isinstance(expected, str) and isinstance(actual, str):
return expected == actual
elif expected == None and actual == None:
return True
elif expected != None and actual == None:
return False
elif expected == None and actual != None:
return False
elif isinstance(expected, Decimal) and isinstance(actual, Decimal):
return math.isclose(expected, actual, rel_tol=epsilon)
else:

if isinstance(expected, Decimal) and isinstance(actual, Decimal):
return math.isclose(expected, actual, rel_tol=epsilon)

if isinstance(expected, str) and isinstance(actual, str):
return expected == actual

return expected == actual

def iterate_queries(spark_session: SparkSession,
input1: str,
input2: str,
Expand Down Expand Up @@ -239,22 +228,25 @@ def update_summary(prefix, unmatch_queries):
for query_name in query_dict.keys():
summary_wildcard = prefix + f'/*{query_name}-*.json'
file_glob = glob.glob(summary_wildcard)

# Expect only one summary file for each query
if len(file_glob) > 1:
raise Exception(f"More than one summary file found for query {query_name} in folder {prefix}.")
if len(file_glob) == 0:
raise Exception(f"No summary file found for query {query_name} in folder {prefix}.")
for filename in file_glob:
with open(filename, 'r') as f:
summary = json.load(f)
if query_name in unmatch_queries:
if 'Completed' in summary['queryStatus'] or 'CompletedWithTaskFailures' in summary['queryStatus']:
summary['queryValidationStatus'] = ['Fail']
else:
summary['queryValidationStatus'] = ['NotAttempted']

filename = file_glob[0]
with open(filename, 'r') as f:
summary = json.load(f)
if query_name in unmatch_queries:
if 'Completed' in summary['queryStatus'] or 'CompletedWithTaskFailures' in summary['queryStatus']:
summary['queryValidationStatus'] = ['Fail']
else:
summary['queryValidationStatus'] = ['Pass']
with open(filename, 'w') as f:
json.dump(summary, f, indent=2)
summary['queryValidationStatus'] = ['NotAttempted']
else:
summary['queryValidationStatus'] = ['Pass']
with open(filename, 'w') as f:
json.dump(summary, f, indent=2)

if __name__ == "__main__":
parser = parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -313,4 +305,4 @@ def update_summary(prefix, unmatch_queries):
max_errors=args.max_errors,
epsilon=args.epsilon)
if args.json_summary_folder:
update_summary(args.json_summary_folder, unmatch_queries)
update_summary(args.json_summary_folder, unmatch_queries)

0 comments on commit 56fc6f2

Please sign in to comment.