Coverage for src/rev/Gradients.py : 18%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python3
3from FADiff import FADiff
6class Scal:
7 _tmp_part_der = 0
9 def __init__(self, val, der=None, parents=[],
10 roots=[], name=None, new_input=False):
11 self._val = val
12 self._grad = 0 # TODO: Not sure if need
13 if new_input:
14 self._der = {}
15 for var in FADiff._revscal_inputs:
16 self._der[var] = 0
17 var._der[self] = 0
18 self._der[self] = der
19 FADiff._revscal_inputs.append(self)
20 else:
21 self._der = der
22 self._name = name
23 self._parents = parents
24 self._root_inputs = roots
26 def __add__(self, other):
27 try:
28 der = {}
29 for var, part_der in self._der.items():
30 der[var] = part_der + other._der.get(var)
31 parents = [self, other]
32 roots = self._set_roots(self, other)
33 return Scal(self._val + other._val, der, parents, roots)
34 except AttributeError:
35 parents = [self]
36 roots = self._set_roots(self)
37 return Scal(self._val + other, self._der, parents, roots)
39 def __radd__(self, other):
40 return self.__add__(other)
42 # TODO
43 def __mul__(self, other):
44 try:
45 der = {}
46 for var, part_der in self._der.items():
47 der[var] = self._val * other._der.get(var) +\
48 part_der * other._val
49 parents = [self, other]
50 roots = self._set_roots(self, other)
51 return Scal(self._val * other._val, der, parents, roots)
52 except AttributeError:
53 der = {}
54 for var, part_der in self._der.items():
55 der[var] = part_der * other
56 parents = [self]
57 roots = self._set_roots(self, other)
58 return Scal(self._val * other, der, parents, roots)
60 def __rmul__(self, other):
61 return self.__mul__(other)
63 @property
64 def val(self):
65 return [self._val]
67 @property
68 def der(self):
69 parents = []
70 for var in self._der.keys():
71 if var in self._root_inputs:
72 Scal._tmp_part_der = 1
73 self._back_trace(var)
74 parents.append(Scal._tmp_part_der)
75 return parents
77 def _back_trace(self, var):
78 if not self._parents: # Base case (at root var)
79 return
80 parent = None # TODO: Raise exception if no parent found?
81 for par in self._parents: # Find parent with partial der wrt var
82 if var == par or var in par._root_inputs:
83 parent = par
84 break
85 Scal._tmp_part_der = Scal._tmp_part_der * self._der.get(var)
86 parent._back_trace(var)
88 @staticmethod
89 def _set_roots(var1, var2=None):
90 roots = []
91 if not var1._parents and var1 in FADiff._revscal_inputs: # Root parent
92 roots.append(var1)
93 else:
94 for root in var1._root_inputs:
95 roots.append(root)
96 if var2:
97 if not var2._parents and var2 in FADiff._revscal_inputs: # Root parent
98 roots.append(var2)
99 else:
100 for root in var2._root_inputs:
101 roots.append(root)
102 return roots