-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'refs/remotes/upstream/pr_writecoeff' in…
…to pr_writecoeff
- Loading branch information
Showing
3 changed files
with
268 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
#include <complex> | ||
#include <memory> | ||
#include <madness/mra/mra.h> | ||
#include <madness/world/worldmutex.h> | ||
|
||
using namespace madness; | ||
|
||
static const size_t D = 2; | ||
typedef Vector<double,D> coordT; | ||
typedef std::shared_ptr< FunctionFunctorInterface<std::complex<double>,D> > functorT; | ||
typedef Function<std::complex<double>,D> cfunctionT; | ||
typedef FunctionFactory<std::complex<double>,D> factoryT; | ||
typedef SeparatedConvolution<std::complex<double>,D> operatorT; | ||
|
||
static const double R = 1.4; // bond length | ||
static const double L = 32.0*R; // box size | ||
static const long k = 3; // wavelet order | ||
static const double thresh = 1e-3; // precision | ||
|
||
static std::complex<double> f(const coordT& r) | ||
{ | ||
return std::complex<double>(0.0,2.0); | ||
} | ||
|
||
template <typename T, size_t D> | ||
class WriteCoeffImpl : public Mutex { | ||
const int k; | ||
std::ostream& out; | ||
public: | ||
WriteCoeffImpl(const int k, std::ostream& out) | ||
: k(k) | ||
, out(out) | ||
{} | ||
|
||
void operator()(const Key<D>& key, const Tensor< T >& t) const { | ||
ScopedMutex obolus(*this); | ||
out << key << " " << t << std::endl; | ||
} | ||
}; | ||
|
||
template <typename T, size_t D> | ||
class WriteCoeff { | ||
std::shared_ptr<WriteCoeffImpl<T,D>> impl; | ||
|
||
public: | ||
WriteCoeff(const int k, std::ostream& out) | ||
: impl(new WriteCoeffImpl<T,D>(k, out)) | ||
{} | ||
|
||
void operator()(const Key<D>& key, const Tensor< T >& t) const { | ||
(*impl)(key, t); | ||
} | ||
}; | ||
|
||
|
||
|
||
int main(int argc, char** argv) | ||
{ | ||
initialize(argc, argv); | ||
World world(SafeMPI::COMM_WORLD); | ||
|
||
startup(world,argc,argv); | ||
std::cout.precision(6); | ||
|
||
FunctionDefaults<D>::set_k(k); | ||
FunctionDefaults<D>::set_thresh(thresh); | ||
FunctionDefaults<D>::set_refine(true); | ||
FunctionDefaults<D>::set_initial_level(2); | ||
FunctionDefaults<D>::set_truncate_mode(0); | ||
FunctionDefaults<D>::set_cubic_cell(-L/2, L/2); | ||
|
||
cfunctionT fun = factoryT(world).f(f); | ||
fun.truncate(); | ||
|
||
cfunctionT sqrt_of_fun = copy(fun); | ||
auto op = WriteCoeff<std::complex<double>,D>(k, std::cout); | ||
fun.unaryop(op); | ||
world.gop.fence(); | ||
|
||
finalize(); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
#include <complex> | ||
#include <iomanip> | ||
#include <iostream> | ||
#include <madness/mra/mra.h> | ||
#include <memory> | ||
|
||
using namespace madness; | ||
|
||
static const size_t D = 2; | ||
typedef Vector<double, D> coordT; | ||
typedef Key<D> keyT; | ||
typedef double dataT;// was std::complex<double> | ||
typedef std::shared_ptr<FunctionFunctorInterface<dataT, D>> functorT; | ||
typedef Function<dataT, D> functionT; | ||
typedef FunctionFactory<dataT, D> factoryT; | ||
typedef SeparatedConvolution<dataT, D> operatorT; | ||
|
||
static const double L = 4.0; | ||
static const long k = 5; // wavelet order | ||
static const double thresh = 1e-3;// precision | ||
|
||
static dataT f(const coordT &r) { | ||
double R = r.normf(); | ||
return std::exp(-R * R); | ||
} | ||
|
||
template<typename T, std::size_t NDIM> | ||
void write_function_coeffs(const Function<T, NDIM> &f, std::ostream &out, const Key<NDIM> &key) { | ||
const auto &coeffs = f.get_impl()->get_coeffs(); | ||
auto it = coeffs.find(key).get(); | ||
if (it == coeffs.end()) { | ||
for (int i = 0; i < key.level(); ++i) out << " "; | ||
out << key << " missing --> " << coeffs.owner(key) << "\n"; | ||
} else { | ||
const auto &node = it->second; | ||
if (node.has_coeff()) { | ||
auto values = f.get_impl()->coeffs2values(key, node.coeff()); | ||
for (int i = 0; i < key.level(); ++i) out << " "; | ||
out << key.level() << " "; | ||
for (int i = 0; i < NDIM; ++i) out << key.translation()[i] << " "; | ||
out << std::endl; | ||
for (size_t i = 0; i < values.size(); i++) out << values.ptr()[i] << " "; | ||
out << std::endl; | ||
} | ||
if (node.has_children()) { | ||
for (KeyChildIterator<NDIM> kit(key); kit; ++kit) { write_function_coeffs<T, NDIM>(f, out, kit.key()); } | ||
} | ||
} | ||
} | ||
|
||
template<typename T, std::size_t NDIM> | ||
size_t count_leaf_nodes(const Function<T, NDIM> &f) { | ||
const auto &coeffs = f.get_impl()->get_coeffs(); | ||
size_t count = 0; | ||
for (auto it = coeffs.begin(); it != coeffs.end(); ++it) { | ||
const auto &key = it->first; | ||
const auto &node = it->second; | ||
if (node.has_coeff()) { count++; } | ||
} | ||
f.get_impl()->world.gop.sum(count); | ||
return count; | ||
} | ||
|
||
template<typename T, std::size_t NDIM> | ||
void write_function(const Function<T, NDIM> &f, std::ostream &out) { | ||
f.reconstruct(); | ||
std::cout << "NUMBER OF LEAF NODES: " << count_leaf_nodes(f) << std::endl; | ||
|
||
auto flags = out.flags(); | ||
auto precision = out.precision(); | ||
out << std::setprecision(17); | ||
out << std::scientific; | ||
|
||
if (f.get_impl()->world.rank() == 0) { | ||
out << NDIM << std::endl; | ||
const auto &cell = FunctionDefaults<NDIM>::get_cell(); | ||
for (int d = 0; d < NDIM; ++d) { | ||
for (int i = 0; i < 2; ++i) out << cell(d, i) << " "; | ||
out << std::endl; | ||
} | ||
out << f.k() << std::endl; | ||
out << count_leaf_nodes(f) << std::endl; | ||
|
||
write_function_coeffs(f, out, Key<NDIM>(0)); | ||
} | ||
f.get_impl()->world.gop.fence(); | ||
|
||
out << std::setprecision(precision); | ||
out.setf(flags); | ||
} | ||
|
||
template<typename T, std::size_t NDIM> | ||
void read_function_coeffs(Function<T, NDIM> &f, std::istream &in) { | ||
auto &coeffs = f.get_impl()->get_coeffs(); | ||
|
||
while (true) { | ||
Level n; | ||
Vector<Translation, NDIM> l; | ||
long dims[NDIM]; | ||
in >> n; | ||
if (in.eof()) break; | ||
|
||
for (int i = 0; i < NDIM; ++i) { | ||
in >> l[i]; | ||
dims[i] = f.k(); | ||
} | ||
Key<NDIM> key(n, l); | ||
|
||
Tensor<T> values(NDIM, dims); | ||
for (size_t i = 0; i < values.size(); i++) in >> values.ptr()[i]; | ||
auto t = f.get_impl()->values2coeffs(key, values); | ||
|
||
// f.get_impl()->accumulate2(t, coeffs, key); | ||
coeffs.task(key, &FunctionNode<T, NDIM>::accumulate2, t, coeffs, key); | ||
} | ||
} | ||
|
||
template<typename T, std::size_t NDIM> | ||
Function<T, NDIM> read_function(World &world, std::istream &in) { | ||
size_t ndim; | ||
in >> ndim; | ||
MADNESS_CHECK(ndim == NDIM); | ||
|
||
Tensor<double> cell(NDIM, 2); | ||
for (int d = 0; d < NDIM; ++d) { | ||
for (int i = 0; i < 2; ++i) in >> cell(d, i); | ||
} | ||
FunctionDefaults<NDIM>::set_cell(cell); | ||
|
||
int k; | ||
in >> k; | ||
FunctionFactory<T, NDIM> factory(world); | ||
Function<T, NDIM> f(factory.k(k).empty()); | ||
world.gop.fence(); | ||
|
||
read_function_coeffs(f, in); | ||
|
||
f.verify_tree(); | ||
|
||
return f; | ||
} | ||
|
||
void test(World &world) { | ||
functionT fun = factoryT(world).f(f); | ||
fun.truncate(); | ||
|
||
{ | ||
double norm = fun.norm2(); | ||
if (world.rank() == 0) std::cout << "norm = " << norm << std::endl; | ||
std::ofstream out("fun.dat", std::ios::out); | ||
write_function(fun, out); | ||
out.close(); | ||
// fun.print_tree(); | ||
} | ||
|
||
{ | ||
std::ifstream in("fun.dat", std::ios::in); | ||
functionT fun2 = read_function<dataT, D>(world, in); | ||
double norm = fun2.norm2(); | ||
if (world.rank() == 0) std::cout << "norm = " << norm << std::endl; | ||
// write_function(fun2,std::cout); | ||
// fun2.print_tree(); | ||
double err = (fun - fun2).norm2(); | ||
if (world.rank() == 0) std::cout << "error = " << err << std::endl; | ||
} | ||
} | ||
|
||
int main(int argc, char **argv) { | ||
World &world = initialize(argc, argv); | ||
startup(world, argc, argv); | ||
std::cout.precision(6); | ||
|
||
FunctionDefaults<D>::set_k(k); | ||
FunctionDefaults<D>::set_thresh(thresh); | ||
FunctionDefaults<D>::set_refine(true); | ||
FunctionDefaults<D>::set_initial_level(2); | ||
FunctionDefaults<D>::set_truncate_mode(0); | ||
FunctionDefaults<D>::set_cubic_cell(-L / 2, L / 2); | ||
|
||
test(world); | ||
|
||
world.gop.fence(); | ||
finalize(); | ||
return 0; | ||
} |