Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sub_queries arg #168

Merged
merged 2 commits into from
Oct 26, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion nds/nds_validate.py
Copy link
Collaborator

@gerashegalov gerashegalov Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update copyright year, we could copy define a pre-commit hook like https://github.com/NVIDIA/spark-rapids/blob/branch-23.12/.pre-commit-config.yaml#L20

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated copyright year.
File #169 as a follow up issue.

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from pyspark.sql.types import *
from pyspark.sql.functions import col

from nds_power import gen_sql_from_stream
from nds_power import gen_sql_from_stream, get_query_subset

def compare_results(spark_session: SparkSession,
input1: str,
Expand Down Expand Up @@ -335,8 +335,17 @@ def update_summary(prefix, unmatch_queries):
' checks when the input data is float for some queries.')
parser.add_argument('--json_summary_folder',
help='path of a folder that contains json summary file for each query.')
parser.add_argument('--sub_queries',
type=lambda s: [x.strip() for x in s.split(',')],
help='comma separated list of queries to compare. If not specified, all queries ' +
'in the stream file will be compared. e.g. "query1,query2,query3". Note, use ' +
'"_part1" and "_part2" suffix for the following query names: ' +
'query14, query23, query24, query39. e.g. query14_part1, query39_part2')
args = parser.parse_args()
query_dict = gen_sql_from_stream(args.query_stream_file)
# if set sub_queries, only compare the specified queries
if args.sub_queries:
query_dict = get_query_subset(query_dict, args.sub_queries)
session_builder = SparkSession.builder.appName("Validate Query Output").getOrCreate()
unmatch_queries = iterate_queries(session_builder,
args.input1,
Expand Down