You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TEST_F(NVFuserTest, TMP) {
auto fusion_ptr = std::make_unique<Fusion>();
FusionGuard fg(fusion_ptr.get());
Fusion& fusion = *fusion_ptr;
auto tv0 = makeConcreteTensor({2, 4, 8});
fusion.addInput(tv0);
auto tv1 = makeConcreteTensor({2});
fusion.addInput(tv1);
auto tv2 = reshape(tv0, {2, 4, 8}, {2, 32});
auto tv3 = broadcast(tv1, {false, true});
auto tv4 = add(tv2, tv3);
fusion.addOutput(tv4);
fusion.addOutput(tv3);
fusion.printMath();
auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({2, 4, 8}, options);
auto t1 = at::randn({2}, options);
std::vector<c10::IValue> inputs({t0, t1});
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto out_tensors = executor_cache.runFusionWithInputs(inputs);
testValidate(
executor_cache.fusion(), out_tensors, inputs, __LINE__, __FILE__);
}
Inputs:
T0_g_float[iS0{2}, iS1{4}, iS2{8}]
T1_g_float[iS3{2}]
Outputs:
T4_g_float[iS12{2}, iS13{32}]
T3_g_float[iS10{2}, bS11{1}]
%kernel_math {
T2_l_float[iS4{2}, iS9{32}rf] = view( T0_g_float[iS0{2}, iS1{4}, iS2{8}] )
T3_g_float[iS10{2}, bS11{1}]
= broadcast( T1_g_float[iS3{2}], flags = {false, true} )
T4_g_float[iS12{2}, iS13{32}]
= T2_l_float[iS4{2}, iS9{32}rf]
+ T3_g_float[iS10{2}, bS11{1}];
} // %kernel_math
***Runtime***: Try to schedule fusion un-segmented:
Scheduler _expr_eval_ ***rejected*** because : Fusion must contain only a single expression.
Scheduler _no_op_ ***rejected*** because : output has a concrete dimension
Scheduler _matmul_ ***rejected*** because : No matmul patterns were found
Scheduler _reduction_ ***rejected*** because : No reduction op to schedule
Scheduler _resize_ ***rejected*** because : No resize op to schedule
Scheduler _transpose_ ***rejected*** because : cannot find two mismatching inner most dimensions
Scheduler _pointwise_ ***rejected*** because : cannot find reference tensor
Scheduler _inner_persistent_ ***rejected*** because : needs a reduction op
Scheduler _outer_persistent_ ***rejected*** because : needs a reduction op
Scheduler _inner_outer_persistent_ ***rejected*** because : needs a reduction op
I don't see any reason this should be segmented. It should be scheduled as a pointwise kernel with T4 as the reference, but DomainMap rejects it likely because of the use of the permissive map at here.
The text was updated successfully, but these errors were encountered:
This simple fusion is segmented:
I don't see any reason this should be segmented. It should be scheduled as a pointwise kernel with
T4
as the reference, butDomainMap
rejects it likely because of the use of the permissive map at here.The text was updated successfully, but these errors were encountered: