-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalue.c
131 lines (110 loc) · 3.35 KB
/
value.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include "value.h"
// operator N is null. i couldn't NULL it for some reason
ValuePtr Value_create(double data){
ValuePtr newValue = (ValuePtr) malloc (sizeof(Value));
if(newValue != NULL){
newValue->data = data;
newValue->grad = 0.0;
newValue->operator = ' ';
newValue->operand1 = NULL;
newValue->operand2 = NULL;
}
return newValue;
}
ValuePtr Value_createFromOperator(double data, char operator, ValuePtr op1, ValuePtr op2){
ValuePtr newValue = (ValuePtr) malloc (sizeof(Value));
if(newValue != NULL){
newValue->data = data;
newValue->grad = 0.0;
newValue->operator = operator;
newValue->operand1 = op1;
newValue->operand2 = op2;
}
return newValue;
}
ValuePtr Value_add(ValuePtr a, ValuePtr b){
return Value_createFromOperator(a->data + b->data, '+', a, b);
}
ValuePtr Value_multiply(ValuePtr a, ValuePtr b){
return Value_createFromOperator(a->data * b->data, '*', a, b);
}
ValuePtr Value_power(ValuePtr a, ValuePtr b){
return Value_createFromOperator(pow(a->data, b->data), 'p', a, b);
}
ValuePtr Value_tanh(ValuePtr a){
return Value_createFromOperator(regular_tanh(a->data), 't', a, NULL);
}
void Value_backward(ValuePtr value){
Value_backward_helper(NULL, value);
}
void Value_backward_helper(ValuePtr parent, ValuePtr value){
if(!value){
return;
}
if(!parent){
value->grad = 1;
}
else{
switch (parent->operator){
case 't': //tanh
value->grad += parent->grad * (1.0 - pow(regular_tanh(value->data), 2));
break;
case '+':
value->grad += parent->grad;
break;
case '*':
value->grad += parent->grad * get_sibling_data(parent, value);
break;
default:
break;
}
}
Value_backward_helper(value, value->operand1);
Value_backward_helper(value, value->operand2);
}
double get_sibling_data(ValuePtr parent, ValuePtr value){
if(parent->operator == '*'){ //can only be used by a Value whose parent is a multiply
if(parent->operand1 == value){
return parent->operand2->data;
}
return parent->operand1->data;
}
return NaN;
}
double regular_tanh(double x){
return (exp(2*x) - 1)/(exp(2*x) + 1);
}
double Value_getData(ValuePtr value){
return value->data;
}
void Value_print(ValuePtr value){
printf("Data is %.2f\n", value->data);
printf("Grad is %.2f\n", value->grad);
}
void Value_print_comp_graph(ValuePtr value){
Value_print_comp_graph_helper(value, 0);
}
void Value_print_comp_graph_helper(ValuePtr value, int space){
// Base case
if (value == NULL)
return;
// Increase distance between levels
space += TREE_PRINT_DISTANCE;
// Process right child first
Value_print_comp_graph_helper(value->operand1, space);
// Print current node after space
// count
printf("\n");
for (int i = TREE_PRINT_DISTANCE; i < space; i++)
printf(" ");
printf("D:%.2f | G:%.2f | %c\n", value->data, value->grad, value->operator);
// Process left child
Value_print_comp_graph_helper(value->operand2, space);
}
double gen_random(){
double div = RAND_MAX / 2;
return -1 + (rand() / div);
}