2626namespace tc {
2727namespace polyhedral {
2828namespace {
29- // This returns the (inclusive) range of the mapping parameter that is active
30- // at node under root given:
31- // 1. a context that is the intersection of the specialization context and
32- // the mapping context
33- // 2. a MappingId
34- // This range corresponds to the blocks/threads active at that particular
35- // location in the tree.
29+ // This returns the (inclusive) range of the mapping parameter "mappingId"
30+ // within the context "mappingContext".
31+ // This range corresponds to the blocks/threads active at the particular
32+ // location in the tree where this mapping is active.
3633//
3734// This is used to tighten the kernel to only launch on the necessary amount
3835// of resources.
@@ -43,23 +40,20 @@ namespace {
4340// Otherwise, the range is asserted bounded on the left and to lie in the
4441// positive half of the integer axis.
4542std::pair<size_t , size_t > rangeOfMappingParameter (
46- const detail::ScheduleTree* root,
47- const detail::ScheduleTree* node,
48- isl::set context,
43+ isl::set mappingContext,
4944 mapping::MappingId mappingId) {
50- auto active =
51- activeDomainPoints (root, node).intersect_params (context).params ();
52- if (!active.involves_param (mappingId)) {
45+ if (!mappingContext.involves_param (mappingId)) {
5346 return std::make_pair (0 , std::numeric_limits<size_t >::max ());
5447 }
55- isl::aff a (isl::aff::param_on_domain_space (active.get_space (), mappingId));
56- auto max = active.max_val (a);
48+ auto space = mappingContext.get_space ();
49+ isl::aff a (isl::aff::param_on_domain_space (space, mappingId));
50+ auto max = mappingContext.max_val (a);
5751 if (max.is_nan () || max.is_infty ()) {
5852 return std::make_pair (0 , std::numeric_limits<size_t >::max ());
5953 }
6054 TC_CHECK (max.is_int ()) << max.to_str ();
6155 TC_CHECK (max.is_nonneg ()) << max.to_str ();
62- auto min = active .min_val (a);
56+ auto min = mappingContext .min_val (a);
6357 TC_CHECK (min.is_int ()) << max.to_str ();
6458 TC_CHECK (min.is_nonneg ()) << max.to_str ();
6559
@@ -68,13 +62,52 @@ std::pair<size_t, size_t> rangeOfMappingParameter(
6862 static_cast <size_t >(max.get_num_si ()));
6963}
7064
71- // Look for nodes with no children.
72- inline std::vector<const detail::ScheduleTree*> leaves (
73- const detail::ScheduleTree* tree) {
74- return functional::Filter (
75- [](const detail::ScheduleTree* st) { return st->numChildren () == 0 ; },
76- detail::ScheduleTree::collect (tree));
65+ /*
66+ * Compute the maximal value attained by the mapping parameter "id".
67+ */
68+ template <typename MappingIdType>
69+ size_t maxValue (const Scop& scop, const MappingIdType& id) {
70+ using namespace polyhedral ::detail;
71+
72+ auto root = scop.scheduleRoot ();
73+ auto params = scop.context ();
74+ size_t sizetMax = std::numeric_limits<size_t >::max ();
75+ size_t max = 0 ;
76+ size_t min = sizetMax;
77+ auto filters = root->collect (root, ScheduleTreeType::Mapping);
78+ filters = functional::Filter (isMappingTo<MappingIdType>, filters);
79+ for (auto p : filters) {
80+ auto mappingNode = p->elemAs <ScheduleTreeElemMapping>();
81+ auto active = activeDomainPoints (root, p).intersect_params (params);
82+ active = active.intersect (mappingNode->filter_ );
83+ auto range = rangeOfMappingParameter (active.params (), id);
84+ min = std::min (min, range.first );
85+ max = std::max (max, range.second );
86+ }
87+ // Ignore min for now but there is a future possibility for shifting
88+ LOG_IF (WARNING, min > 0 )
89+ << " Opportunity for tightening launch bounds with shifting -> min:"
90+ << min;
91+ TC_CHECK (max < sizetMax) << " missing mapping to " << id << *root;
92+ // Inclusive range needs + 1 to translate to sizes
93+ return max + 1 ;
94+ }
95+
96+ /*
97+ * Take grid or block launch bounds "size" and replace them
98+ * by the tightened, actual, launch bounds used in practice.
99+ */
100+ template <typename MappingIdType, typename Size>
101+ Size launchBounds (const Scop& scop, Size size) {
102+ Size tightened;
103+
104+ for (size_t i = 0 ; i < size.view .size (); ++i) {
105+ tightened.view [i] = maxValue (scop, MappingIdType::makeId (i));
106+ }
107+
108+ return tightened;
77109}
110+
78111} // namespace
79112
80113// Takes grid/block launch bounds that have been passed to mapping and
@@ -84,56 +117,9 @@ std::pair<tc::Grid, tc::Block> tightenLaunchBounds(
84117 const Scop& scop,
85118 const tc::Grid& grid,
86119 const tc::Block& block) {
87- auto root = scop.scheduleRoot ();
88- auto params = scop.context ();
89-
90- auto max = [root, params](const mapping::MappingId& id) -> size_t {
91- size_t sizetMax = std::numeric_limits<size_t >::max ();
92- size_t max = 0 ;
93- size_t min = sizetMax;
94- auto nonSyncLeaves = functional::Filter (
95- [root, params](const detail::ScheduleTree* node) {
96- auto f = node->elemAsBase <detail::ScheduleTreeElemFilter>();
97- if (!f) {
98- return true ;
99- }
100- if (f->filter_ .n_set () != 1 ) {
101- std::stringstream ss;
102- ss << " In tree:\n "
103- << *root << " \n not a single set in filter: " << f->filter_ ;
104- throw tightening::TighteningException (ss.str ());
105- }
106- auto single = isl::set::from_union_set (f->filter_ );
107- auto single_id = single.get_tuple_id ();
108- return !Scop::isSyncId (single_id) && !Scop::isWarpSyncId (single_id);
109- },
110- leaves (root));
111- for (auto p : nonSyncLeaves) {
112- auto range = rangeOfMappingParameter (root, p, params, id);
113- min = std::min (min, range.first );
114- max = std::max (max, range.second );
115- }
116- // Ignore min for now but there is a future possibility for shifting
117- LOG_IF (WARNING, min > 0 )
118- << " Opportunity for tightening launch bounds with shifting -> min:"
119- << min;
120- // Inclusive range needs + 1 to translate to sizes
121- if (max < sizetMax) { // avoid overflow
122- return max + 1 ;
123- }
124- return sizetMax;
125- };
126-
127- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
128- // Corner case: take the min with the current size to avoid degenerate
129- // range in the unbounded case.
130120 return std::make_pair (
131- tc::Grid ({std::min (max (BX), BX.mappingSize (grid)),
132- std::min (max (BY), BY.mappingSize (grid)),
133- std::min (max (BZ), BZ.mappingSize (grid))}),
134- tc::Block ({std::min (max (TX), TX.mappingSize (block)),
135- std::min (max (TY), TY.mappingSize (block)),
136- std::min (max (TZ), TZ.mappingSize (block))}));
121+ launchBounds<mapping::BlockId>(scop, grid),
122+ launchBounds<mapping::ThreadId>(scop, block));
137123}
138124} // namespace polyhedral
139125} // namespace tc
0 commit comments