/* * BinaryOPNode.cpp * * Created on: 6 Nov 2013 * Author: s0965328 */ #include "auto_diff_types.h" #include "BinaryOPNode.h" #include "PNode.h" #include "Stack.h" #include "Tape.h" #include "EdgeSet.h" #include "Node.h" #include "VNode.h" #include "OPNode.h" #include "ActNode.h" #include "EdgeSet.h" namespace AutoDiff { BinaryOPNode::BinaryOPNode(OPCODE op_, Node* left_, Node* right_):OPNode(op_,left_),right(right_) { } OPNode* BinaryOPNode::createBinaryOpNode(OPCODE op, Node* left, Node* right) { assert(left!=NULL && right!=NULL); OPNode* node = NULL; node = new BinaryOPNode(op,left,right); return node; } BinaryOPNode::~BinaryOPNode() { if(right->getType()!=VNode_Type) { delete right; right = NULL; } } void BinaryOPNode::inorder_visit(int level,ostream& oss){ if(left!=NULL){ left->inorder_visit(level+1,oss); } oss<toString(level)<inorder_visit(level+1,oss); } } void BinaryOPNode::collect_vnodes(boost::unordered_set& nodes,unsigned int& total){ total++; if (left != NULL) { left->collect_vnodes(nodes,total); } if (right != NULL) { right->collect_vnodes(nodes,total); } } void BinaryOPNode::eval_function() { assert(left!=NULL && right!=NULL); left->eval_function(); right->eval_function(); this->calc_eval_function(); } void BinaryOPNode::calc_eval_function() { double x = NaN_Double; double rx = SV->pop_back(); double lx = SV->pop_back(); switch (op) { case OP_PLUS: x = lx + rx; break; case OP_MINUS: x = lx - rx; break; case OP_TIMES: x = lx * rx; break; case OP_DIVID: x = lx / rx; break; case OP_POW: x = pow(lx,rx); break; default: cerr<<"op["<push_back(x); } //1. visiting left if not NULL //2. then, visiting right if not NULL //3. calculating the immediate derivative hu and hv void BinaryOPNode::grad_reverse_0() { assert(left!=NULL && right != NULL); this->adj = 0; left->grad_reverse_0(); right->grad_reverse_0(); this->calc_grad_reverse_0(); } //right left - right most traversal void BinaryOPNode::grad_reverse_1() { assert(right!=NULL && left!=NULL); double r_adj = SD->pop_back()*this->adj; right->update_adj(r_adj); double l_adj = SD->pop_back()*this->adj; left->update_adj(l_adj); right->grad_reverse_1(); left->grad_reverse_1(); } void BinaryOPNode::calc_grad_reverse_0() { assert(left!=NULL && right != NULL); double l_dh = NaN_Double; double r_dh = NaN_Double; double rx = SV->pop_back(); double lx = SV->pop_back(); double x = NaN_Double; switch (op) { case OP_PLUS: x = lx + rx; l_dh = 1; r_dh = 1; break; case OP_MINUS: x = lx - rx; l_dh = 1; r_dh = -1; break; case OP_TIMES: x = lx * rx; l_dh = rx; r_dh = lx; break; case OP_DIVID: x = lx / rx; l_dh = 1 / rx; r_dh = -(lx) / pow(rx, 2); break; case OP_POW: if(right->getType()==PNode_Type){ x = pow(lx,rx); l_dh = rx*pow(lx,(rx-1)); r_dh = 0; } else{ assert(lx>0.0); //otherwise log(lx) is not defined in read number x = pow(lx,rx); l_dh = rx*pow(lx,(rx-1)); r_dh = pow(lx,rx)*log(lx); //this is for x1^x2 when x1=0 cause r_dh become +inf, however d(0^x2)/d(x2) = 0 } break; default: cerr<<"error op not impl"<push_back(x); SD->push_back(l_dh); SD->push_back(r_dh); } void BinaryOPNode::hess_reverse_0_init_n_in_arcs() { this->left->hess_reverse_0_init_n_in_arcs(); this->right->hess_reverse_0_init_n_in_arcs(); this->Node::hess_reverse_0_init_n_in_arcs(); } void BinaryOPNode::hess_reverse_1_clear_index() { this->left->hess_reverse_1_clear_index(); this->right->hess_reverse_1_clear_index(); this->Node::hess_reverse_1_clear_index(); } unsigned int BinaryOPNode::hess_reverse_0() { assert(this->left!=NULL && right!=NULL); if(index==0) { unsigned int lindex=0, rindex=0; lindex = left->hess_reverse_0(); rindex = right->hess_reverse_0(); assert(lindex!=0 && rindex !=0); II->set(lindex); II->set(rindex); double rx,rx_bar,rw,rw_bar; double lx,lx_bar,lw,lw_bar; double x,x_bar,w,w_bar; double r_dh, l_dh; right->hess_reverse_0_get_values(rindex,rx,rx_bar,rw,rw_bar); left->hess_reverse_0_get_values(lindex,lx,lx_bar,lw,lw_bar); switch(op) { case OP_PLUS: // cout<<"lindex="<getType()==PNode_Type) { x = pow(lx,rx); x_bar = 0; l_dh = rx*pow(lx,(rx-1)); r_dh = 0; w = lw * l_dh + rw * r_dh; w_bar = 0; } else { assert(lx>0.0); //otherwise log(lx) undefined in real number x = pow(lx,rx); x_bar = 0; l_dh = rx*pow(lx,(rx-1)); r_dh = pow(lx,rx)*log(lx); //log(lx) cause -inf when lx=0; w = lw * l_dh + rw * r_dh; w_bar = 0; } break; default: cerr<<"op["<set(x); TT->set(x_bar); TT->set(w); TT->set(w_bar); TT->set(l_dh); TT->set(r_dh); assert(TT->index == TT->index); index = TT->index; } return index; } void BinaryOPNode::hess_reverse_0_get_values(unsigned int i,double& x, double& x_bar, double& w, double& w_bar) { --i; // skip the r_dh (ie, dh/du) --i; // skip the l_dh (ie. dh/dv) w_bar = TT->get(--i); w = TT->get(--i); x_bar = TT->get(--i); x = TT->get(--i); } void BinaryOPNode::hess_reverse_1(unsigned int i) { n_in_arcs--; if(n_in_arcs==0) { assert(right!=NULL && left!=NULL); unsigned int rindex = II->get(--(II->index)); unsigned int lindex = II->get(--(II->index)); // cout<<"ri["<toString(0)<get(--i); double l_dh = TT->get(--i); double w_bar = TT->get(--i); --i; //skip w double x_bar = TT->get(--i); --i; //skip x double lw_bar=0,rw_bar=0; double lw=0,lx=0; left->hess_reverse_1_get_xw(lindex,lw,lx); double rw=0,rx=0; right->hess_reverse_1_get_xw(rindex,rw,rx); switch(op) { case OP_PLUS: assert(l_dh==1); assert(r_dh==1); lw_bar += w_bar*l_dh; rw_bar += w_bar*r_dh; break; case OP_MINUS: assert(l_dh==1); assert(r_dh==-1); lw_bar += w_bar*l_dh; rw_bar += w_bar*r_dh; break; case OP_TIMES: assert(rx == l_dh); assert(lx == r_dh); lw_bar += w_bar*rx; lw_bar += x_bar*lw*0 + x_bar*rw*1; rw_bar += w_bar*lx; rw_bar += x_bar*lw*1 + x_bar*rw*0; break; case OP_DIVID: lw_bar += w_bar*l_dh; lw_bar += x_bar*lw*0 + x_bar*rw*-1/(pow(rx,2)); rw_bar += w_bar*r_dh; rw_bar += x_bar*lw*-1/pow(rx,2) + x_bar*rw*2*lx/pow(rx,3); break; case OP_POW: if(right->getType()==PNode_Type){ lw_bar += w_bar*l_dh; lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + 0; rw_bar += w_bar*r_dh; assert(r_dh==0.0); rw_bar += 0; } else{ assert(lx>0.0); //otherwise log(lx) is not define in Real lw_bar += w_bar*l_dh; lw_bar += x_bar*lw*pow(lx,rx-2)*rx*(rx-1) + x_bar*rw*pow(lx,rx-1)*(rx*log(lx)+1); //cause log(lx)=-inf when rw_bar += w_bar*r_dh; rw_bar += x_bar*lw*pow(lx,rx-1)*(rx*log(lx)+1) + x_bar*rw*pow(lx,rx)*pow(log(lx),2); } break; default: cerr<<"op["<update_x_bar(rindex,rx_bar); left->update_x_bar(lindex,lx_bar); right->update_w_bar(rindex,rw_bar); left->update_w_bar(lindex,lw_bar); this->right->hess_reverse_1(rindex); this->left->hess_reverse_1(lindex); } } void BinaryOPNode::hess_reverse_1_init_x_bar(unsigned int i) { TT->at(i-5) = 1; } void BinaryOPNode::update_x_bar(unsigned int i ,double v) { TT->at(i-5) += v; } void BinaryOPNode::update_w_bar(unsigned int i ,double v) { TT->at(i-3) += v; } void BinaryOPNode::hess_reverse_1_get_xw(unsigned int i,double& w,double& x) { w = TT->get(i-4); x = TT->get(i-6); } void BinaryOPNode::hess_reverse_get_x(unsigned int i,double& x) { x = TT->get(i-6); } void BinaryOPNode::nonlinearEdges(EdgeSet& edges) { for(list::iterator it=edges.edges.begin();it!=edges.edges.end();) { Edge e = *it; if(e.a==this || e.b == this){ if(e.a == this && e.b == this) { Edge e1(left,left); Edge e2(right,right); Edge e3(left,right); edges.insertEdge(e1); edges.insertEdge(e2); edges.insertEdge(e3); } else { Node* o = e.a==this? e.b: e.a; Edge e1(left,o); Edge e2(right,o); edges.insertEdge(e1); edges.insertEdge(e2); } it = edges.edges.erase(it); } else { it++; } } Edge e1(left,right); Edge e2(left,left); Edge e3(right,right); switch(op) { case OP_PLUS: case OP_MINUS: //do nothing for linear operator break; case OP_TIMES: edges.insertEdge(e1); break; case OP_DIVID: edges.insertEdge(e1); edges.insertEdge(e3); break; case OP_POW: edges.insertEdge(e1); edges.insertEdge(e2); edges.insertEdge(e3); break; default: cerr<<"op["<nonlinearEdges(edges); right->nonlinearEdges(edges); } #if FORWARD_ENABLED void BinaryOPNode::hess_forward(unsigned int len, double** ret_vec) { double* lvec = NULL; double* rvec = NULL; if(left!=NULL){ left->hess_forward(len,&lvec); } if(right!=NULL){ right->hess_forward(len,&rvec); } *ret_vec = new double[len]; hess_forward_calc0(len,lvec,rvec,*ret_vec); //delete lvec, rvec delete[] lvec; delete[] rvec; } void BinaryOPNode::hess_forward_calc0(unsigned int& len, double* lvec, double* rvec, double* ret_vec) { double hu = NaN_Double, hv= NaN_Double; double lval = NaN_Double, rval = NaN_Double; double val = NaN_Double; unsigned int index = 0; switch (op) { case OP_PLUS: rval = SV->pop_back(); lval = SV->pop_back(); val = lval + rval; SV->push_back(val); //calculate the first order derivatives for(unsigned int i=0;ipop_back(); lval = SV->pop_back(); val = lval + rval; SV->push_back(val); //calculate the first order derivatives for(unsigned int i=0;ipop_back(); lval = SV->pop_back(); val = lval * rval; SV->push_back(val); hu = rval; hv = lval; //calculate the first order derivatives for(unsigned int i =0;ipop_back(); lval = SV->pop_back(); val = pow(lval,rval); SV->push_back(val); if(left->getType()==PNode_Type && right->getType()==PNode_Type) { std::fill_n(ret_vec,len,0); } else { hu = rval*pow(lval,(rval-1)); hv = pow(lval,rval)*log(lval); if(left->getType()==PNode_Type) { double coeff = pow(log(lval),2)*pow(lval,rval); //calculate the first order derivatives for(unsigned int i =0;igetType()==PNode_Type) { double coeff = rval*(rval-1)*pow(lval,rval-2); //calculate the first order derivatives for(unsigned int i =0;ipop_back(); val = sin(lval); SV->push_back(val); hu = cos(lval); double coeff; coeff = -val; //=sin(left->val); -- and avoid cross initialisation //calculate the first order derivatives for(unsigned int i =0;i