43
43
44
44
from pyspark .sql .types import VarcharType , CharType
45
45
46
- def decimalType (non_decimal , precision , scale ):
47
- if non_decimal :
48
- return DoubleType ()
49
- else :
46
+ def decimalType (use_decimal , precision , scale ):
47
+ if use_decimal :
50
48
return DecimalType (precision , scale )
51
-
49
+ else :
50
+ return DoubleType ()
52
51
53
- def get_schemas (non_decimal ):
52
+ def get_schemas (use_decimal ):
54
53
SCHEMAS = {}
55
54
SCHEMAS ["dbgen_version" ] = StructType ([
56
55
StructField ("dv_version" , VarcharType (16 )),
@@ -71,7 +70,7 @@ def get_schemas(non_decimal):
71
70
StructField ("ca_state" , CharType (2 )),
72
71
StructField ("ca_zip" , CharType (10 )),
73
72
StructField ("ca_country" , VarcharType (20 )),
74
- StructField ("ca_gmt_offset" , decimalType (non_decimal , 5 , 2 )),
73
+ StructField ("ca_gmt_offset" , decimalType (use_decimal , 5 , 2 )),
75
74
StructField ("ca_location_type" , CharType (20 ))
76
75
])
77
76
@@ -132,7 +131,7 @@ def get_schemas(non_decimal):
132
131
StructField ("w_state" , CharType (2 )),
133
132
StructField ("w_zip" , CharType (10 )),
134
133
StructField ("w_country" , VarcharType (20 )),
135
- StructField ("w_gmt_offset" , decimalType (non_decimal , 5 , 2 ))
134
+ StructField ("w_gmt_offset" , decimalType (use_decimal , 5 , 2 ))
136
135
])
137
136
138
137
SCHEMAS ["ship_mode" ] = StructType ([
@@ -175,8 +174,8 @@ def get_schemas(non_decimal):
175
174
StructField ("i_rec_start_date" , DateType ()),
176
175
StructField ("i_rec_end_date" , DateType ()),
177
176
StructField ("i_item_desc" , VarcharType (200 )),
178
- StructField ("i_current_price" , decimalType (non_decimal , 7 , 2 )),
179
- StructField ("i_wholesale_cost" , decimalType (non_decimal , 7 , 2 )),
177
+ StructField ("i_current_price" , decimalType (use_decimal , 7 , 2 )),
178
+ StructField ("i_wholesale_cost" , decimalType (use_decimal , 7 , 2 )),
180
179
StructField ("i_brand_id" , IntegerType ()),
181
180
StructField ("i_brand" , CharType (50 )),
182
181
StructField ("i_class_id" , IntegerType ()),
@@ -222,8 +221,8 @@ def get_schemas(non_decimal):
222
221
StructField ("s_state" , CharType (2 )),
223
222
StructField ("s_zip" , CharType (10 )),
224
223
StructField ("s_country" , VarcharType (20 )),
225
- StructField ("s_gmt_offset" , decimalType (non_decimal , 5 , 2 )),
226
- StructField ("s_tax_precentage" , decimalType (non_decimal , 5 , 2 ))
224
+ StructField ("s_gmt_offset" , decimalType (use_decimal , 5 , 2 )),
225
+ StructField ("s_tax_precentage" , decimalType (use_decimal , 5 , 2 ))
227
226
])
228
227
229
228
SCHEMAS ["call_center" ] = StructType ([
@@ -256,8 +255,8 @@ def get_schemas(non_decimal):
256
255
StructField ("cc_state" , CharType (2 )),
257
256
StructField ("cc_zip" , CharType (10 )),
258
257
StructField ("cc_country" , VarcharType (20 )),
259
- StructField ("cc_gmt_offset" , decimalType (non_decimal , 5 , 2 )),
260
- StructField ("cc_tax_percentage" , decimalType (non_decimal , 5 , 2 ))
258
+ StructField ("cc_gmt_offset" , decimalType (use_decimal , 5 , 2 )),
259
+ StructField ("cc_tax_percentage" , decimalType (use_decimal , 5 , 2 ))
261
260
])
262
261
263
262
SCHEMAS ["customer" ] = StructType ([
@@ -306,8 +305,8 @@ def get_schemas(non_decimal):
306
305
StructField ("web_state" , CharType (2 )),
307
306
StructField ("web_zip" , CharType (10 )),
308
307
StructField ("web_country" , VarcharType (20 )),
309
- StructField ("web_gmt_offset" , decimalType (non_decimal , 5 , 2 )),
310
- StructField ("web_tax_percentage" , decimalType (non_decimal , 5 , 2 ))
308
+ StructField ("web_gmt_offset" , decimalType (use_decimal , 5 , 2 )),
309
+ StructField ("web_tax_percentage" , decimalType (use_decimal , 5 , 2 ))
311
310
])
312
311
313
312
SCHEMAS ["store_returns" ] = StructType ([
@@ -322,15 +321,15 @@ def get_schemas(non_decimal):
322
321
StructField ("sr_reason_sk" , IntegerType ()),
323
322
StructField ("sr_ticket_number" , IntegerType (), nullable = False ),
324
323
StructField ("sr_return_quantity" , IntegerType ()),
325
- StructField ("sr_return_amt" , decimalType (non_decimal , 7 , 2 )),
326
- StructField ("sr_return_tax" , decimalType (non_decimal , 7 , 2 )),
327
- StructField ("sr_return_amt_inc_tax" , decimalType (non_decimal , 7 , 2 )),
328
- StructField ("sr_fee" , decimalType (non_decimal , 7 , 2 )),
329
- StructField ("sr_return_ship_cost" , decimalType (non_decimal , 7 , 2 )),
330
- StructField ("sr_refunded_cash" , decimalType (non_decimal , 7 , 2 )),
331
- StructField ("sr_reversed_charge" , decimalType (non_decimal , 7 , 2 )),
332
- StructField ("sr_store_credit" , decimalType (non_decimal , 7 , 2 )),
333
- StructField ("sr_net_loss" , decimalType (non_decimal , 7 , 2 ))
324
+ StructField ("sr_return_amt" , decimalType (use_decimal , 7 , 2 )),
325
+ StructField ("sr_return_tax" , decimalType (use_decimal , 7 , 2 )),
326
+ StructField ("sr_return_amt_inc_tax" , decimalType (use_decimal , 7 , 2 )),
327
+ StructField ("sr_fee" , decimalType (use_decimal , 7 , 2 )),
328
+ StructField ("sr_return_ship_cost" , decimalType (use_decimal , 7 , 2 )),
329
+ StructField ("sr_refunded_cash" , decimalType (use_decimal , 7 , 2 )),
330
+ StructField ("sr_reversed_charge" , decimalType (use_decimal , 7 , 2 )),
331
+ StructField ("sr_store_credit" , decimalType (use_decimal , 7 , 2 )),
332
+ StructField ("sr_net_loss" , decimalType (use_decimal , 7 , 2 ))
334
333
])
335
334
336
335
SCHEMAS ["household_demographics" ] = StructType ([
@@ -364,7 +363,7 @@ def get_schemas(non_decimal):
364
363
StructField ("p_start_date_sk" , IntegerType ()),
365
364
StructField ("p_end_date_sk" , IntegerType ()),
366
365
StructField ("p_item_sk" , IntegerType ()),
367
- StructField ("p_cost" , decimalType (non_decimal , 15 , 2 )),
366
+ StructField ("p_cost" , decimalType (use_decimal , 15 , 2 )),
368
367
StructField ("p_response_target" , IntegerType ()),
369
368
StructField ("p_promo_name" , CharType (50 )),
370
369
StructField ("p_channel_dmail" , CharType (1 )),
@@ -418,15 +417,15 @@ def get_schemas(non_decimal):
418
417
StructField ("cr_reason_sk" , IntegerType ()),
419
418
StructField ("cr_order_number" , IntegerType (), nullable = False ),
420
419
StructField ("cr_return_quantity" , IntegerType ()),
421
- StructField ("cr_return_amount" , decimalType (non_decimal , 7 , 2 )),
422
- StructField ("cr_return_tax" , decimalType (non_decimal , 7 , 2 )),
423
- StructField ("cr_return_amt_inc_tax" , decimalType (non_decimal , 7 , 2 )),
424
- StructField ("cr_fee" , decimalType (non_decimal , 7 , 2 )),
425
- StructField ("cr_return_ship_cost" , decimalType (non_decimal , 7 , 2 )),
426
- StructField ("cr_refunded_cash" , decimalType (non_decimal , 7 , 2 )),
427
- StructField ("cr_reversed_charge" , decimalType (non_decimal , 7 , 2 )),
428
- StructField ("cr_store_credit" , decimalType (non_decimal , 7 , 2 )),
429
- StructField ("cr_net_loss" , decimalType (non_decimal , 7 , 2 ))
420
+ StructField ("cr_return_amount" , decimalType (use_decimal , 7 , 2 )),
421
+ StructField ("cr_return_tax" , decimalType (use_decimal , 7 , 2 )),
422
+ StructField ("cr_return_amt_inc_tax" , decimalType (use_decimal , 7 , 2 )),
423
+ StructField ("cr_fee" , decimalType (use_decimal , 7 , 2 )),
424
+ StructField ("cr_return_ship_cost" , decimalType (use_decimal , 7 , 2 )),
425
+ StructField ("cr_refunded_cash" , decimalType (use_decimal , 7 , 2 )),
426
+ StructField ("cr_reversed_charge" , decimalType (use_decimal , 7 , 2 )),
427
+ StructField ("cr_store_credit" , decimalType (use_decimal , 7 , 2 )),
428
+ StructField ("cr_net_loss" , decimalType (use_decimal , 7 , 2 ))
430
429
])
431
430
432
431
SCHEMAS ["web_returns" ] = StructType ([
@@ -445,15 +444,15 @@ def get_schemas(non_decimal):
445
444
StructField ("wr_reason_sk" , IntegerType ()),
446
445
StructField ("wr_order_number" , IntegerType (), nullable = False ),
447
446
StructField ("wr_return_quantity" , IntegerType ()),
448
- StructField ("wr_return_amt" , decimalType (non_decimal , 7 , 2 )),
449
- StructField ("wr_return_tax" , decimalType (non_decimal , 7 , 2 )),
450
- StructField ("wr_return_amt_inc_tax" , decimalType (non_decimal , 7 , 2 )),
451
- StructField ("wr_fee" , decimalType (non_decimal , 7 , 2 )),
452
- StructField ("wr_return_ship_cost" , decimalType (non_decimal , 7 , 2 )),
453
- StructField ("wr_refunded_cash" , decimalType (non_decimal , 7 , 2 )),
454
- StructField ("wr_reversed_charge" , decimalType (non_decimal , 7 , 2 )),
455
- StructField ("wr_account_credit" , decimalType (non_decimal , 7 , 2 )),
456
- StructField ("wr_net_loss" , decimalType (non_decimal , 7 , 2 ))
447
+ StructField ("wr_return_amt" , decimalType (use_decimal , 7 , 2 )),
448
+ StructField ("wr_return_tax" , decimalType (use_decimal , 7 , 2 )),
449
+ StructField ("wr_return_amt_inc_tax" , decimalType (use_decimal , 7 , 2 )),
450
+ StructField ("wr_fee" , decimalType (use_decimal , 7 , 2 )),
451
+ StructField ("wr_return_ship_cost" , decimalType (use_decimal , 7 , 2 )),
452
+ StructField ("wr_refunded_cash" , decimalType (use_decimal , 7 , 2 )),
453
+ StructField ("wr_reversed_charge" , decimalType (use_decimal , 7 , 2 )),
454
+ StructField ("wr_account_credit" , decimalType (use_decimal , 7 , 2 )),
455
+ StructField ("wr_net_loss" , decimalType (use_decimal , 7 , 2 ))
457
456
])
458
457
459
458
SCHEMAS ["web_sales" ] = StructType ([
@@ -476,21 +475,21 @@ def get_schemas(non_decimal):
476
475
StructField ("ws_promo_sk" , IntegerType ()),
477
476
StructField ("ws_order_number" , IntegerType (), nullable = False ),
478
477
StructField ("ws_quantity" , IntegerType ()),
479
- StructField ("ws_wholesale_cost" , decimalType (non_decimal , 7 , 2 )),
480
- StructField ("ws_list_price" , decimalType (non_decimal , 7 , 2 )),
481
- StructField ("ws_sales_price" , decimalType (non_decimal , 7 , 2 )),
482
- StructField ("ws_ext_discount_amt" , decimalType (non_decimal , 7 , 2 )),
483
- StructField ("ws_ext_sales_price" , decimalType (non_decimal , 7 , 2 )),
484
- StructField ("ws_ext_wholesale_cost" , decimalType (non_decimal , 7 , 2 )),
485
- StructField ("ws_ext_list_price" , decimalType (non_decimal , 7 , 2 )),
486
- StructField ("ws_ext_tax" , decimalType (non_decimal , 7 , 2 )),
487
- StructField ("ws_coupon_amt" , decimalType (non_decimal , 7 , 2 )),
488
- StructField ("ws_ext_ship_cost" , decimalType (non_decimal , 7 , 2 )),
489
- StructField ("ws_net_paid" , decimalType (non_decimal , 7 , 2 )),
490
- StructField ("ws_net_paid_inc_tax" , decimalType (non_decimal , 7 , 2 )),
491
- StructField ("ws_net_paid_inc_ship" , decimalType (non_decimal , 7 , 2 )),
492
- StructField ("ws_net_paid_inc_ship_tax" , decimalType (non_decimal , 7 , 2 )),
493
- StructField ("ws_net_profit" , decimalType (non_decimal , 7 , 2 ))
478
+ StructField ("ws_wholesale_cost" , decimalType (use_decimal , 7 , 2 )),
479
+ StructField ("ws_list_price" , decimalType (use_decimal , 7 , 2 )),
480
+ StructField ("ws_sales_price" , decimalType (use_decimal , 7 , 2 )),
481
+ StructField ("ws_ext_discount_amt" , decimalType (use_decimal , 7 , 2 )),
482
+ StructField ("ws_ext_sales_price" , decimalType (use_decimal , 7 , 2 )),
483
+ StructField ("ws_ext_wholesale_cost" , decimalType (use_decimal , 7 , 2 )),
484
+ StructField ("ws_ext_list_price" , decimalType (use_decimal , 7 , 2 )),
485
+ StructField ("ws_ext_tax" , decimalType (use_decimal , 7 , 2 )),
486
+ StructField ("ws_coupon_amt" , decimalType (use_decimal , 7 , 2 )),
487
+ StructField ("ws_ext_ship_cost" , decimalType (use_decimal , 7 , 2 )),
488
+ StructField ("ws_net_paid" , decimalType (use_decimal , 7 , 2 )),
489
+ StructField ("ws_net_paid_inc_tax" , decimalType (use_decimal , 7 , 2 )),
490
+ StructField ("ws_net_paid_inc_ship" , decimalType (use_decimal , 7 , 2 )),
491
+ StructField ("ws_net_paid_inc_ship_tax" , decimalType (use_decimal , 7 , 2 )),
492
+ StructField ("ws_net_profit" , decimalType (use_decimal , 7 , 2 ))
494
493
])
495
494
496
495
SCHEMAS ["catalog_sales" ] = StructType ([
@@ -513,21 +512,21 @@ def get_schemas(non_decimal):
513
512
StructField ("cs_promo_sk" , IntegerType ()),
514
513
StructField ("cs_order_number" , IntegerType (), nullable = False ),
515
514
StructField ("cs_quantity" , IntegerType ()),
516
- StructField ("cs_wholesale_cost" , decimalType (non_decimal , 7 , 2 )),
517
- StructField ("cs_list_price" , decimalType (non_decimal , 7 , 2 )),
518
- StructField ("cs_sales_price" , decimalType (non_decimal , 7 , 2 )),
519
- StructField ("cs_ext_discount_amt" , decimalType (non_decimal , 7 , 2 )),
520
- StructField ("cs_ext_sales_price" , decimalType (non_decimal , 7 , 2 )),
521
- StructField ("cs_ext_wholesale_cost" , decimalType (non_decimal , 7 , 2 )),
522
- StructField ("cs_ext_list_price" , decimalType (non_decimal , 7 , 2 )),
523
- StructField ("cs_ext_tax" , decimalType (non_decimal , 7 , 2 )),
524
- StructField ("cs_coupon_amt" , decimalType (non_decimal , 7 , 2 )),
525
- StructField ("cs_ext_ship_cost" , decimalType (non_decimal , 7 , 2 )),
526
- StructField ("cs_net_paid" , decimalType (non_decimal , 7 , 2 )),
527
- StructField ("cs_net_paid_inc_tax" , decimalType (non_decimal , 7 , 2 )),
528
- StructField ("cs_net_paid_inc_ship" , decimalType (non_decimal , 7 , 2 )),
529
- StructField ("cs_net_paid_inc_ship_tax" , decimalType (non_decimal , 7 , 2 )),
530
- StructField ("cs_net_profit" , decimalType (non_decimal , 7 , 2 ))
515
+ StructField ("cs_wholesale_cost" , decimalType (use_decimal , 7 , 2 )),
516
+ StructField ("cs_list_price" , decimalType (use_decimal , 7 , 2 )),
517
+ StructField ("cs_sales_price" , decimalType (use_decimal , 7 , 2 )),
518
+ StructField ("cs_ext_discount_amt" , decimalType (use_decimal , 7 , 2 )),
519
+ StructField ("cs_ext_sales_price" , decimalType (use_decimal , 7 , 2 )),
520
+ StructField ("cs_ext_wholesale_cost" , decimalType (use_decimal , 7 , 2 )),
521
+ StructField ("cs_ext_list_price" , decimalType (use_decimal , 7 , 2 )),
522
+ StructField ("cs_ext_tax" , decimalType (use_decimal , 7 , 2 )),
523
+ StructField ("cs_coupon_amt" , decimalType (use_decimal , 7 , 2 )),
524
+ StructField ("cs_ext_ship_cost" , decimalType (use_decimal , 7 , 2 )),
525
+ StructField ("cs_net_paid" , decimalType (use_decimal , 7 , 2 )),
526
+ StructField ("cs_net_paid_inc_tax" , decimalType (use_decimal , 7 , 2 )),
527
+ StructField ("cs_net_paid_inc_ship" , decimalType (use_decimal , 7 , 2 )),
528
+ StructField ("cs_net_paid_inc_ship_tax" , decimalType (use_decimal , 7 , 2 )),
529
+ StructField ("cs_net_profit" , decimalType (use_decimal , 7 , 2 ))
531
530
])
532
531
533
532
SCHEMAS ["store_sales" ] = StructType ([
@@ -542,18 +541,18 @@ def get_schemas(non_decimal):
542
541
StructField ("ss_promo_sk" , IntegerType ()),
543
542
StructField ("ss_ticket_number" , IntegerType (), nullable = False ),
544
543
StructField ("ss_quantity" , IntegerType ()),
545
- StructField ("ss_wholesale_cost" , decimalType (non_decimal , 7 , 2 )),
546
- StructField ("ss_list_price" , decimalType (non_decimal , 7 , 2 )),
547
- StructField ("ss_sales_price" , decimalType (non_decimal , 7 , 2 )),
548
- StructField ("ss_ext_discount_amt" , decimalType (non_decimal , 7 , 2 )),
549
- StructField ("ss_ext_sales_price" , decimalType (non_decimal , 7 , 2 )),
550
- StructField ("ss_ext_wholesale_cost" , decimalType (non_decimal , 7 , 2 )),
551
- StructField ("ss_ext_list_price" , decimalType (non_decimal , 7 , 2 )),
552
- StructField ("ss_ext_tax" , decimalType (non_decimal , 7 , 2 )),
553
- StructField ("ss_coupon_amt" , decimalType (non_decimal , 7 , 2 )),
554
- StructField ("ss_net_paid" , decimalType (non_decimal , 7 , 2 )),
555
- StructField ("ss_net_paid_inc_tax" , decimalType (non_decimal , 7 , 2 )),
556
- StructField ("ss_net_profit" , decimalType (non_decimal , 7 , 2 ))
544
+ StructField ("ss_wholesale_cost" , decimalType (use_decimal , 7 , 2 )),
545
+ StructField ("ss_list_price" , decimalType (use_decimal , 7 , 2 )),
546
+ StructField ("ss_sales_price" , decimalType (use_decimal , 7 , 2 )),
547
+ StructField ("ss_ext_discount_amt" , decimalType (use_decimal , 7 , 2 )),
548
+ StructField ("ss_ext_sales_price" , decimalType (use_decimal , 7 , 2 )),
549
+ StructField ("ss_ext_wholesale_cost" , decimalType (use_decimal , 7 , 2 )),
550
+ StructField ("ss_ext_list_price" , decimalType (use_decimal , 7 , 2 )),
551
+ StructField ("ss_ext_tax" , decimalType (use_decimal , 7 , 2 )),
552
+ StructField ("ss_coupon_amt" , decimalType (use_decimal , 7 , 2 )),
553
+ StructField ("ss_net_paid" , decimalType (use_decimal , 7 , 2 )),
554
+ StructField ("ss_net_paid_inc_tax" , decimalType (use_decimal , 7 , 2 )),
555
+ StructField ("ss_net_profit" , decimalType (use_decimal , 7 , 2 ))
557
556
])
558
557
return SCHEMAS
559
558
@@ -602,8 +601,8 @@ def store(df, filename, prefix=""):
602
601
603
602
results = {}
604
603
605
- schemas = get_schemas (args .non_decimal )
606
-
604
+ schemas = get_schemas (use_decimal = not args .non_decimal )
605
+
607
606
for fn , schema in schemas .items ():
608
607
results [fn ] = timeit .timeit (lambda : store (load (f"{ fn } { args .input_suffix } " , schema , prefix = args .input_prefix ), f"{ fn } " , args .output_prefix ), number = 1 )
609
608
0 commit comments