Coverage for src/FuncVect.py : 22%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python3
3from FADiff import FADiff
4from fad.Gradients import Scal as fadScal
5from rev.Gradients import Scal as revScal
8class FuncVect:
9 def __init__(self, funcs):
10 try: # All fxns all Scal xor all Vect?
11 func_type = type(funcs[0]) # Check the first element
12 for func in funcs[1:]:
13 if type(func) != func_type: # Check for consistency
14 raise Exception('All functions must be of same type (Scalar '
15 'or Vector).')
16 except Exception:
17 raise
18 else:
19 # Set inputs list based on type (Scal or Vect) of functions
20 if func_type is fadScal:
21 inputs = FADiff._fadscal_inputs
22 else:
23 inputs = FADiff._fadvect_inputs
24 self._f_vect = funcs # List of objects (Scal or Vect) used in fxns
25 self._input_vars = [] # Get complete list of input vars of f_vect
26 for func in funcs:
27 if func._parents:
28 for var in func._der.keys():
29 if var in func._parents:
30 self._input_vars.append(var)
31 elif func in inputs: # TODO: Input var (identity fxn) is correct? Make sure to test
32 self._input_vars.append(func)
33 self._input_vars = list(set(self._input_vars))
35 @property
36 def val(self):
37 func_vals = []
38 for func in self._f_vect:
39 func_vals.append(func._val)
40 return func_vals
42 @property
43 def der(self):
44 funcs_parents = []
45 for func in self._f_vect:
46 parents = []
47 for var, part_der in func._der.items():
48 if var in self._input_vars:
49 parents.append(part_der)
50 funcs_parents.append(parents)
51 return funcs_parents # TODO: Return correct?