Skip to content

Commit 99fa840

Browse files
committed
Merge remote-tracking branch 'refs/remotes/upstream/pr_writecoeff' into pr_writecoeff
2 parents 5455760 + 31b59ac commit 99fa840

File tree

3 files changed

+268
-1
lines changed

3 files changed

+268
-1
lines changed

src/examples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ set(EXAMPLE_SOURCES
99
dataloadbal hatom_1d binaryop dielectric hehf 3dharmonic testsolver
1010
testspectralprop dielectric_external_field tiny h2dynamic newsolver testcomplexfunctionsolver
1111
helium_exact density_smoothing siam_example ac_corr dirac-hatom
12-
derivatives array_worldobject)
12+
derivatives array_worldobject writecoeff writecoeff2)
1313

1414
if(LIBXC_FOUND)
1515
list(APPEND EXAMPLE_SOURCES hefxc)

src/examples/writecoeff.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include <complex>
2+
#include <memory>
3+
#include <madness/mra/mra.h>
4+
#include <madness/world/worldmutex.h>
5+
6+
using namespace madness;
7+
8+
static const size_t D = 2;
9+
typedef Vector<double,D> coordT;
10+
typedef std::shared_ptr< FunctionFunctorInterface<std::complex<double>,D> > functorT;
11+
typedef Function<std::complex<double>,D> cfunctionT;
12+
typedef FunctionFactory<std::complex<double>,D> factoryT;
13+
typedef SeparatedConvolution<std::complex<double>,D> operatorT;
14+
15+
static const double R = 1.4; // bond length
16+
static const double L = 32.0*R; // box size
17+
static const long k = 3; // wavelet order
18+
static const double thresh = 1e-3; // precision
19+
20+
static std::complex<double> f(const coordT& r)
21+
{
22+
return std::complex<double>(0.0,2.0);
23+
}
24+
25+
template <typename T, size_t D>
26+
class WriteCoeffImpl : public Mutex {
27+
const int k;
28+
std::ostream& out;
29+
public:
30+
WriteCoeffImpl(const int k, std::ostream& out)
31+
: k(k)
32+
, out(out)
33+
{}
34+
35+
void operator()(const Key<D>& key, const Tensor< T >& t) const {
36+
ScopedMutex obolus(*this);
37+
out << key << " " << t << std::endl;
38+
}
39+
};
40+
41+
template <typename T, size_t D>
42+
class WriteCoeff {
43+
std::shared_ptr<WriteCoeffImpl<T,D>> impl;
44+
45+
public:
46+
WriteCoeff(const int k, std::ostream& out)
47+
: impl(new WriteCoeffImpl<T,D>(k, out))
48+
{}
49+
50+
void operator()(const Key<D>& key, const Tensor< T >& t) const {
51+
(*impl)(key, t);
52+
}
53+
};
54+
55+
56+
57+
int main(int argc, char** argv)
58+
{
59+
initialize(argc, argv);
60+
World world(SafeMPI::COMM_WORLD);
61+
62+
startup(world,argc,argv);
63+
std::cout.precision(6);
64+
65+
FunctionDefaults<D>::set_k(k);
66+
FunctionDefaults<D>::set_thresh(thresh);
67+
FunctionDefaults<D>::set_refine(true);
68+
FunctionDefaults<D>::set_initial_level(2);
69+
FunctionDefaults<D>::set_truncate_mode(0);
70+
FunctionDefaults<D>::set_cubic_cell(-L/2, L/2);
71+
72+
cfunctionT fun = factoryT(world).f(f);
73+
fun.truncate();
74+
75+
cfunctionT sqrt_of_fun = copy(fun);
76+
auto op = WriteCoeff<std::complex<double>,D>(k, std::cout);
77+
fun.unaryop(op);
78+
world.gop.fence();
79+
80+
finalize();
81+
return 0;
82+
}

src/examples/writecoeff2.cc

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#include <complex>
2+
#include <iomanip>
3+
#include <iostream>
4+
#include <madness/mra/mra.h>
5+
#include <memory>
6+
7+
using namespace madness;
8+
9+
static const size_t D = 2;
10+
typedef Vector<double, D> coordT;
11+
typedef Key<D> keyT;
12+
typedef double dataT;// was std::complex<double>
13+
typedef std::shared_ptr<FunctionFunctorInterface<dataT, D>> functorT;
14+
typedef Function<dataT, D> functionT;
15+
typedef FunctionFactory<dataT, D> factoryT;
16+
typedef SeparatedConvolution<dataT, D> operatorT;
17+
18+
static const double L = 4.0;
19+
static const long k = 5; // wavelet order
20+
static const double thresh = 1e-3;// precision
21+
22+
static dataT f(const coordT &r) {
23+
double R = r.normf();
24+
return std::exp(-R * R);
25+
}
26+
27+
template<typename T, std::size_t NDIM>
28+
void write_function_coeffs(const Function<T, NDIM> &f, std::ostream &out, const Key<NDIM> &key) {
29+
const auto &coeffs = f.get_impl()->get_coeffs();
30+
auto it = coeffs.find(key).get();
31+
if (it == coeffs.end()) {
32+
for (int i = 0; i < key.level(); ++i) out << " ";
33+
out << key << " missing --> " << coeffs.owner(key) << "\n";
34+
} else {
35+
const auto &node = it->second;
36+
if (node.has_coeff()) {
37+
auto values = f.get_impl()->coeffs2values(key, node.coeff());
38+
for (int i = 0; i < key.level(); ++i) out << " ";
39+
out << key.level() << " ";
40+
for (int i = 0; i < NDIM; ++i) out << key.translation()[i] << " ";
41+
out << std::endl;
42+
for (size_t i = 0; i < values.size(); i++) out << values.ptr()[i] << " ";
43+
out << std::endl;
44+
}
45+
if (node.has_children()) {
46+
for (KeyChildIterator<NDIM> kit(key); kit; ++kit) { write_function_coeffs<T, NDIM>(f, out, kit.key()); }
47+
}
48+
}
49+
}
50+
51+
template<typename T, std::size_t NDIM>
52+
size_t count_leaf_nodes(const Function<T, NDIM> &f) {
53+
const auto &coeffs = f.get_impl()->get_coeffs();
54+
size_t count = 0;
55+
for (auto it = coeffs.begin(); it != coeffs.end(); ++it) {
56+
const auto &key = it->first;
57+
const auto &node = it->second;
58+
if (node.has_coeff()) { count++; }
59+
}
60+
f.get_impl()->world.gop.sum(count);
61+
return count;
62+
}
63+
64+
template<typename T, std::size_t NDIM>
65+
void write_function(const Function<T, NDIM> &f, std::ostream &out) {
66+
f.reconstruct();
67+
std::cout << "NUMBER OF LEAF NODES: " << count_leaf_nodes(f) << std::endl;
68+
69+
auto flags = out.flags();
70+
auto precision = out.precision();
71+
out << std::setprecision(17);
72+
out << std::scientific;
73+
74+
if (f.get_impl()->world.rank() == 0) {
75+
out << NDIM << std::endl;
76+
const auto &cell = FunctionDefaults<NDIM>::get_cell();
77+
for (int d = 0; d < NDIM; ++d) {
78+
for (int i = 0; i < 2; ++i) out << cell(d, i) << " ";
79+
out << std::endl;
80+
}
81+
out << f.k() << std::endl;
82+
out << count_leaf_nodes(f) << std::endl;
83+
84+
write_function_coeffs(f, out, Key<NDIM>(0));
85+
}
86+
f.get_impl()->world.gop.fence();
87+
88+
out << std::setprecision(precision);
89+
out.setf(flags);
90+
}
91+
92+
template<typename T, std::size_t NDIM>
93+
void read_function_coeffs(Function<T, NDIM> &f, std::istream &in) {
94+
auto &coeffs = f.get_impl()->get_coeffs();
95+
96+
while (true) {
97+
Level n;
98+
Vector<Translation, NDIM> l;
99+
long dims[NDIM];
100+
in >> n;
101+
if (in.eof()) break;
102+
103+
for (int i = 0; i < NDIM; ++i) {
104+
in >> l[i];
105+
dims[i] = f.k();
106+
}
107+
Key<NDIM> key(n, l);
108+
109+
Tensor<T> values(NDIM, dims);
110+
for (size_t i = 0; i < values.size(); i++) in >> values.ptr()[i];
111+
auto t = f.get_impl()->values2coeffs(key, values);
112+
113+
// f.get_impl()->accumulate2(t, coeffs, key);
114+
coeffs.task(key, &FunctionNode<T, NDIM>::accumulate2, t, coeffs, key);
115+
}
116+
}
117+
118+
template<typename T, std::size_t NDIM>
119+
Function<T, NDIM> read_function(World &world, std::istream &in) {
120+
size_t ndim;
121+
in >> ndim;
122+
MADNESS_CHECK(ndim == NDIM);
123+
124+
Tensor<double> cell(NDIM, 2);
125+
for (int d = 0; d < NDIM; ++d) {
126+
for (int i = 0; i < 2; ++i) in >> cell(d, i);
127+
}
128+
FunctionDefaults<NDIM>::set_cell(cell);
129+
130+
int k;
131+
in >> k;
132+
FunctionFactory<T, NDIM> factory(world);
133+
Function<T, NDIM> f(factory.k(k).empty());
134+
world.gop.fence();
135+
136+
read_function_coeffs(f, in);
137+
138+
f.verify_tree();
139+
140+
return f;
141+
}
142+
143+
void test(World &world) {
144+
functionT fun = factoryT(world).f(f);
145+
fun.truncate();
146+
147+
{
148+
double norm = fun.norm2();
149+
if (world.rank() == 0) std::cout << "norm = " << norm << std::endl;
150+
std::ofstream out("fun.dat", std::ios::out);
151+
write_function(fun, out);
152+
out.close();
153+
// fun.print_tree();
154+
}
155+
156+
{
157+
std::ifstream in("fun.dat", std::ios::in);
158+
functionT fun2 = read_function<dataT, D>(world, in);
159+
double norm = fun2.norm2();
160+
if (world.rank() == 0) std::cout << "norm = " << norm << std::endl;
161+
// write_function(fun2,std::cout);
162+
// fun2.print_tree();
163+
double err = (fun - fun2).norm2();
164+
if (world.rank() == 0) std::cout << "error = " << err << std::endl;
165+
}
166+
}
167+
168+
int main(int argc, char **argv) {
169+
World &world = initialize(argc, argv);
170+
startup(world, argc, argv);
171+
std::cout.precision(6);
172+
173+
FunctionDefaults<D>::set_k(k);
174+
FunctionDefaults<D>::set_thresh(thresh);
175+
FunctionDefaults<D>::set_refine(true);
176+
FunctionDefaults<D>::set_initial_level(2);
177+
FunctionDefaults<D>::set_truncate_mode(0);
178+
FunctionDefaults<D>::set_cubic_cell(-L / 2, L / 2);
179+
180+
test(world);
181+
182+
world.gop.fence();
183+
finalize();
184+
return 0;
185+
}

0 commit comments

Comments
 (0)