Skip to content

Commit

Permalink
format output for Rosenbrock benchmark in SymJava and Generated C++ code
Browse files Browse the repository at this point in the history
  • Loading branch information
yuemingl committed Apr 10, 2015
1 parent 46f1784 commit 86025a7
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 69 deletions.
113 changes: 61 additions & 52 deletions src/symjava/examples/BenchmarkRosenbrock.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,33 +113,40 @@ public static double test(int N) {

int NN = 100000;
outAry = new double[N];
double checkSumGrad = 0.0;
double xx = 1.0;
begin = System.currentTimeMillis();
double out = 0.0;
for(int j=0; j<NN; j++) {
for(int k=0; k<N; k++)
args[k] += 1e-15;
for(int k=0; k<N; k++) {
xx += 1e-15;
args[k] = xx;
}
numGrad.eval(outAry, args);
for(int k=0; k<N; k++)
out += outAry[k];
checkSumGrad += outAry[k];
}
end = System.currentTimeMillis();
double timeGradEval = (end-begin)/1000.0;

outAry = new double[N*N];
double checkSumHess = 0.0;
xx = 1.0;
begin = System.currentTimeMillis();
for(int j=0; j<NN; j++) {
for(int k=0; k<N; k++)
args[k] += 1e-15;
for(int k=0; k<N; k++) {
xx += 1e-15;
args[k] = xx;
}
numHess.eval(outAry, args);
for(int k=0; k<N; k++) //Trace
out += outAry[k*N+k];
checkSumHess += outAry[k*N+k];
}
end = System.currentTimeMillis();
double timeHessEval = (end-begin)/1000.0;

System.out.println(N+"\t"+timeSym+"\t"+timeGrad+"\t"+timeGradEval+"\t"+timeHess+"\t"+timeHessEval+"\t"+timeCCompile);
System.out.println(N+"\t"+timeSym+"\t"+timeGrad+"\t"+timeGradEval+"\t"+timeHess+"\t"+timeHessEval+"\t"+checkSumGrad+"\t"+checkSumHess+"\t"+timeCCompile);

return out;
return checkSumGrad;
}

public static void print_header(PrintWriter writer) {
Expand Down Expand Up @@ -176,42 +183,48 @@ public static void print_main(PrintWriter writer, int N) {
writer.println(" double durationGrad, durationHess;");
writer.println(" int N = 100000;");
writer.println(" double *args, *outAry;");
writer.println(" double out = 0.0;");
writer.println(" double checkSumGrad = 0.0;");
writer.println(" double checkSumHess = 0.0;");
writer.println(" double xx = 1.0;");
writer.println();
writer.println(" args = new double["+N+"];");
writer.println(" //for(int i=0; i<"+N+"; i++)");
writer.println(" // args[i] = 1.0;");
writer.println(" outAry = new double["+(N*N)+"];");
writer.println(" xx = 1.0;");
writer.println(" start = std::clock();");
writer.println(" for(int i=0; i<N; i++) {");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" xx += 1e-15;");
writer.println(" args[j] = xx;");
writer.println(" }");
writer.println(" grad_"+N+"(args, outAry);");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" checkSumGrad += outAry[j];");
writer.println(" }");
writer.println(" }");
writer.println(" durationGrad = ( std::clock() - start ) / (double) CLOCKS_PER_SEC;");
writer.println(" xx = 1.0;");
writer.println(" start = std::clock();");
writer.println(" for(int i=0; i<N; i++) {");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" xx += 1e-15;");
writer.println(" args[j] = xx;");
writer.println(" }");
writer.println(" hess_"+N+"(args, outAry);");
// writer.println(" for(int j=0; j<"+N+"*"+N+"; j++) {");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" checkSumHess += outAry[j*"+N+"+j];");
writer.println(" }");
writer.println(" }");
writer.println(" durationHess = ( std::clock() - start ) / (double) CLOCKS_PER_SEC;");
writer.println(" cout.precision(6);");
writer.println(" cout<<\"N="+N+": Grad=\"<< durationGrad << \" Hess=\" << durationHess;");
writer.println(" cout.precision(17);");
writer.println(" cout<< \" Grad CheckSum=\" << checkSumGrad << \" Hess CheckSum=\" << checkSumHess << endl;");
writer.println(" delete args;");
writer.println(" delete outAry;");
writer.println();
writer.println(" args = new double["+N+"];");
writer.println(" for(int i=0; i<"+N+"; i++)");
writer.println(" args[i] = 1.0;");
writer.println(" outAry = new double["+(N*N)+"];");
writer.println(" start = std::clock();");
writer.println(" for(int i=0; i<N; i++) {");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" args[j] += 1e-15;");
writer.println(" }");
writer.println(" grad_"+N+"(args, outAry);");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" out += outAry[j];");
writer.println(" }");
// writer.println(" out += outAry["+(N-1)+"];");
writer.println(" }");
writer.println(" durationGrad = ( std::clock() - start ) / (double) CLOCKS_PER_SEC;");
writer.println(" start = std::clock();");
writer.println(" for(int i=0; i<N; i++) {");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" args[j] += 1e-15;");
writer.println(" }");
writer.println(" hess_"+N+"(args, outAry);");
// writer.println(" for(int j=0; j<"+N+"*"+N+"; j++) {");
writer.println(" for(int j=0; j<"+N+"; j++) {");
writer.println(" out += outAry[j*"+N+"+j];");
writer.println(" }");
// writer.println(" out += outAry["+(N*N-1)+"];");
writer.println(" }");
writer.println(" durationHess = ( std::clock() - start ) / (double) CLOCKS_PER_SEC;");
writer.println(" cout<<\"N="+N+": Grad=\"<< durationGrad << \" Hess=\" << durationHess << endl;");
writer.println(" delete args;");
writer.println(" delete outAry;");
writer.println();
writer.println(" cout<<\" Final Value=\" << out << endl;");
writer.println("}");
writer.println("//g++ -O3 benchmark-rosenbrock-manual.cpp -o run");
}
Expand Down Expand Up @@ -260,13 +273,9 @@ public static void print_c_code(PrintWriter pw, SymMatrix hess) {

public static void main(String[] args) {
System.out.println("============Benchmark for Rosenbrock==============");
System.out.println("N|Symbolic Manipulaton|Compile Gradient|Eval Gradient|Compile Hessian|Eval Hessian|C Code Compile");
double out = 0.0;
for(int N=5; N<850; N+=50) {
double tmp = test(N);
System.out.println("Final Value="+tmp);
out += tmp;
}
System.out.println("Final Value="+out);//6.881736000015767E11
System.out.println("N|Symbolic Manipulaton|Compile Gradient|Eval Gradient|Compile Hessian|Eval Hessian|Grad CheckSum|Hess CheckSum|C Code Compile");
for(int N=5; N<850; N+=50)
test(N);
//test(5000);//Exception in thread "main" java.lang.OutOfMemoryError: GC overhead limit exceeded
}
}
12 changes: 6 additions & 6 deletions src/symjava/examples/BenchmarkSqrt.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ public static void test() {
}

int N=10000000;
double x = 0.1;
double xx = 0.1;
double out = 0.0;
for(int i=0; i<funcs.size(); i++) {
long begin = System.currentTimeMillis();
for(int j=0; j<N; j++) {
x += 1e-15;
out += funcs.get(i).apply(x);
xx += 1e-15;
out += funcs.get(i).apply(xx);
}
long end = System.currentTimeMillis();
System.out.println("Time: "+((end-begin)/1000.0)+" expr="+exprs.get(i));
Expand Down Expand Up @@ -87,13 +87,13 @@ public static void testBatchEval() {

int N=10000000/batchLen;
double out = 0.0;
double x = 0.1;
double xx = 0.1;
for(int i=0; i<funcs.size(); i++) {
long begin = System.currentTimeMillis();
for(int j=0; j<N; j++) {
for(int k=0; k<batchLen; k++) {
x += 1e-15;
args[k] = x;
xx += 1e-15;
args[k] = xx;
}
funcs.get(i).apply(outAry, 0, args);
for(int k=0; k<batchLen; k++)
Expand Down
12 changes: 6 additions & 6 deletions src/symjava/examples/BenchmarkTaylor.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ public static void test() {
}

int N=10000000;
double x = 0.1;
double xx = 0.1;
double out = 0.0;
for(int i=0; i<funcs.size(); i++) {
long begin = System.currentTimeMillis();
for(int j=0; j<N; j++) {
x += 1e-15;
out += funcs.get(i).apply(x);
xx += 1e-15;
out += funcs.get(i).apply(xx);
}
long end = System.currentTimeMillis();
System.out.println("Time: "+((end-begin)/1000.0)+" expr="+exprs.get(i));
Expand Down Expand Up @@ -85,13 +85,13 @@ public static void testBatchEval() {

int N=10000000/batchLen;
double out = 0.0;
double x = 0.1;
double xx = 0.1;
for(int i=0; i<funcs.size(); i++) {
long begin = System.currentTimeMillis();
for(int j=0; j<N; j++) {
for(int k=0; k<batchLen; k++) {
x += 1e-15;
args[k] = x;
xx += 1e-15;
args[k] = xx;
}
funcs.get(i).apply(outAry, 0, args);
for(int k=0; k<batchLen; k++)
Expand Down
2 changes: 1 addition & 1 deletion src/symjava/examples/Example2.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public static void example2() {
{7 , 478.019226091914},
{8 , 608.140949270688},
{9 , 754.598868667148},
{10, 916.128818085883},
{10, 916.128818085883},
};

double[] initialGuess = {1, 1, 1};
Expand Down
6 changes: 3 additions & 3 deletions src/symjava/examples/NumericalIntegration.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
public class NumericalIntegration {

public static void main(String[] args) {
// test_1D();
// test_2D();
// test_ND();
test_1D();
test_2D();
test_ND();

//Expr i = Integrate.apply(exp(pow(x,2)), Interval.apply(a, b).setStepSize(0.001));
//BytecodeFunc fi = JIT.compile(new Expr[]{a,b}, i);
Expand Down
2 changes: 1 addition & 1 deletion src/symjava/symbolic/utils/JIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public static BytecodeVecFunc compile(Expr[] args, Expr[] exprs) {
}

public static BytecodeBatchFunc compileBatchFunc(Expr[] args, Expr expr) {
String className = "JITVecFunc_YYYY"+java.util.UUID.randomUUID().toString().replaceAll("-", "");
String className = "JITVecFunc_YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY"+java.util.UUID.randomUUID().toString().replaceAll("-", "");
ClassGen genClass = BytecodeUtils.genClassBytecodeBatchFunc(className,expr, args, true, false);
FuncClassLoader<BytecodeBatchFunc> fcl = new FuncClassLoader<BytecodeBatchFunc>();
return fcl.newInstance(genClass);
Expand Down

0 comments on commit 86025a7

Please sign in to comment.