@@ -440,24 +440,25 @@ Status DispatchRadixSort(OpKernelContext* context, const int32_t size,
440440 keys_out = mutable_keys_out;
441441 }
442442
443- if (size <= KEYS_PER_ITEM * GROUP_SIZE) {
444- using Rsortor = GroupRadixSortor<
445- KeyT, /* key_per_item==*/ KEYS_PER_ITEM, /* group_size=*/ GROUP_SIZE,
446- /* subgroup_size =*/ SUBGROUP_SIZE, sycl::group<1 >, ValueT>;
447- // Compute the required local memory size
448- size_t local_memory_size = Rsortor::LocalStorage::SIZE;
449- const int32_t num_wg = 1 ;
450- sycl::range<1 > global_range (num_wg * GROUP_SIZE);
451- sycl::range<1 > local_range (GROUP_SIZE);
452-
453- return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
454- Rsortor>(
455- stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
456- local_range, local_memory_size, num_bits);
457- } else {
443+ if (size > KEYS_PER_ITEM * GROUP_SIZE &&
444+ !std::is_floating_point_v<KeyT>) { // DeviceRadixSort will write OOM for
445+ // float/double point types.
458446 return DispatchDeviceRadixSort (context, keys_in, indices_in, keys_out,
459447 indices_out, size);
460448 }
449+ using Rsortor = GroupRadixSortor<
450+ KeyT, /* key_per_item==*/ KEYS_PER_ITEM, /* group_size=*/ GROUP_SIZE,
451+ /* subgroup_size =*/ SUBGROUP_SIZE, sycl::group<1 >, ValueT>;
452+ // Compute the required local memory size
453+ size_t local_memory_size = Rsortor::LocalStorage::SIZE;
454+ const int32_t num_wg = 1 ;
455+ sycl::range<1 > global_range (num_wg * GROUP_SIZE);
456+ sycl::range<1 > local_range (GROUP_SIZE);
457+
458+ return LaunchRadixSortKernel<KeyT, ValueT, KEYS_PER_ITEM, SUBGROUP_SIZE,
459+ Rsortor>(
460+ stream, size, keys_in, indices_in, keys_out, indices_out, global_range,
461+ local_range, local_memory_size, num_bits);
461462}
462463
463464template <typename InputIteratorT, typename OutputIteratorT, typename BinaryOp>
0 commit comments