Skip to content

Commit 68a65ed

Browse files
authored
Merge pull request brucefan1983#728 from brucefan1983/lmax
more flexible lmax
2 parents f73e277 + ad2824d commit 68a65ed

File tree

8 files changed

+181
-322
lines changed

8 files changed

+181
-322
lines changed

src/force/nep3.cu

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ static __global__ void find_descriptor(
668668
weight_left +
669669
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
670670
weight_right;
671-
accumulate_s(d12, x12, y12, z12, gn12, s);
671+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
672672
#else
673673
float fc12;
674674
int t2 = g_type[n2];
@@ -690,16 +690,10 @@ static __global__ void find_descriptor(
690690
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
691691
gn12 += fn12[k] * annmb.c[c_index];
692692
}
693-
accumulate_s(d12, x12, y12, z12, gn12, s);
693+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
694694
#endif
695695
}
696-
if (paramb.num_L == paramb.L_max) {
697-
find_q(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
698-
} else if (paramb.num_L == paramb.L_max + 1) {
699-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
700-
} else {
701-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
702-
}
696+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
703697
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
704698
g_sum_fxyz[(n * NUM_OF_ABC + abc) * N + n1] = s[abc];
705699
}
@@ -983,15 +977,7 @@ static __global__ void find_partial_force_angular(
983977
g_gn_angular[index_left_all] * weight_left + g_gn_angular[index_right_all] * weight_right;
984978
float gnp12 = g_gnp_angular[index_left_all] * weight_left +
985979
g_gnp_angular[index_right_all] * weight_right;
986-
if (paramb.num_L == paramb.L_max) {
987-
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
988-
} else if (paramb.num_L == paramb.L_max + 1) {
989-
accumulate_f12_with_4body(
990-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
991-
} else {
992-
accumulate_f12_with_5body(
993-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
994-
}
980+
accumulate_f12(paramb.L_max, paramb.num_L, n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
995981
}
996982
#else
997983
float fc12, fcp12;
@@ -1019,15 +1005,7 @@ static __global__ void find_partial_force_angular(
10191005
gn12 += fn12[k] * annmb.c[c_index];
10201006
gnp12 += fnp12[k] * annmb.c[c_index];
10211007
}
1022-
if (paramb.num_L == paramb.L_max) {
1023-
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1024-
} else if (paramb.num_L == paramb.L_max + 1) {
1025-
accumulate_f12_with_4body(
1026-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1027-
} else {
1028-
accumulate_f12_with_5body(
1029-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1030-
}
1008+
accumulate_f12(paramb.L_max, paramb.num_L, n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
10311009
}
10321010
#endif
10331011
g_f12x[index] = f12[0];
@@ -1683,7 +1661,7 @@ static __global__ void find_descriptor(
16831661
weight_left +
16841662
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
16851663
weight_right;
1686-
accumulate_s(d12, x12, y12, z12, gn12, s);
1664+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
16871665
#else
16881666
float fc12;
16891667
int t2 = g_type[n2];
@@ -1705,16 +1683,10 @@ static __global__ void find_descriptor(
17051683
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
17061684
gn12 += fn12[k] * annmb.c[c_index];
17071685
}
1708-
accumulate_s(d12, x12, y12, z12, gn12, s);
1686+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
17091687
#endif
17101688
}
1711-
if (paramb.num_L == paramb.L_max) {
1712-
find_q(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
1713-
} else if (paramb.num_L == paramb.L_max + 1) {
1714-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
1715-
} else {
1716-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
1717-
}
1689+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
17181690
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
17191691
g_sum_fxyz[(n * NUM_OF_ABC + abc) * N + n1] = s[abc];
17201692
}

src/force/nep3_multigpu.cu

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ static __global__ void find_descriptor(
942942
weight_left +
943943
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
944944
weight_right;
945-
accumulate_s(d12, x12, y12, z12, gn12, s);
945+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
946946
#else
947947
float fc12;
948948
int t2 = g_type[n2];
@@ -964,16 +964,10 @@ static __global__ void find_descriptor(
964964
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
965965
gn12 += fn12[k] * annmb.c[c_index];
966966
}
967-
accumulate_s(d12, x12, y12, z12, gn12, s);
967+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
968968
#endif
969969
}
970-
if (paramb.num_L == paramb.L_max) {
971-
find_q(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
972-
} else if (paramb.num_L == paramb.L_max + 1) {
973-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
974-
} else {
975-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
976-
}
970+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
977971
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
978972
g_sum_fxyz[(n * NUM_OF_ABC + abc) * N + n1] = s[abc];
979973
}
@@ -1254,15 +1248,7 @@ static __global__ void find_partial_force_angular(
12541248
g_gn_angular[index_left_all] * weight_left + g_gn_angular[index_right_all] * weight_right;
12551249
float gnp12 = g_gnp_angular[index_left_all] * weight_left +
12561250
g_gnp_angular[index_right_all] * weight_right;
1257-
if (paramb.num_L == paramb.L_max) {
1258-
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1259-
} else if (paramb.num_L == paramb.L_max + 1) {
1260-
accumulate_f12_with_4body(
1261-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1262-
} else {
1263-
accumulate_f12_with_5body(
1264-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1265-
}
1251+
accumulate_f12(paramb.L_max, paramb.num_L, n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
12661252
}
12671253
#else
12681254
float fc12, fcp12;
@@ -1290,15 +1276,7 @@ static __global__ void find_partial_force_angular(
12901276
gn12 += fn12[k] * annmb.c[c_index];
12911277
gnp12 += fnp12[k] * annmb.c[c_index];
12921278
}
1293-
if (paramb.num_L == paramb.L_max) {
1294-
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1295-
} else if (paramb.num_L == paramb.L_max + 1) {
1296-
accumulate_f12_with_4body(
1297-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1298-
} else {
1299-
accumulate_f12_with_5body(
1300-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
1301-
}
1279+
accumulate_f12(paramb.L_max, paramb.num_L, n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
13021280
}
13031281
#endif
13041282
g_f12x[index] = f12[0];
@@ -2088,7 +2066,7 @@ static __global__ void find_descriptor(
20882066
weight_left +
20892067
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
20902068
weight_right;
2091-
accumulate_s(d12, x12, y12, z12, gn12, s);
2069+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
20922070
#else
20932071
float fc12;
20942072
int t2 = g_type[n2];
@@ -2110,16 +2088,10 @@ static __global__ void find_descriptor(
21102088
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
21112089
gn12 += fn12[k] * annmb.c[c_index];
21122090
}
2113-
accumulate_s(d12, x12, y12, z12, gn12, s);
2091+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
21142092
#endif
21152093
}
2116-
if (paramb.num_L == paramb.L_max) {
2117-
find_q(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
2118-
} else if (paramb.num_L == paramb.L_max + 1) {
2119-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
2120-
} else {
2121-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
2122-
}
2094+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
21232095
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
21242096
g_sum_fxyz[(n * NUM_OF_ABC + abc) * N + n1] = s[abc];
21252097
}

src/force/nep3_small_box.cuh

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ static __global__ void find_descriptor_small_box(
259259
weight_left +
260260
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
261261
weight_right;
262-
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
262+
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
263263
#else
264264
float fc12;
265265
int t2 = g_type[n2];
@@ -281,16 +281,10 @@ static __global__ void find_descriptor_small_box(
281281
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
282282
gn12 += fn12[k] * annmb.c[c_index];
283283
}
284-
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
284+
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
285285
#endif
286286
}
287-
if (paramb.num_L == paramb.L_max) {
288-
find_q(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
289-
} else if (paramb.num_L == paramb.L_max + 1) {
290-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
291-
} else {
292-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
293-
}
287+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
294288
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
295289
g_sum_fxyz[(n * NUM_OF_ABC + abc) * N + n1] = s[abc];
296290
}
@@ -454,7 +448,7 @@ static __global__ void find_descriptor_small_box(
454448
weight_left +
455449
g_gn_angular[(index_right * paramb.num_types_sq + t12) * (paramb.n_max_angular + 1) + n] *
456450
weight_right;
457-
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
451+
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
458452
#else
459453
float fc12;
460454
int t2 = g_type[n2];
@@ -476,16 +470,10 @@ static __global__ void find_descriptor_small_box(
476470
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
477471
gn12 += fn12[k] * annmb.c[c_index];
478472
}
479-
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
473+
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
480474
#endif
481475
}
482-
if (paramb.num_L == paramb.L_max) {
483-
find_q(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
484-
} else if (paramb.num_L == paramb.L_max + 1) {
485-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
486-
} else {
487-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
488-
}
476+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
489477
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
490478
g_sum_fxyz[(n * NUM_OF_ABC + abc) * N + n1] = s[abc];
491479
}
@@ -698,15 +686,7 @@ static __global__ void find_force_angular_small_box(
698686
g_gn_angular[index_left_all] * weight_left + g_gn_angular[index_right_all] * weight_right;
699687
float gnp12 = g_gnp_angular[index_left_all] * weight_left +
700688
g_gnp_angular[index_right_all] * weight_right;
701-
if (paramb.num_L == paramb.L_max) {
702-
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
703-
} else if (paramb.num_L == paramb.L_max + 1) {
704-
accumulate_f12_with_4body(
705-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
706-
} else {
707-
accumulate_f12_with_5body(
708-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
709-
}
689+
accumulate_f12(paramb.L_max, paramb.num_L, n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
710690
}
711691
#else
712692
float fc12, fcp12;
@@ -733,15 +713,7 @@ static __global__ void find_force_angular_small_box(
733713
gn12 += fn12[k] * annmb.c[c_index];
734714
gnp12 += fnp12[k] * annmb.c[c_index];
735715
}
736-
if (paramb.num_L == paramb.L_max) {
737-
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
738-
} else if (paramb.num_L == paramb.L_max + 1) {
739-
accumulate_f12_with_4body(
740-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
741-
} else {
742-
accumulate_f12_with_5body(
743-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
744-
}
716+
accumulate_f12(paramb.L_max, paramb.num_L, n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
745717
}
746718
#endif
747719
double s_sxx = 0.0;

src/main_nep/nep3.cu

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,9 @@ static __global__ void find_descriptors_angular(
221221
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
222222
gn12 += fn12[k] * annmb.c[c_index];
223223
}
224-
accumulate_s(d12, x12, y12, z12, gn12, s);
225-
}
226-
if (paramb.num_L == paramb.L_max) {
227-
find_q(paramb.n_max_angular + 1, n, s, q);
228-
} else if (paramb.num_L == paramb.L_max + 1) {
229-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q);
230-
} else {
231-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q);
224+
accumulate_s(paramb.L_max, d12, x12, y12, z12, gn12, s);
232225
}
226+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q);
233227
for (int abc = 0; abc < NUM_OF_ABC; ++abc) {
234228
g_sum_fxyz[(n * NUM_OF_ABC + abc) * N + n1] = s[abc];
235229
}
@@ -726,15 +720,7 @@ static __global__ void find_force_angular(
726720
gn12 += fn12[k] * annmb.c[c_index];
727721
gnp12 += fnp12[k] * annmb.c[c_index];
728722
}
729-
if (paramb.num_L == paramb.L_max) {
730-
accumulate_f12(n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
731-
} else if (paramb.num_L == paramb.L_max + 1) {
732-
accumulate_f12_with_4body(
733-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
734-
} else {
735-
accumulate_f12_with_5body(
736-
n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
737-
}
723+
accumulate_f12(paramb.L_max, paramb.num_L, n, paramb.n_max_angular + 1, d12, r12, gn12, gnp12, Fp, sum_fxyz, f12);
738724
}
739725

740726
atomicAdd(&g_fx[n1], f12[0]);

src/main_nep/parameters.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,8 +727,11 @@ void Parameters::parse_l_max(const char** param, int num_param)
727727
if (!is_valid_int(param[1], &L_max)) {
728728
PRINT_INPUT_ERROR("l_max for 3-body descriptors should be an integer.\n");
729729
}
730-
if (L_max != 4) {
731-
PRINT_INPUT_ERROR("l_max for 3-body descriptors should = 4.");
730+
if (L_max < 0) {
731+
PRINT_INPUT_ERROR("l_max for 3-body descriptors should >= 0.");
732+
}
733+
if (L_max > 4) {
734+
PRINT_INPUT_ERROR("l_max for 3-body descriptors should <= 4.");
732735
}
733736

734737
if (num_param >= 3) {
@@ -738,6 +741,9 @@ void Parameters::parse_l_max(const char** param, int num_param)
738741
if (L_max_4body != 0 && L_max_4body != 2) {
739742
PRINT_INPUT_ERROR("l_max for 4-body descriptors should = 0 or 2.");
740743
}
744+
if (L_max < L_max_4body) {
745+
PRINT_INPUT_ERROR("l_max_4body should <= l_max_3body.");
746+
}
741747
}
742748

743749
if (num_param == 4) {

src/mc/nep_energy.cu

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,9 @@ static __global__ void find_energy_nep(
360360
c_index += t1 * paramb.num_types + t2 + paramb.num_c_radial;
361361
gn12 += fn12[k] * annmb.c[c_index];
362362
}
363-
accumulate_s(d12, r12[0], r12[1], r12[2], gn12, s);
364-
}
365-
if (paramb.num_L == paramb.L_max) {
366-
find_q(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
367-
} else if (paramb.num_L == paramb.L_max + 1) {
368-
find_q_with_4body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
369-
} else {
370-
find_q_with_5body(paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
363+
accumulate_s(paramb.L_max, d12, r12[0], r12[1], r12[2], gn12, s);
371364
}
365+
find_q(paramb.L_max, paramb.num_L, paramb.n_max_angular + 1, n, s, q + (paramb.n_max_radial + 1));
372366
}
373367

374368
// nomalize descriptor

0 commit comments

Comments
 (0)