Skip to content

Commit

Permalink
Merge branch 'main' into unnecessary-config
Browse files Browse the repository at this point in the history
  • Loading branch information
yikf authored Jan 13, 2025
2 parents bc14926 + 318bb21 commit 533132b
Show file tree
Hide file tree
Showing 85 changed files with 986 additions and 423 deletions.
2 changes: 1 addition & 1 deletion backends-clickhouse/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<artifactId>gluten-parent</artifactId>
<groupId>org.apache.gluten</groupId>
<version>1.3.0-SNAPSHOT</version>
<version>1.4.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.sql.delta.DeltaLogFileIndex
import org.apache.spark.sql.delta.rules.CHOptimizeMetadataOnlyDeltaQuery
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.noop.GlutenNoopWriterRule
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
import org.apache.spark.util.SparkPlanRules

Expand Down Expand Up @@ -132,6 +133,7 @@ object CHRuleApi {
c =>
intercept(
SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarPostRules)(c.session)))
injector.injectPost(c => GlutenNoopWriterRule.apply(c.session))

// Gluten columnar: Final rules.
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3362,5 +3362,20 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr
val sql = "select * from test_filter where (c1, c2) in (('a1', 'b1'), ('a2', 'b2'))"
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}

test("GLUTEN-8343: Cast number to decimal") {
val create_table_sql = "create table test_tbl_8343(id bigint, d bigint, f double) using parquet"
val insert_data_sql =
"insert into test_tbl_8343 values(1, 55, 55.12345), (2, 137438953483, 137438953483.12345), (3, -12, -12.123), (4, 0, 0.0001), (5, NULL, NULL), (6, %d, NULL), (7, %d, NULL)"
.format(Double.MaxValue.longValue(), Double.MinValue.longValue())
val query_sql =
"select cast(d as decimal(1, 0)), cast(d as decimal(9, 1)), cast((f-55.12345) as decimal(9,1)), cast(f as decimal(4,2)), " +
"cast(f as decimal(32, 3)), cast(f as decimal(2, 1)), cast(d as decimal(38,3)) from test_tbl_8343"
spark.sql(create_table_sql);
spark.sql(insert_data_sql);
compareResultsAgainstVanillaSpark(query_sql, true, { _ => })
spark.sql("drop table test_tbl_8343")
}

}
// scalastyle:on line.size.limit
2 changes: 1 addition & 1 deletion backends-velox/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<artifactId>gluten-parent</artifactId>
<groupId>org.apache.gluten</groupId>
<version>1.3.0-SNAPSHOT</version>
<version>1.4.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.noop.GlutenNoopWriterRule
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.joins.BaseJoinExec
Expand Down Expand Up @@ -110,6 +111,7 @@ object VeloxRuleApi {
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPost(c => each(c.session)))
injector.injectPost(c => ColumnarCollapseTransformStages(c.glutenConf))
injector.injectPost(c => GlutenNoopWriterRule(c.session))

// Gluten columnar: Final rules.
injector.injectFinal(c => RemoveGlutenTableCacheColumnarToRow(c.session))
Expand Down Expand Up @@ -188,6 +190,7 @@ object VeloxRuleApi {
.getExtendedColumnarPostRules()
.foreach(each => injector.injectPostTransform(c => each(c.session)))
injector.injectPostTransform(c => ColumnarCollapseTransformStages(c.glutenConf))
injector.injectPostTransform(c => GlutenNoopWriterRule(c.session))
injector.injectPostTransform(c => RemoveGlutenTableCacheColumnarToRow(c.session))
injector.injectPostTransform(c => GlutenFallbackReporter(c.glutenConf, c.session))
injector.injectPostTransform(_ => RemoveFallbackTagRule())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.gluten.vectorized.{ArrowColumnarRow, ArrowWritableColumnVector
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CaseWhen, Coalesce, Expression, If, LambdaFunction, NamedExpression, NaNvl, ScalaUDF}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.execution.{ExplainUtils, ProjectExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.hive.HiveUdfUtil
Expand Down Expand Up @@ -75,6 +76,14 @@ case class ColumnarPartialProjectExec(original: ProjectExec, child: SparkPlan)(

override def output: Seq[Attribute] = child.output ++ replacedAliasUdf.map(_.toAttribute)

override def doCanonicalize(): ColumnarPartialProjectExec = {
val canonicalized = original.canonicalized.asInstanceOf[ProjectExec]
this.copy(
original = canonicalized,
child = child.canonicalized
)(replacedAliasUdf.map(QueryPlan.normalizeExpressions(_, child.output)))
}

override def batchType(): Convention.BatchType = BackendsApiManager.getSettings.primaryBatchType

override def rowType0(): Convention.RowType = Convention.RowType.None
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/clickhouse.version
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
CH_ORG=Kyligence
CH_BRANCH=rebase_ch/20250107
CH_COMMIT=01d2a08fb01
CH_BRANCH=rebase_ch/20250110
CH_COMMIT=eafc5ef70b3
94 changes: 65 additions & 29 deletions cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <typeinfo>

namespace DB
{
Expand All @@ -34,6 +35,7 @@ extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int ILLEGAL_COLUMN;
extern const int TYPE_MISMATCH;
extern const int NOT_IMPLEMENTED;
}
}

Expand Down Expand Up @@ -78,25 +80,22 @@ class FunctionCheckDecimalOverflow : public IFunction

DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (!isDecimal(arguments[0].type) || !isInteger(arguments[1].type) || !isInteger(arguments[2].type))
if ((!isDecimal(arguments[0].type) && !isNativeNumber(arguments[0].type)) || !isInteger(arguments[1].type) || !isInteger(arguments[2].type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} {} {} of argument of function {}",
arguments[0].type->getName(),
arguments[1].type->getName(),
arguments[2].type->getName(),
getName());

UInt32 precision = extractArgument(arguments[1]);
UInt32 scale = extractArgument(arguments[2]);

auto return_type = createDecimal<DataTypeDecimal>(precision, scale);
if constexpr (exception_mode == CheckExceptionMode::Null)
{
if (!arguments[0].type->isNullable())
return std::make_shared<DataTypeNullable>(return_type);
}

return return_type;
}

Expand All @@ -113,19 +112,15 @@ class FunctionCheckDecimalOverflow : public IFunction
using Types = std::decay_t<decltype(types)>;
using FromDataType = typename Types::LeftType;
using ToDataType = typename Types::RightType;

if constexpr (IsDataTypeDecimal<FromDataType>)
if constexpr (IsDataTypeDecimal<FromDataType> || IsDataTypeNumber<FromDataType>)
{
using FromFieldType = typename FromDataType::FieldType;
using ColVecType = ColumnDecimal<FromFieldType>;

if (const ColVecType * col_vec = checkAndGetColumn<ColVecType>(src_column.column.get()))
if (const ColumnVectorOrDecimal<FromFieldType> * col_vec = checkAndGetColumn<ColumnVectorOrDecimal<FromFieldType>>(src_column.column.get()))
{
executeInternal<FromFieldType, ToDataType>(*col_vec, result_column, input_rows_count, precision, scale);
executeInternal<FromDataType, ToDataType>(*col_vec, result_column, input_rows_count, precision, scale);
return true;
}
}

throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column while execute function {}", getName());
};

Expand All @@ -146,17 +141,28 @@ class FunctionCheckDecimalOverflow : public IFunction
}

private:
template <typename T, typename ToDataType>
template <typename FromDataType, typename ToDataType>
requires(IsDataTypeDecimal<ToDataType> && (IsDataTypeDecimal<FromDataType> || IsDataTypeNumber<FromDataType>))
static void executeInternal(
const ColumnDecimal<T> & col_source, ColumnPtr & result_column, size_t input_rows_count, UInt32 precision, UInt32 scale_to)
const ColumnVectorOrDecimal<typename FromDataType::FieldType> & col_source, ColumnPtr & result_column, size_t input_rows_count, UInt32 precision, UInt32 scale_to)
{
using ToFieldType = typename ToDataType::FieldType;
using ToColumnType = typename ToDataType::ColumnType;
using T = typename FromDataType::FieldType;

ColumnUInt8::MutablePtr col_null_map_to;
ColumnUInt8::Container * vec_null_map_to [[maybe_unused]] = nullptr;
auto scale_from = col_source.getScale();

UInt32 scale_from = 0;
using ToFieldNativeType = typename ToFieldType::NativeType;
ToFieldNativeType decimal_int_part_max = 0;
ToFieldNativeType decimal_int_part_min = 0;
if constexpr (IsDataTypeDecimal<FromDataType>)
scale_from = col_source.getScale();
else
{
decimal_int_part_max = DecimalUtils::scaleMultiplier<ToFieldNativeType>(precision - scale_to) - 1;
decimal_int_part_min = 1 - DecimalUtils::scaleMultiplier<ToFieldNativeType>(precision - scale_to);
}
if constexpr (exception_mode == CheckExceptionMode::Null)
{
col_null_map_to = ColumnUInt8::create(input_rows_count, false);
Expand All @@ -170,17 +176,17 @@ class FunctionCheckDecimalOverflow : public IFunction
auto & datas = col_source.getData();
for (size_t i = 0; i < input_rows_count; ++i)
{
// bool overflow = outOfDigits<T>(datas[i], precision, scale_from, scale_to);
ToFieldType result;
bool success = convertToDecimalImpl<T, ToDataType>(datas[i], precision, scale_from, scale_to, result);

if (success)
bool success = convertToDecimalImpl<FromDataType, ToDataType>(datas[i], precision, scale_from, scale_to, decimal_int_part_max, decimal_int_part_min, result);
if constexpr (exception_mode == CheckExceptionMode::Null)
{
vec_to[i] = static_cast<ToFieldType>(result);
(*vec_null_map_to)[i] = !success;
}
else
{
vec_to[i] = static_cast<ToFieldType>(0);
if constexpr (exception_mode == CheckExceptionMode::Null)
(*vec_null_map_to)[i] = static_cast<UInt8>(1);
if (success)
vec_to[i] = static_cast<ToFieldType>(result);
else
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Decimal value is overflow.");
}
Expand All @@ -192,20 +198,50 @@ class FunctionCheckDecimalOverflow : public IFunction
result_column = std::move(col_to);
}

template <is_decimal FromFieldType, typename ToDataType>
template <typename FromDataType, typename ToDataType>
requires(IsDataTypeDecimal<ToDataType>)
static bool convertToDecimalImpl(
const FromFieldType & decimal, UInt32 precision_to, UInt32 scale_from, UInt32 scale_to, typename ToDataType::FieldType & result)
const FromDataType::FieldType & value,
UInt32 precision_to,
UInt32 scale_from,
UInt32 scale_to,
typename ToDataType::FieldType::NativeType decimal_int_part_max,
typename ToDataType::FieldType::NativeType decimal_int_part_min,
typename ToDataType::FieldType & result)
{
using FromFieldType = typename FromDataType::FieldType;
if constexpr (std::is_same_v<FromFieldType, Decimal32>)
return convertDecimalsImpl<DataTypeDecimal<Decimal32>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);

return convertDecimalsImpl<DataTypeDecimal<Decimal32>, ToDataType>(value, precision_to, scale_from, scale_to, result);
else if constexpr (std::is_same_v<FromFieldType, Decimal64>)
return convertDecimalsImpl<DataTypeDecimal<Decimal64>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
return convertDecimalsImpl<DataTypeDecimal<Decimal64>, ToDataType>(value, precision_to, scale_from, scale_to, result);
else if constexpr (std::is_same_v<FromFieldType, Decimal128>)
return convertDecimalsImpl<DataTypeDecimal<Decimal128>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
return convertDecimalsImpl<DataTypeDecimal<Decimal128>, ToDataType>(value, precision_to, scale_from, scale_to, result);
else if constexpr (std::is_same_v<FromFieldType, Decimal256>)
return convertDecimalsImpl<DataTypeDecimal<Decimal256>, ToDataType>(value, precision_to, scale_from, scale_to, result);
else if constexpr (IsDataTypeNumber<FromDataType> && !std::is_same_v<FromFieldType, BFloat16>)
return convertNumberToDecimalImpl<DataTypeNumber<FromFieldType>, ToDataType>(value, scale_to, decimal_int_part_max, decimal_int_part_min, result);
else
return convertDecimalsImpl<DataTypeDecimal<Decimal256>, ToDataType>(decimal, precision_to, scale_from, scale_to, result);
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Convert from {} type to decimal type is not implemented.", typeid(value).name());
}

template <typename FromDataType, typename ToDataType>
requires(IsDataTypeNumber<FromDataType> && IsDataTypeDecimal<ToDataType>)
static inline bool convertNumberToDecimalImpl(
const typename FromDataType::FieldType & value,
UInt32 scale,
typename ToDataType::FieldType::NativeType decimal_int_part_max,
typename ToDataType::FieldType::NativeType decimal_int_part_min,
typename ToDataType::FieldType & result)
{
using FromFieldType = typename FromDataType::FieldType;
using ToFieldNativeType = typename ToDataType::FieldType::NativeType;
ToFieldNativeType int_part = 0;
if constexpr (std::is_same_v<FromFieldType, Float32> || std::is_same_v<FromFieldType, Float64>)
int_part = static_cast<ToFieldNativeType>(value);
else
int_part = value;

return int_part >= decimal_int_part_min && int_part <= decimal_int_part_max && tryConvertToDecimal<FromDataType, ToDataType>(value, scale, result);
}

template <typename FromDataType, typename ToDataType>
Expand Down
4 changes: 2 additions & 2 deletions cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -964,10 +964,10 @@ jobject create(JNIEnv * env, const SparkRowInfo & spark_row_info)
{
auto * offsets_arr = env->NewLongArray(spark_row_info.getNumRows());
const auto * offsets_src = spark_row_info.getOffsets().data();
env->SetLongArrayRegion(offsets_arr, 0, spark_row_info.getNumRows(), static_cast<const jlong *>(offsets_src));
env->SetLongArrayRegion(offsets_arr, 0, spark_row_info.getNumRows(), reinterpret_cast<const jlong *>(offsets_src));
auto * lengths_arr = env->NewLongArray(spark_row_info.getNumRows());
const auto * lengths_src = spark_row_info.getLengths().data();
env->SetLongArrayRegion(lengths_arr, 0, spark_row_info.getNumRows(), static_cast<const jlong *>(lengths_src));
env->SetLongArrayRegion(lengths_arr, 0, spark_row_info.getNumRows(), reinterpret_cast<const jlong *>(lengths_src));
int64_t address = reinterpret_cast<int64_t>(spark_row_info.getBufferAddress());
int64_t column_number = spark_row_info.getNumCols();
int64_t total_size = spark_row_info.getTotalBytes();
Expand Down
13 changes: 8 additions & 5 deletions cpp-ch/local-engine/Parser/ExpressionParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,6 @@ ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG &
DataTypePtr denull_input_type = removeNullable(input_type);
DataTypePtr output_type = TypeParser::parseType(substrait_type);
DataTypePtr denull_output_type = removeNullable(output_type);

const ActionsDAG::Node * result_node = nullptr;
if (substrait_type.has_binary())
{
Expand All @@ -336,11 +335,15 @@ ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG &
String function_name = "sparkCastFloatTo" + denull_output_type->getName();
result_node = toFunctionNode(actions_dag, function_name, args);
}
else if ((isDecimal(denull_input_type) && substrait_type.has_decimal()))
else if ((isDecimal(denull_input_type) || isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
{
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), substrait_type.decimal().precision()));
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
result_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args);
int decimal_precision = substrait_type.decimal().precision();
if (decimal_precision)
{
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), decimal_precision));
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), substrait_type.decimal().scale()));
result_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args);
}
}
else if (isMap(denull_input_type) && isString(denull_output_type))
{
Expand Down
28 changes: 13 additions & 15 deletions cpp-ch/local-engine/Storages/MergeTree/SparkStorageMergeTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,24 +469,22 @@ MergeTreeDataWriter::TemporaryPart SparkMergeTreeDataWriter::writeTempPart(
new_data_part->uuid = UUIDHelpers::generateV4();

SyncGuardPtr sync_guard;
if (new_data_part->isStoredOnDisk())
{
/// The name could be non-unique in case of stale files from previous runs.
String full_path = new_data_part->getDataPartStorage().getFullPath();

if (new_data_part->getDataPartStorage().exists())
{
// LOG_WARNING(log, "Removing old temporary directory {}", full_path);
data_part_storage->removeRecursive();
}
/// The name could be non-unique in case of stale files from previous runs.
String full_path = new_data_part->getDataPartStorage().getFullPath();

data_part_storage->createDirectories();
if (new_data_part->getDataPartStorage().exists())
{
LOG_WARNING(log, "Removing old temporary directory {}", full_path);
data_part_storage->removeRecursive();
}

if ((*data.getSettings())[MergeTreeSetting::fsync_part_directory])
{
const auto disk = data_part_volume->getDisk();
sync_guard = disk->getDirectorySyncGuard(full_path);
}
data_part_storage->createDirectories();

if ((*data.getSettings())[MergeTreeSetting::fsync_part_directory])
{
const auto disk = data_part_volume->getDisk();
sync_guard = disk->getDirectorySyncGuard(full_path);
}

/// This effectively chooses minimal compression method:
Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/local_engine_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ JNIEXPORT jobject Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn
local_engine::BlockStripes bs = local_engine::BlockStripeSplitter::split(*block, partition_col_indice_vec, hasBucket, reserve_);

auto * addresses = env->NewLongArray(bs.block_addresses.size());
env->SetLongArrayRegion(addresses, 0, bs.block_addresses.size(), static_cast<const jlong *>(bs.block_addresses.data()));
env->SetLongArrayRegion(addresses, 0, bs.block_addresses.size(), reinterpret_cast<const jlong *>(bs.block_addresses.data()));
auto * indices = env->NewIntArray(bs.heading_row_indice.size());
env->SetIntArrayRegion(indices, 0, bs.heading_row_indice.size(), bs.heading_row_indice.data());

Expand Down
Loading

0 comments on commit 533132b

Please sign in to comment.