@@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
756756 raise TypeError (
757757 f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
758758 )
759- if not isinstance (ary_mask , dpt .usm_ndarray ):
760- raise TypeError (
761- f"Expecting type dpctl.tensor.usm_ndarray, got { type ( ary_mask ) } "
759+ if isinstance (ary_mask , dpt .usm_ndarray ):
760+ dst_usm_type = dpctl . utils . get_coerced_usm_type (
761+ ( ary . usm_type , ary_mask . usm_type )
762762 )
763- dst_usm_type = dpctl .utils .get_coerced_usm_type (
764- (ary .usm_type , ary_mask .usm_type )
765- )
766- exec_q = dpctl .utils .get_execution_queue (
767- (ary .sycl_queue , ary_mask .sycl_queue )
768- )
769- if exec_q is None :
770- raise dpctl .utils .ExecutionPlacementError (
771- "arrays have different associated queues. "
772- "Use `y.to_device(x.device)` to migrate."
763+ exec_q = dpctl .utils .get_execution_queue (
764+ (ary .sycl_queue , ary_mask .sycl_queue )
765+ )
766+ if exec_q is None :
767+ raise dpctl .utils .ExecutionPlacementError (
768+ "arrays have different associated queues. "
769+ "Use `y.to_device(x.device)` to migrate."
770+ )
771+ elif isinstance (ary_mask , np .ndarray ):
772+ dst_usm_type = ary .usm_type
773+ exec_q = ary .sycl_queue
774+ ary_mask = dpt .asarray (
775+ ary_mask , usm_type = dst_usm_type , sycl_queue = exec_q
776+ )
777+ else :
778+ raise TypeError (
779+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
780+ f"{ type (ary_mask )} "
773781 )
774782 ary_nd = ary .ndim
775783 pp = normalize_axis_index (operator .index (axis ), ary_nd )
@@ -837,35 +845,40 @@ def _nonzero_impl(ary):
837845 return res
838846
839847
840- def _validate_indices (inds , queue_list , usm_type_list ):
848+ def _get_indices_queue_usm_type (inds , queue , usm_type ):
841849 """
842- Utility for validating indices are usm_ndarray of integral dtype or Python
843- integers. At least one must be an array.
850+ Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851+ dtype or Python integers. At least one must be an array.
844852
845853 For each array, the queue and usm type are appended to `queue_list` and
846854 `usm_type_list`, respectively.
847855 """
848- any_usmarray = False
856+ queues = [queue ]
857+ usm_types = [usm_type ]
858+ any_array = False
849859 for ind in inds :
850- if isinstance (ind , dpt .usm_ndarray ):
851- any_usmarray = True
860+ if isinstance (ind , ( np . ndarray , dpt .usm_ndarray ) ):
861+ any_array = True
852862 if ind .dtype .kind not in "ui" :
853863 raise IndexError (
854864 "arrays used as indices must be of integer (or boolean) "
855865 "type"
856866 )
857- queue_list .append (ind .sycl_queue )
858- usm_type_list .append (ind .usm_type )
867+ if isinstance (ind , dpt .usm_ndarray ):
868+ queues .append (ind .sycl_queue )
869+ usm_types .append (ind .usm_type )
859870 elif not isinstance (ind , Integral ):
860871 raise TypeError (
861- "all elements of `ind` expected to be usm_ndarrays "
862- f"or integers, found { type (ind )} "
872+ "all elements of `ind` expected to be usm_ndarrays, "
873+ f"NumPy arrays, or integers, found { type (ind )} "
863874 )
864- if not any_usmarray :
875+ if not any_array :
865876 raise TypeError (
866- "at least one element of `inds` expected to be a usm_ndarray "
877+ "at least one element of `inds` expected to be an array "
867878 )
868- return inds
879+ usm_type = dpctl .utils .get_coerced_usm_type (usm_types )
880+ q = dpctl .utils .get_execution_queue (queues )
881+ return q , usm_type
869882
870883
871884def _prepare_indices_arrays (inds , q , usm_type ):
@@ -922,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
922935 raise ValueError (
923936 "Invalid value for mode keyword, only 0 or 1 is supported"
924937 )
925- queues_ = [
926- ary .sycl_queue ,
927- ]
928- usm_types_ = [
929- ary .usm_type ,
930- ]
931938 if not isinstance (inds , (list , tuple )):
932939 inds = (inds ,)
933940
934- _validate_indices ( inds , queues_ , usm_types_ )
935- res_usm_type = dpctl . utils . get_coerced_usm_type ( usm_types_ )
936- exec_q = dpctl . utils . get_execution_queue ( queues_ )
941+ exec_q , res_usm_type = _get_indices_queue_usm_type (
942+ inds , ary . sycl_queue , ary . usm_type
943+ )
937944 if exec_q is None :
938945 raise dpctl .utils .ExecutionPlacementError (
939946 "Can not automatically determine where to allocate the "
@@ -942,8 +949,7 @@ def _take_multi_index(ary, inds, p, mode=0):
942949 "be associated with the same queue."
943950 )
944951
945- if len (inds ) > 1 :
946- inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
952+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
947953
948954 ind0 = inds [0 ]
949955 ary_sh = ary .shape
@@ -976,21 +982,51 @@ def _place_impl(ary, ary_mask, vals, axis=0):
976982 raise TypeError (
977983 f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
978984 )
979- if not isinstance (ary_mask , dpt .usm_ndarray ):
980- raise TypeError (
981- f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
985+ if isinstance (ary_mask , dpt .usm_ndarray ):
986+ exec_q = dpctl .utils .get_execution_queue (
987+ (
988+ ary .sycl_queue ,
989+ ary_mask .sycl_queue ,
990+ )
982991 )
983- exec_q = dpctl .utils .get_execution_queue (
984- (
985- ary .sycl_queue ,
986- ary_mask .sycl_queue ,
992+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
993+ (
994+ ary .usm_type ,
995+ ary_mask .usm_type ,
996+ )
997+ )
998+ if exec_q is None :
999+ raise dpctl .utils .ExecutionPlacementError (
1000+ "arrays have different associated queues. "
1001+ "Use `y.to_device(x.device)` to migrate."
1002+ )
1003+ elif isinstance (ary_mask , np .ndarray ):
1004+ exec_q = ary .sycl_queue
1005+ coerced_usm_type = ary .usm_type
1006+ ary_mask = dpt .asarray (
1007+ ary_mask , usm_type = coerced_usm_type , sycl_queue = exec_q
1008+ )
1009+ else :
1010+ raise TypeError (
1011+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
1012+ f"{ type (ary_mask )} "
9871013 )
988- )
9891014 if exec_q is not None :
9901015 if not isinstance (vals , dpt .usm_ndarray ):
991- vals = dpt .asarray (vals , dtype = ary .dtype , sycl_queue = exec_q )
1016+ vals = dpt .asarray (
1017+ vals ,
1018+ dtype = ary .dtype ,
1019+ usm_type = coerced_usm_type ,
1020+ sycl_queue = exec_q ,
1021+ )
9921022 else :
9931023 exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1024+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1025+ (
1026+ coerced_usm_type ,
1027+ vals .usm_type ,
1028+ )
1029+ )
9941030 if exec_q is None :
9951031 raise dpctl .utils .ExecutionPlacementError (
9961032 "arrays have different associated queues. "
@@ -1005,7 +1041,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
10051041 )
10061042 mask_nelems = ary_mask .size
10071043 cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
1008- cumsum = dpt .empty (mask_nelems , dtype = cumsum_dt , device = ary_mask .device )
1044+ cumsum = dpt .empty (
1045+ mask_nelems ,
1046+ dtype = cumsum_dt ,
1047+ usm_type = coerced_usm_type ,
1048+ device = ary_mask .device ,
1049+ )
10091050 exec_q = cumsum .sycl_queue
10101051 _manager = dpctl .utils .SequentialOrderManager [exec_q ]
10111052 dep_ev = _manager .submitted_events
@@ -1048,30 +1089,29 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10481089 raise ValueError (
10491090 "Invalid value for mode keyword, only 0 or 1 is supported"
10501091 )
1051- if isinstance (vals , dpt .usm_ndarray ):
1052- queues_ = [ary .sycl_queue , vals .sycl_queue ]
1053- usm_types_ = [ary .usm_type , vals .usm_type ]
1054- else :
1055- queues_ = [
1056- ary .sycl_queue ,
1057- ]
1058- usm_types_ = [
1059- ary .usm_type ,
1060- ]
10611092 if not isinstance (inds , (list , tuple )):
10621093 inds = (inds ,)
10631094
1064- _validate_indices (inds , queues_ , usm_types_ )
1095+ exec_q , coerced_usm_type = _get_indices_queue_usm_type (
1096+ inds , ary .sycl_queue , ary .usm_type
1097+ )
10651098
1066- vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
1067- exec_q = dpctl .utils .get_execution_queue (queues_ )
10681099 if exec_q is not None :
10691100 if not isinstance (vals , dpt .usm_ndarray ):
10701101 vals = dpt .asarray (
1071- vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
1102+ vals ,
1103+ dtype = ary .dtype ,
1104+ usm_type = coerced_usm_type ,
1105+ sycl_queue = exec_q ,
10721106 )
10731107 else :
10741108 exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1109+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1110+ (
1111+ coerced_usm_type ,
1112+ vals .usm_type ,
1113+ )
1114+ )
10751115 if exec_q is None :
10761116 raise dpctl .utils .ExecutionPlacementError (
10771117 "Can not automatically determine where to allocate the "
@@ -1080,8 +1120,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10801120 "be associated with the same queue."
10811121 )
10821122
1083- if len (inds ) > 1 :
1084- inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1123+ inds = _prepare_indices_arrays (inds , exec_q , coerced_usm_type )
10851124
10861125 ind0 = inds [0 ]
10871126 ary_sh = ary .shape
0 commit comments