summaryrefslogtreecommitdiffstats
path: root/src/boost/libs/yap/example/autodiff_library/BinaryOPNode.h
blob: 7abc702aaeb0398179ac6d495d9a43df24d2bea4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
/*
 * BinaryOPNode.h
 *
 *  Created on: 6 Nov 2013
 *      Author: s0965328
 */

#ifndef BINARYOPNODE_H_
#define BINARYOPNODE_H_

#include "OPNode.h"

namespace AutoDiff {

class EdgeSet;

class BinaryOPNode: public OPNode {
public:

	static OPNode* createBinaryOpNode(OPCODE op, Node* left, Node* right);
	virtual ~BinaryOPNode();

	void collect_vnodes(boost::unordered_set<Node*>& nodes,unsigned int& total);
	void eval_function();

	void grad_reverse_0();
	void grad_reverse_1();

	void hess_forward(unsigned int len, double** ret_vec);

	unsigned int hess_reverse_0();
	void hess_reverse_0_init_n_in_arcs();
	void hess_reverse_0_get_values(unsigned int,double&, double&, double&, double&);
	void hess_reverse_1(unsigned int i);
	void hess_reverse_1_init_x_bar(unsigned int);
	void update_x_bar(unsigned int,double);
	void update_w_bar(unsigned int,double);
	void hess_reverse_1_get_xw(unsigned int, double&,double&);
	void hess_reverse_get_x(unsigned int,double& x);
	void hess_reverse_1_clear_index();

	void nonlinearEdges(EdgeSet& a);

	void inorder_visit(int level,ostream& oss);
	string toString(int level);

	Node* right;

private:
	BinaryOPNode(OPCODE op, Node* left, Node* right);
	void calc_eval_function();
	void calc_grad_reverse_0();
	void hess_forward_calc0(unsigned int& len, double* lvec, double* rvec,double* ret_vec);
};

} /* namespace AutoDiff */
#endif /* BINARYOPNODE_H_ */