HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_cooperative_groups.h
1/*
2Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved.
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
34
35#if __cplusplus
36#if !defined(__HIPCC_RTC__)
38#endif
39
40namespace cooperative_groups {
41
50class thread_group {
51 protected:
52 uint32_t _type; // thread_group type
53 uint32_t _size; // total number of threads in the tread_group
54 uint64_t _mask; // Lanemask for coalesced and tiled partitioned group types,
55 // LSB represents lane 0, and MSB represents lane 63
56
57 // Construct a thread group, and set thread group type and other essential
58 // thread group properties. This generic thread group is directly constructed
59 // only when the group is supposed to contain only the calling the thread
60 // (throurh the API - `this_thread()`), and in all other cases, this thread
61 // group object is a sub-object of some other derived thread group object
62 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size = static_cast<uint64_t>(0),
63 uint64_t mask = static_cast<uint64_t>(0)) {
64 _type = type;
65 _size = size;
66 _mask = mask;
67 }
68
69 struct _tiled_info {
70 bool is_tiled;
71 unsigned int size;
72 unsigned int meta_group_rank;
73 unsigned int meta_group_size;
74 };
75
76 struct _coalesced_info {
77 lane_mask member_mask;
78 unsigned int size;
79 struct _tiled_info tiled_info;
80 } coalesced_info;
81
82 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
83 unsigned int tile_size);
84 friend class thread_block;
85
86 public:
87 // Total number of threads in the thread group, and this serves the purpose
88 // for all derived cooperative group types since their `size` is directly
89 // saved during the construction
90 __CG_QUALIFIER__ uint32_t size() const { return _size; }
91 __CG_QUALIFIER__ unsigned int cg_type() const { return _type; }
92 // Rank of the calling thread within [0, size())
93 __CG_QUALIFIER__ uint32_t thread_rank() const;
94 // Is this cooperative group type valid?
95 __CG_QUALIFIER__ bool is_valid() const;
96 // synchronize the threads in the thread group
97 __CG_QUALIFIER__ void sync() const;
98};
122class multi_grid_group : public thread_group {
123 // Only these friend functions are allowed to construct an object of this class
124 // and access its resources
125 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
126
127 protected:
128 // Construct mutli-grid thread group (through the API this_multi_grid())
129 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
130 : thread_group(internal::cg_multi_grid, size) {}
131
132 public:
133 // Number of invocations participating in this multi-grid group. In other
134 // words, the number of GPUs
135 __CG_QUALIFIER__ uint32_t num_grids() { return internal::multi_grid::num_grids(); }
136 // Rank of this invocation. In other words, an ID number within the range
137 // [0, num_grids()) of the GPU, this kernel is running on
138 __CG_QUALIFIER__ uint32_t grid_rank() { return internal::multi_grid::grid_rank(); }
139 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::multi_grid::thread_rank(); }
140 __CG_QUALIFIER__ bool is_valid() const { return internal::multi_grid::is_valid(); }
141 __CG_QUALIFIER__ void sync() const { internal::multi_grid::sync(); }
142};
143
153__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
154 return multi_grid_group(internal::multi_grid::size());
155}
156
165class grid_group : public thread_group {
166 // Only these friend functions are allowed to construct an object of this class
167 // and access its resources
168 friend __CG_QUALIFIER__ grid_group this_grid();
169
170 protected:
171 // Construct grid thread group (through the API this_grid())
172 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
173
174 public:
175 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::grid::thread_rank(); }
176 __CG_QUALIFIER__ bool is_valid() const { return internal::grid::is_valid(); }
177 __CG_QUALIFIER__ void sync() const { internal::grid::sync(); }
178};
179
189__CG_QUALIFIER__ grid_group this_grid() { return grid_group(internal::grid::size()); }
190
200class thread_block : public thread_group {
201 // Only these friend functions are allowed to construct an object of thi
202 // class and access its resources
203 friend __CG_QUALIFIER__ thread_block this_thread_block();
204 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
205 unsigned int tile_size);
206 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent,
207 unsigned int tile_size);
208 protected:
209 // Construct a workgroup thread group (through the API this_thread_block())
210 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
211 : thread_group(internal::cg_workgroup, size) {}
212
213 __CG_QUALIFIER__ thread_group new_tiled_group(unsigned int tile_size) const {
214 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
215 // Invalid tile size, assert
216 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
217 __hip_assert(false && "invalid tile size");
218 }
219
220 auto block_size = size();
221 auto rank = thread_rank();
222 auto partitions = (block_size + tile_size - 1) / tile_size;
223 auto tail = (partitions * tile_size) - block_size;
224 auto partition_size = tile_size - tail * (rank >= (partitions - 1) * tile_size);
225 thread_group tiledGroup = thread_group(internal::cg_tiled_group, partition_size);
226
227 tiledGroup.coalesced_info.tiled_info.size = tile_size;
228 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
229 tiledGroup.coalesced_info.tiled_info.meta_group_rank = rank / tile_size;
230 tiledGroup.coalesced_info.tiled_info.meta_group_size = partitions;
231 return tiledGroup;
232 }
233
234 public:
235 // 3-dimensional block index within the grid
236 __CG_STATIC_QUALIFIER__ dim3 group_index() { return internal::workgroup::group_index(); }
237 // 3-dimensional thread index within the block
238 __CG_STATIC_QUALIFIER__ dim3 thread_index() { return internal::workgroup::thread_index(); }
239 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() { return internal::workgroup::thread_rank(); }
240 __CG_STATIC_QUALIFIER__ uint32_t size() { return internal::workgroup::size(); }
241 __CG_STATIC_QUALIFIER__ bool is_valid() { return internal::workgroup::is_valid(); }
242 __CG_STATIC_QUALIFIER__ void sync() { internal::workgroup::sync(); }
243 __CG_QUALIFIER__ dim3 group_dim() { return internal::workgroup::block_dim(); }
244};
245
255__CG_QUALIFIER__ thread_block this_thread_block() {
256 return thread_block(internal::workgroup::size());
257}
258
267class tiled_group : public thread_group {
268 private:
269 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
270 unsigned int tile_size);
271 friend __CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent,
272 unsigned int tile_size);
273
274 __CG_QUALIFIER__ tiled_group new_tiled_group(unsigned int tile_size) const {
275 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
276
277 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
278 __hip_assert(false && "invalid tile size");
279 }
280
281 if (size() <= tile_size) {
282 return *this;
283 }
284
285 tiled_group tiledGroup = tiled_group(tile_size);
286 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
287 return tiledGroup;
288 }
289
290 protected:
291 explicit __CG_QUALIFIER__ tiled_group(unsigned int tileSize)
292 : thread_group(internal::cg_tiled_group, tileSize) {
293 coalesced_info.tiled_info.size = tileSize;
294 coalesced_info.tiled_info.is_tiled = true;
295 }
296
297 public:
298 __CG_QUALIFIER__ unsigned int size() const { return (coalesced_info.tiled_info.size); }
299
300 __CG_QUALIFIER__ unsigned int thread_rank() const {
301 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
302 }
303
304 __CG_QUALIFIER__ void sync() const {
305 internal::tiled_group::sync();
306 }
307};
308
316class coalesced_group : public thread_group {
317 private:
318 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
319 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size);
320 friend __CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size);
321
322 __CG_QUALIFIER__ coalesced_group new_tiled_group(unsigned int tile_size) const {
323 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
324
325 if (!tile_size || (tile_size > size()) || !pow2) {
326 return coalesced_group(0);
327 }
328
329 // If a tiled group is passed to be partitioned further into a coalesced_group.
330 // prepare a mask for further partitioning it so that it stays coalesced.
331 if (coalesced_info.tiled_info.is_tiled) {
332 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
333 unsigned int masklength = min(static_cast<unsigned int>(size()) - base_offset, tile_size);
334 lane_mask member_mask = static_cast<lane_mask>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
335
336 member_mask <<= (__lane_id() & ~(tile_size - 1));
337 coalesced_group coalesced_tile = coalesced_group(member_mask);
338 coalesced_tile.coalesced_info.tiled_info.is_tiled = true;
339 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
340 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
341 return coalesced_tile;
342 }
343 // Here the parent coalesced_group is not partitioned.
344 else {
345 lane_mask member_mask = 0;
346 unsigned int tile_rank = 0;
347 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
348
349 for (unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
350 lane_mask active = coalesced_info.member_mask & (1 << i);
351 // Make sure the lane is active
352 if (active) {
353 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
354 // Prepare a member_mask that is appropriate for a tile
355 member_mask |= active;
356 tile_rank++;
357 }
358 lanes_to_skip--;
359 }
360 }
361 coalesced_group coalesced_tile = coalesced_group(member_mask);
362 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
363 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
364 (size() + tile_size - 1) / tile_size;
365 return coalesced_tile;
366 }
367 return coalesced_group(0);
368 }
369
370 protected:
371 // Constructor
372 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
373 : thread_group(internal::cg_coalesced_group) {
374 coalesced_info.member_mask = member_mask; // Which threads are active
375 coalesced_info.size = __popcll(coalesced_info.member_mask); // How many threads are active
376 coalesced_info.tiled_info.is_tiled = false; // Not a partitioned group
377 coalesced_info.tiled_info.meta_group_rank = 0;
378 coalesced_info.tiled_info.meta_group_size = 1;
379 }
380
381 public:
382 __CG_QUALIFIER__ unsigned int size() const {
383 return coalesced_info.size;
384 }
385
386 __CG_QUALIFIER__ unsigned int thread_rank() const {
387 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
388 }
389
390 __CG_QUALIFIER__ void sync() const {
391 internal::coalesced_group::sync();
392 }
393
394 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
395 return coalesced_info.tiled_info.meta_group_rank;
396 }
397
398 __CG_QUALIFIER__ unsigned int meta_group_size() const {
399 return coalesced_info.tiled_info.meta_group_size;
400 }
401
402 template <class T>
403 __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
404 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
405
406 srcRank = srcRank % static_cast<int>(size());
407
408 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
409 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
410 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
411
412 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
413 }
414
415 template <class T>
416 __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
417 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
418
419 // Note: The cuda implementation appears to use the remainder of lane_delta
420 // and WARP_SIZE as the shift value rather than lane_delta itself.
421 // This is not described in the documentation and is not done here.
422
423 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
424 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
425 }
426
427 int lane;
428 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
429 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
430 }
431 else {
432 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
433 }
434
435 if (lane == -1) {
436 lane = __lane_id();
437 }
438
439 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
440 }
441
442 template <class T>
443 __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
444 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
445
446 // Note: The cuda implementation appears to use the remainder of lane_delta
447 // and WARP_SIZE as the shift value rather than lane_delta itself.
448 // This is not described in the documentation and is not done here.
449
450 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
451 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
452 }
453
454 int lane;
455 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
456 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
457 }
458 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
459 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
460 }
461
462 if (lane == -1) {
463 lane = __lane_id();
464 }
465
466 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
467 }
468};
469
477__CG_QUALIFIER__ coalesced_group coalesced_threads() {
478 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
479}
480
486__CG_QUALIFIER__ uint32_t thread_group::thread_rank() const {
487 switch (this->_type) {
488 case internal::cg_multi_grid: {
489 return (static_cast<const multi_grid_group*>(this)->thread_rank());
490 }
491 case internal::cg_grid: {
492 return (static_cast<const grid_group*>(this)->thread_rank());
493 }
494 case internal::cg_workgroup: {
495 return (static_cast<const thread_block*>(this)->thread_rank());
496 }
497 case internal::cg_tiled_group: {
498 return (static_cast<const tiled_group*>(this)->thread_rank());
499 }
500 case internal::cg_coalesced_group: {
501 return (static_cast<const coalesced_group*>(this)->thread_rank());
502 }
503 default: {
504 __hip_assert(false && "invalid cooperative group type");
505 return -1;
506 }
507 }
508}
514__CG_QUALIFIER__ bool thread_group::is_valid() const {
515 switch (this->_type) {
516 case internal::cg_multi_grid: {
517 return (static_cast<const multi_grid_group*>(this)->is_valid());
518 }
519 case internal::cg_grid: {
520 return (static_cast<const grid_group*>(this)->is_valid());
521 }
522 case internal::cg_workgroup: {
523 return (static_cast<const thread_block*>(this)->is_valid());
524 }
525 case internal::cg_tiled_group: {
526 return (static_cast<const tiled_group*>(this)->is_valid());
527 }
528 case internal::cg_coalesced_group: {
529 return (static_cast<const coalesced_group*>(this)->is_valid());
530 }
531 default: {
532 __hip_assert(false && "invalid cooperative group type");
533 return false;
534 }
535 }
536}
542__CG_QUALIFIER__ void thread_group::sync() const {
543 switch (this->_type) {
544 case internal::cg_multi_grid: {
545 static_cast<const multi_grid_group*>(this)->sync();
546 break;
547 }
548 case internal::cg_grid: {
549 static_cast<const grid_group*>(this)->sync();
550 break;
551 }
552 case internal::cg_workgroup: {
553 static_cast<const thread_block*>(this)->sync();
554 break;
555 }
556 case internal::cg_tiled_group: {
557 static_cast<const tiled_group*>(this)->sync();
558 break;
559 }
560 case internal::cg_coalesced_group: {
561 static_cast<const coalesced_group*>(this)->sync();
562 break;
563 }
564 default: {
565 __hip_assert(false && "invalid cooperative group type");
566 }
567 }
568}
569
576template <class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy const& g) { return g.size(); }
583template <class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy const& g) {
584 return g.thread_rank();
585}
592template <class CGTy> __CG_QUALIFIER__ bool is_valid(CGTy const& g) { return g.is_valid(); }
599template <class CGTy> __CG_QUALIFIER__ void sync(CGTy const& g) { g.sync(); }
605template <unsigned int tileSize> class tile_base {
606 protected:
607 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
608
609 public:
610 // Rank of the thread within this tile
611 _CG_STATIC_CONST_DECL_ unsigned int thread_rank() {
612 return (internal::workgroup::thread_rank() & (numThreads - 1));
613 }
614
615 // Number of threads within this tile
616 __CG_STATIC_QUALIFIER__ unsigned int size() { return numThreads; }
617};
623template <unsigned int size> class thread_block_tile_base : public tile_base<size> {
624 static_assert(is_valid_tile_size<size>::value,
625 "Tile size is either not a power of 2 or greater than the wavefront size");
626 using tile_base<size>::numThreads;
627
628 public:
629 __CG_STATIC_QUALIFIER__ void sync() {
630 internal::tiled_group::sync();
631 }
632
633 template <class T> __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
634 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
635 return (__shfl(var, srcRank, numThreads));
636 }
637
638 template <class T> __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
639 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
640 return (__shfl_down(var, lane_delta, numThreads));
641 }
642
643 template <class T> __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
644 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
645 return (__shfl_up(var, lane_delta, numThreads));
646 }
647
648 template <class T> __CG_QUALIFIER__ T shfl_xor(T var, unsigned int laneMask) const {
649 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
650 return (__shfl_xor(var, laneMask, numThreads));
651 }
652};
655template <unsigned int tileSize, typename ParentCGTy>
656class parent_group_info {
657public:
658 // Returns the linear rank of the group within the set of tiles partitioned
659 // from a parent group (bounded by meta_group_size)
660 __CG_STATIC_QUALIFIER__ unsigned int meta_group_rank() {
661 return ParentCGTy::thread_rank() / tileSize;
662 }
663
664 // Returns the number of groups created when the parent group was partitioned.
665 __CG_STATIC_QUALIFIER__ unsigned int meta_group_size() {
666 return (ParentCGTy::size() + tileSize - 1) / tileSize;
667 }
668};
669
676template <unsigned int tileSize, class ParentCGTy>
677class thread_block_tile_type : public thread_block_tile_base<tileSize>,
678 public tiled_group,
679 public parent_group_info<tileSize, ParentCGTy> {
680 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
681 typedef thread_block_tile_base<numThreads> tbtBase;
682 protected:
683 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
684 coalesced_info.tiled_info.size = numThreads;
685 coalesced_info.tiled_info.is_tiled = true;
686 }
687 public:
688 using tbtBase::size;
689 using tbtBase::sync;
690 using tbtBase::thread_rank;
691};
692
693// Partial template specialization
694template <unsigned int tileSize>
695class thread_block_tile_type<tileSize, void> : public thread_block_tile_base<tileSize>,
696 public tiled_group
697 {
698 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
699
700 typedef thread_block_tile_base<numThreads> tbtBase;
701
702 protected:
703
704 __CG_QUALIFIER__ thread_block_tile_type(unsigned int meta_group_rank, unsigned int meta_group_size)
705 : tiled_group(numThreads) {
706 coalesced_info.tiled_info.size = numThreads;
707 coalesced_info.tiled_info.is_tiled = true;
708 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
709 coalesced_info.tiled_info.meta_group_size = meta_group_size;
710 }
711
712 public:
713 using tbtBase::size;
714 using tbtBase::sync;
715 using tbtBase::thread_rank;
716
717 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
718 return coalesced_info.tiled_info.meta_group_rank;
719 }
720
721 __CG_QUALIFIER__ unsigned int meta_group_size() const {
722 return coalesced_info.tiled_info.meta_group_size;
723 }
724// end of operative group
728};
729
730
737__CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size) {
738 if (parent.cg_type() == internal::cg_tiled_group) {
739 const tiled_group* cg = static_cast<const tiled_group*>(&parent);
740 return cg->new_tiled_group(tile_size);
741 }
742 else if(parent.cg_type() == internal::cg_coalesced_group) {
743 const coalesced_group* cg = static_cast<const coalesced_group*>(&parent);
744 return cg->new_tiled_group(tile_size);
745 }
746 else {
747 const thread_block* tb = static_cast<const thread_block*>(&parent);
748 return tb->new_tiled_group(tile_size);
749 }
750}
751
752// Thread block type overload
753__CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent, unsigned int tile_size) {
754 return (parent.new_tiled_group(tile_size));
755}
756
757__CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent, unsigned int tile_size) {
758 return (parent.new_tiled_group(tile_size));
759}
760
761// If a coalesced group is passed to be partitioned, it should remain coalesced
762__CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size) {
763 return (parent.new_tiled_group(tile_size));
764}
765
766template <unsigned int size, class ParentCGTy> class thread_block_tile;
767
768namespace impl {
769template <unsigned int size, class ParentCGTy> class thread_block_tile_internal;
770
771template <unsigned int size, class ParentCGTy>
772class thread_block_tile_internal : public thread_block_tile_type<size, ParentCGTy> {
773 protected:
774 template <unsigned int tbtSize, class tbtParentT>
775 __CG_QUALIFIER__ thread_block_tile_internal(
776 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
777 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
778
779 __CG_QUALIFIER__ thread_block_tile_internal(const thread_block& g)
780 : thread_block_tile_type<size, ParentCGTy>() {}
781};
782} // namespace impl
783
784template <unsigned int size, class ParentCGTy>
785class thread_block_tile : public impl::thread_block_tile_internal<size, ParentCGTy> {
786 protected:
787 __CG_QUALIFIER__ thread_block_tile(const ParentCGTy& g)
788 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
789
790 public:
791 __CG_QUALIFIER__ operator thread_block_tile<size, void>() const {
792 return thread_block_tile<size, void>(*this);
793 }
794};
795
796
797template <unsigned int size>
798class thread_block_tile<size, void> : public impl::thread_block_tile_internal<size, void> {
799 template <unsigned int, class ParentCGTy> friend class thread_block_tile;
800
801 protected:
802 public:
803 template <class ParentCGTy>
804 __CG_QUALIFIER__ thread_block_tile(const thread_block_tile<size, ParentCGTy>& g)
805 : impl::thread_block_tile_internal<size, void>(g) {}
806};
807
808template <unsigned int size, class ParentCGTy = void> class thread_block_tile;
809
810namespace impl {
811template <unsigned int size, class ParentCGTy> struct tiled_partition_internal;
812
813template <unsigned int size>
814struct tiled_partition_internal<size, thread_block> : public thread_block_tile<size, thread_block> {
815 __CG_QUALIFIER__ tiled_partition_internal(const thread_block& g)
816 : thread_block_tile<size, thread_block>(g) {}
817};
818
819} // namespace impl
820
826template <unsigned int size, class ParentCGTy>
827__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(const ParentCGTy& g) {
828 static_assert(is_valid_tile_size<size>::value,
829 "Tiled partition with size > wavefront size. Currently not supported ");
830 return impl::tiled_partition_internal<size, ParentCGTy>(g);
831}
832} // namespace cooperative_groups
833
834#endif // __cplusplus
835#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
Device side implementation of cooperative group feature.