summaryrefslogtreecommitdiffstats
path: root/src/boost/libs/yap/example/autodiff_library/UaryOPNode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/boost/libs/yap/example/autodiff_library/UaryOPNode.cpp')
-rw-r--r--src/boost/libs/yap/example/autodiff_library/UaryOPNode.cpp375
1 files changed, 375 insertions, 0 deletions
diff --git a/src/boost/libs/yap/example/autodiff_library/UaryOPNode.cpp b/src/boost/libs/yap/example/autodiff_library/UaryOPNode.cpp
new file mode 100644
index 000000000..73e2711cb
--- /dev/null
+++ b/src/boost/libs/yap/example/autodiff_library/UaryOPNode.cpp
@@ -0,0 +1,375 @@
+/*
+ * UaryOPNode.cpp
+ *
+ * Created on: 6 Nov 2013
+ * Author: s0965328
+ */
+
+#include "UaryOPNode.h"
+#include "BinaryOPNode.h"
+#include "PNode.h"
+#include "Stack.h"
+#include "Tape.h"
+#include "Edge.h"
+#include "EdgeSet.h"
+#include "auto_diff_types.h"
+
+#include <list>
+
+using namespace std;
+
+namespace AutoDiff {
+
+UaryOPNode::UaryOPNode(OPCODE op_, Node* left): OPNode(op_,left) {
+}
+
+OPNode* UaryOPNode::createUnaryOpNode(OPCODE op, Node* left)
+{
+ assert(left!=NULL);
+ OPNode* node = NULL;
+ if(op == OP_SQRT)
+ {
+ double param = 0.5;
+ node = BinaryOPNode::createBinaryOpNode(OP_POW,left,new PNode(param));
+ }
+ else if(op == OP_NEG)
+ {
+ double param = -1;
+ node = BinaryOPNode::createBinaryOpNode(OP_TIMES,left,new PNode(param));
+ }
+ else
+ {
+ node = new UaryOPNode(op,left);
+ }
+ return node;
+}
+
+UaryOPNode::~UaryOPNode() {
+
+}
+
+
+void UaryOPNode::inorder_visit(int level,ostream& oss){
+ if(left!=NULL){
+ left->inorder_visit(level+1,oss);
+ }
+ oss<<this->toString(level)<<endl;
+}
+
+void UaryOPNode::collect_vnodes(boost::unordered_set<Node*>& nodes,unsigned int& total)
+{
+ total++;
+ if(left!=NULL){
+ left->collect_vnodes(nodes,total);
+ }
+}
+
+void UaryOPNode::eval_function()
+{
+ if(left!=NULL){
+ left->eval_function();
+ }
+ this->calc_eval_function();
+}
+
+//1. visiting left if not NULL
+//2. then, visiting right if not NULL
+//3. calculating the immediate derivative hu and hv
+void UaryOPNode::grad_reverse_0(){
+ assert(left!=NULL);
+ this->adj = 0;
+ left->grad_reverse_0();
+ this->calc_grad_reverse_0();
+}
+
+//right left - right most traversal
+void UaryOPNode::grad_reverse_1()
+{
+ assert(left!=NULL);
+ double l_adj = SD->pop_back()*this->adj;
+ left->update_adj(l_adj);
+ left->grad_reverse_1();
+}
+
+void UaryOPNode::calc_grad_reverse_0()
+{
+ assert(left!=NULL);
+ double hu = NaN_Double;
+ double lval = SV->pop_back();
+ double val = NaN_Double;
+ switch (op)
+ {
+ case OP_SIN:
+ val = sin(lval);
+ hu = cos(lval);
+ break;
+ case OP_COS:
+ val = cos(lval);
+ hu = -sin(lval);
+ break;
+ default:
+ cerr<<"error op not impl"<<endl;
+ break;
+ }
+ SV->push_back(val);
+ SD->push_back(hu);
+}
+
+void UaryOPNode::calc_eval_function()
+{
+ double lval = SV->pop_back();
+ double val = NaN_Double;
+ switch (op)
+ {
+ case OP_SIN:
+ assert(left!=NULL);
+ val = sin(lval);
+ break;
+ case OP_COS:
+ assert(left!=NULL);
+ val = cos(lval);
+ break;
+ default:
+ cerr<<"op["<<op<<"] not yet implemented!!"<<endl;
+ assert(false);
+ break;
+ }
+ SV->push_back(val);
+}
+
+void UaryOPNode::hess_reverse_0_init_n_in_arcs()
+{
+ this->left->hess_reverse_0_init_n_in_arcs();
+ this->Node::hess_reverse_0_init_n_in_arcs();
+}
+
+void UaryOPNode::hess_reverse_1_clear_index()
+{
+ this->left->hess_reverse_1_clear_index();
+ this->Node::hess_reverse_1_clear_index();
+}
+
+unsigned int UaryOPNode::hess_reverse_0()
+{
+ assert(left!=NULL);
+ if(index==0)
+ {
+ unsigned int lindex=0;
+ lindex = left->hess_reverse_0();
+ assert(lindex!=0);
+ II->set(lindex);
+ double lx,lx_bar,lw,lw_bar;
+ double x,x_bar,w,w_bar;
+ double l_dh;
+ switch(op)
+ {
+ case OP_SIN:
+ assert(left != NULL);
+ left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar);
+ x = sin(lx);
+ x_bar = 0;
+ l_dh = cos(lx);
+ w = lw*l_dh;
+ w_bar = 0;
+ break;
+ case OP_COS:
+ assert(left!=NULL);
+ left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar);
+ x = cos(lx);
+ x_bar = 0;
+ l_dh = -sin(lx);
+ w = lw*l_dh;
+ w_bar = 0;
+ break;
+ default:
+ cerr<<"op["<<op<<"] not yet implemented!"<<endl;
+ assert(false);
+ break;
+ }
+ TT->set(x);
+ TT->set(x_bar);
+ TT->set(w);
+ TT->set(w_bar);
+ TT->set(l_dh);
+ assert(TT->index == TT->index);
+ index = TT->index;
+ }
+ return index;
+}
+
+void UaryOPNode::hess_reverse_0_get_values(unsigned int i,double& x, double& x_bar, double& w, double& w_bar)
+{
+ --i; // skip the l_dh (ie, dh/du)
+ w_bar = TT->get(--i);
+ w = TT->get(--i);
+ x_bar = TT->get(--i);
+ x = TT->get(--i);
+}
+
+void UaryOPNode::hess_reverse_1(unsigned int i)
+{
+ n_in_arcs--;
+ if(n_in_arcs==0)
+ {
+ double lindex = II->get(--(II->index));
+ // cout<<"li["<<lindex<<"]\t"<<this->toString(0)<<endl;
+ double l_dh = TT->get(--i);
+ double w_bar = TT->get(--i);
+ --i; //skip w
+ double x_bar = TT->get(--i);
+ --i; //skip x
+ // cout<<"i["<<i<<"]"<<endl;
+
+ assert(left!=NULL);
+ left->update_x_bar(lindex,x_bar*l_dh);
+ double lw_bar = 0;
+ double lw = 0,lx = 0;
+ left->hess_reverse_1_get_xw(lindex,lw,lx);
+ switch(op)
+ {
+ case OP_SIN:
+ assert(l_dh == cos(lx));
+ lw_bar += w_bar*l_dh;
+ lw_bar += x_bar*lw*(-sin(lx));
+ break;
+ case OP_COS:
+ assert(l_dh == -sin(lx));
+ lw_bar += w_bar*l_dh;
+ lw_bar += x_bar*lw*(-cos(lx));
+ break;
+ default:
+ cerr<<"op["<<op<<"] not yet implemented!"<<endl;
+ break;
+ }
+ left->update_w_bar(lindex,lw_bar);
+ left->hess_reverse_1(lindex);
+ }
+}
+
+void UaryOPNode::hess_reverse_1_init_x_bar(unsigned int i)
+{
+ TT->at(i-4) = 1;
+}
+
+void UaryOPNode::update_x_bar(unsigned int i ,double v)
+{
+ TT->at(i-4) += v;
+}
+void UaryOPNode::update_w_bar(unsigned int i ,double v)
+{
+ TT->at(i-2) += v;
+}
+void UaryOPNode::hess_reverse_1_get_xw(unsigned int i,double& w,double& x)
+{
+ w = TT->get(i-3);
+ x = TT->get(i-5);
+}
+void UaryOPNode::hess_reverse_get_x(unsigned int i, double& x)
+{
+ x = TT->get(i-5);
+}
+
+void UaryOPNode::nonlinearEdges(EdgeSet& edges)
+{
+ for(list<Edge>::iterator it=edges.edges.begin();it!=edges.edges.end();)
+ {
+ Edge& e = *it;
+ if(e.a == this || e.b == this){
+ if(e.a == this && e.b == this)
+ {
+ Edge e1(left,left);
+ edges.insertEdge(e1);
+ }
+ else{
+ Node* o = e.a==this?e.b:e.a;
+ Edge e1(left,o);
+ edges.insertEdge(e1);
+ }
+ it = edges.edges.erase(it);
+ }
+ else
+ {
+ it++;
+ }
+ }
+
+ Edge e1(left,left);
+ switch(op)
+ {
+ case OP_SIN:
+ edges.insertEdge(e1);
+ break;
+ case OP_COS:
+ edges.insertEdge(e1);
+ break;
+ default:
+ cerr<<"op["<<op<<"] is not yet implemented !"<<endl;
+ assert(false);
+ break;
+ }
+ left->nonlinearEdges(edges);
+}
+
+#if FORWARD_ENABLED
+void UaryOPNode::hess_forward(unsigned int len, double** ret_vec)
+{
+ double* lvec = NULL;
+ if(left!=NULL){
+ left->hess_forward(len,&lvec);
+ }
+
+ *ret_vec = new double[len];
+ this->hess_forward_calc0(len,lvec,*ret_vec);
+ delete[] lvec;
+}
+
+void UaryOPNode::hess_forward_calc0(unsigned int& len, double* lvec, double* ret_vec)
+{
+ double hu = NaN_Double;
+ double lval = NaN_Double;
+ double val = NaN_Double;
+ unsigned int index = 0;
+ switch (op)
+ {
+ case OP_SIN:
+ assert(left!=NULL);
+ lval = SV->pop_back();
+ val = sin(lval);
+ SV->push_back(val);
+ hu = cos(lval);
+
+ double coeff;
+ coeff = -val; //=sin(left->val); -- and avoid cross initialisation
+ //calculate the first order derivatives
+ for(unsigned int i =0;i<AutoDiff::num_var;++i)
+ {
+ ret_vec[i] = hu*lvec[i] + 0;
+ }
+ //calculate the second order
+ index = AutoDiff::num_var;
+ for(unsigned int i=0;i<AutoDiff::num_var;++i)
+ {
+ for(unsigned int j=i;j<AutoDiff::num_var;++j)
+ {
+ ret_vec[index] = hu*lvec[index] + lvec[i] * coeff * lvec[j] + 0 + 0;
+ ++index;
+ }
+ }
+ assert(index==len);
+ break;
+ default:
+ cerr<<"op["<<op<<"] not yet implemented!";
+ break;
+ }
+}
+#endif
+
+string UaryOPNode::toString(int level)
+{
+ ostringstream oss;
+ string s(level,'\t');
+ oss<<s<<"[UaryOPNode]("<<op<<")";
+ return oss.str();
+}
+
+} /* namespace AutoDiff */