@@ -23,31 +23,25 @@ enum Dist
23
23
CIRC
24
24
};
25
25
26
- template <Dist U> constexpr Dist Collect () { return STAR; }
27
- template <> constexpr Dist Collect<CIRC>() { return CIRC; }
26
+ template <Dist U> constexpr Dist Collect () { return (U == CIRC) ? CIRC : STAR; }
28
27
29
28
template <typename T,Dist U,Dist V>
30
29
class DistMatrix { };
31
30
32
31
template <typename T,Dist U,Dist V>
33
32
void AllGather
34
- ( const DistMatrix<T, U, V >& A,
35
- DistMatrix<T,Collect<U>(),Collect<V>()>& B )
36
- {
37
- std::cout << " U=" << U << " , V=" << V << std::endl;
38
- }
33
+ (const DistMatrix<T,U,V>& A, DistMatrix<T,Collect<U>(),Collect<V>()>& B)
34
+ { }
35
+
36
+ #ifdef USE_CONSTEXPR
37
+ template void AllGather (const DistMatrix<int ,MC,MR>& A, DistMatrix<int ,Collect<MC>(),Collect<MR>()>& B);
38
+ template void AllGather (const DistMatrix<double ,CIRC,CIRC>& A, DistMatrix<double ,Collect<CIRC>(),Collect<CIRC>()>& B);
39
+ #else
40
+ template void AllGather (const DistMatrix<int ,MC,MR>& A, DistMatrix<int ,STAR,STAR>& B);
41
+ template void AllGather (const DistMatrix<double ,CIRC,CIRC>& A, DistMatrix<double ,CIRC,CIRC>& B);
42
+ #endif
39
43
40
- int main ( int argc, char * argv[] )
44
+ int main (int argc, char * argv[])
41
45
{
42
- {
43
- DistMatrix<double ,MC,MR> A;
44
- DistMatrix<double ,STAR,STAR> B;
45
- AllGather ( A, B );
46
- }
47
- {
48
- DistMatrix<int ,CIRC,CIRC> A;
49
- DistMatrix<int ,CIRC,CIRC> B;
50
- AllGather ( A, B );
51
- }
52
- return 0 ;
46
+ return 0 ;
53
47
}
0 commit comments