Source code for wbia_cnn.theano_ext

# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import utool as ut
from theano import tensor as T  # NOQA

print, rrr, profile = ut.inject2(__name__)


[docs]def get_symbol_inputs(expr_list=[]): if not isinstance(expr_list, list): expr_list = [expr_list] inputs_ = [] for expr in expr_list: if isinstance(expr, T.Constant): # constants don't count as inputs continue parents = expr.get_parents() if len(parents) == 0: # no parents, this is an input inputs_ += [expr] else: inputs_ += get_symbol_inputs(parents) return inputs_
[docs]def eval_symbol(expr, inputs_to_value): # evaluate a expr without complaining about unused inputs inputs_ = get_symbol_inputs([expr]) inputs_to_value_ = { key: inputs_to_value[key] for key in inputs_ if key in inputs_to_value } theano_value = expr.eval(inputs_to_value_) return theano_value