Source code for wbia_cnn.experiments

# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import functools
from wbia_cnn import draw_results
import utool as ut

print, rrr, profile = ut.inject2(__name__)


[docs]def sift_dataset_separability(dataset): """ VERY HACKED RIGHT NOW. ONLY LIBERTY. BLINDLY CACHES Args: dataset (?): CommandLine: python -m wbia_cnn.experiments --exec-sift_dataset_separability --show Example: >>> # SCRIPT >>> from wbia_cnn.experiments import * # NOQA >>> from wbia_cnn import ingest_data >>> dataset = ingest_data.grab_liberty_siam_dataset(250000) >>> ut.quit_if_noshow() >>> sift_dataset_separability(dataset) >>> ut.show_if_requested() """ import vtool as vt @ut.cached_func('tempsiftscorecache', cache_dir='.') def cached_siftscores(): data, labels = dataset.subset('test') sift_scores, sift_list = test_sift_patchmatch_scores(data, labels) sift_scores = sift_scores.astype(np.float64) return sift_scores, labels, sift_list sift_scores, labels, sift_list = cached_siftscores() # I dont think we can compare lnbnn on liberty # because we dont have a set of id labels, we have # pairs of correspondences. # import pyflann # flann = pyflann.FLANN() # flann.build_index(sift_list) # idxs, dists = flann.nn_index(sift_list, 10) encoder_kw = { #'monotonize': False, 'monotonize': True, } sift_encoder = vt.ScoreNormalizer(**encoder_kw) sift_encoder.fit(sift_scores, labels) dataname = dataset.alias_key viz_kw = dict( with_scores=False, with_postbayes=False, with_prebayes=False, target_tpr=0.95, score_range=(0, 1), ) inter_sift = sift_encoder.visualize( figtitle=dataname + ' SIFT scores. #data=' + str(len(labels)), fnum=None, **viz_kw ) import plottool as pt # icon = ibs.get_database_icon() icon = ( 'http://www.councilchronicle.com/wp-content/uploads/2015/08/' 'West-Virginia-Arrested-over-Bogus-Statue-of-Liberty-Bomb-Threat.jpg' ) if icon is not None: pt.overlay_icon(icon, coords=(1, 0), bbox_alignment=(1, 0), max_dsize=(None, 192)) if ut.get_argflag('--contextadjust'): pt.adjust_subplots(left=0.1, bottom=0.25, wspace=0.2, hspace=0.2) pt.adjust_subplots(use_argv=True) return inter_sift
# def extract_sifts(data, labels): # import pyhesaff # if len(data.shape) == 4 and data.shape[-1] == 1: # data = data.reshape(data.shape[0:3]) # elif len(data.shape) == 4 and data.shape[-1] == 3: # import vtool as vt # # TODO use dataset to infer data colorspace # data = vt.convert_image_list_colorspace(data, 'GRAY', src_colorspace='BGR') # patch_list = data # print('Extract SIFT descr') # vecs_list = pyhesaff.extract_desc_from_patches(patch_list) # return vecs_list
[docs]def test_sift_patchmatch_scores(data, labels): """ data = X_test labels = y_test """ import pyhesaff import numpy as np if len(data.shape) == 4 and data.shape[-1] == 1: data = data.reshape(data.shape[0:3]) elif len(data.shape) == 4 and data.shape[-1] == 3: import vtool as vt # TODO use dataset to infer data colorspace data = vt.convert_image_list_colorspace(data, 'GRAY', src_colorspace='BGR') patch_list = data print('Extract SIFT descr') vecs_list = pyhesaff.extract_desc_from_patches(patch_list) print('Compute SIFT dist') sqrddist = ( (vecs_list[0::2].astype(np.float32) - vecs_list[1::2].astype(np.float32)) ** 2 ).sum(axis=1) sqrddist_ = sqrddist[None, :].T VEC_PSEUDO_MAX_DISTANCE_SQRD = 2.0 * (512.0 ** 2.0) # sift_scores = 1 - (sqrddist_.flatten() / VEC_PSEUDO_MAX_DISTANCE_SQRD) sift_scores = sqrddist_.flatten() / VEC_PSEUDO_MAX_DISTANCE_SQRD sift_list = vecs_list return sift_scores, sift_list
# test_siamese_thresholds(sqrddist_, labels, figtitle='SIFT descriptor distances')
[docs]def test_siamese_performance(model, data, labels, flat_metadata, dataname=''): r""" CommandLine: utprof.py -m wbia_cnn --tf pz_patchmatch --db liberty --test --weights=liberty:current --arch=siaml2_128 --test python -m wbia_cnn --tf netrun --db liberty --arch=siaml2_128 --test --ensure python -m wbia_cnn --tf netrun --db liberty --arch=siaml2_128 --test --ensure --weights=new python -m wbia_cnn --tf netrun --db liberty --arch=siaml2_128 --train --weights=new python -m wbia_cnn --tf netrun --db pzmtest --weights=liberty:current --arch=siaml2_128 --test # NOQA python -m wbia_cnn --tf netrun --db pzmtest --weights=liberty:current --arch=siaml2_128 """ import vtool as vt import plottool as pt # TODO: save in model.trainind_dpath/diagnostics/figures ut.colorprint('\n[siam_perf] Testing Siamese Performance', 'white') # epoch_dpath = model.get_epoch_diagnostic_dpath() epoch_dpath = model.arch_dpath ut.vd(epoch_dpath) dataname += ' ' + model.get_history_hashid() + '\n' history_text = ut.list_str(model.era_history, newlines=True) ut.write_to(ut.unixjoin(epoch_dpath, 'era_history.txt'), history_text) # if True: # import matplotlib as mpl # mpl.rcParams['agg.path.chunksize'] = 100000 # data = data[::50] # labels = labels[::50] # from wbia_cnn import utils # data, labels = utils.random_xy_sample(data, labels, 10000, model.data_per_label_input) FULL = not ut.get_argflag('--quick') fnum_gen = pt.make_fnum_nextgen() ut.colorprint('[siam_perf] Show era history', 'white') fig = model.show_era_loss(fnum=fnum_gen()) pt.save_figure(fig=fig, dpath=epoch_dpath, dpi=180) # hack ut.colorprint('[siam_perf] Show weights image', 'white') fig = model.show_weights_image(fnum=fnum_gen()) pt.save_figure(fig=fig, dpath=epoch_dpath, dpi=180) # model.draw_all_conv_layer_weights(fnum=fnum_gen()) # model.imwrite_weights(1) # model.imwrite_weights(2) # Compute each type of score ut.colorprint('[siam_perf] Building Scores', 'white') test_outputs = model.predict2(model, data) network_output = test_outputs['network_output_determ'] # hack converting network output to distances for non-descriptor networks if len(network_output.shape) == 2 and network_output.shape[1] == 1: cnn_scores = network_output.T[0] elif len(network_output.shape) == 1: cnn_scores = network_output elif len(network_output.shape) == 2 and network_output.shape[1] > 1: assert model.data_per_label_output == 2 vecs1 = network_output[0::2] vecs2 = network_output[1::2] cnn_scores = vt.L2(vecs1, vecs2) else: assert False cnn_scores = cnn_scores.astype(np.float64) # Segfaults with the data passed in is large (AND MEMMAPPED apparently) # Fixed in hesaff implementation SIFT = FULL if SIFT: sift_scores, sift_list = test_sift_patchmatch_scores(data, labels) sift_scores = sift_scores.astype(np.float64) ut.colorprint('[siam_perf] Learning Encoders', 'white') # Learn encoders encoder_kw = { #'monotonize': False, 'monotonize': True, } cnn_encoder = vt.ScoreNormalizer(**encoder_kw) cnn_encoder.fit(cnn_scores, labels) if SIFT: sift_encoder = vt.ScoreNormalizer(**encoder_kw) sift_encoder.fit(sift_scores, labels) # Visualize ut.colorprint('[siam_perf] Visualize Encoders', 'white') viz_kw = dict( with_scores=False, with_postbayes=False, with_prebayes=False, target_tpr=0.95, ) inter_cnn = cnn_encoder.visualize( figtitle=dataname + ' CNN scores. #data=' + str(len(data)), fnum=fnum_gen(), **viz_kw ) if SIFT: inter_sift = sift_encoder.visualize( figtitle=dataname + ' SIFT scores. #data=' + str(len(data)), fnum=fnum_gen(), **viz_kw ) # Save pt.save_figure(fig=inter_cnn.fig, dpath=epoch_dpath) if SIFT: pt.save_figure(fig=inter_sift.fig, dpath=epoch_dpath) # Save out examples of hard errors # cnn_fp_label_indicies, cnn_fn_label_indicies = # cnn_encoder.get_error_indicies(cnn_scores, labels) # sift_fp_label_indicies, sift_fn_label_indicies = # sift_encoder.get_error_indicies(sift_scores, labels) with_patch_examples = FULL if with_patch_examples: ut.colorprint('[siam_perf] Visualize Confusion Examples', 'white') cnn_indicies = cnn_encoder.get_confusion_indicies(cnn_scores, labels) if SIFT: sift_indicies = sift_encoder.get_confusion_indicies(sift_scores, labels) warped_patch1_list, warped_patch2_list = list(zip(*ut.ichunks(data, 2))) samp_args = (warped_patch1_list, warped_patch2_list, labels) _sample = functools.partial(draw_results.get_patch_sample_img, *samp_args) cnn_fp_img = _sample({'fs': cnn_scores}, cnn_indicies.fp)[0] cnn_fn_img = _sample({'fs': cnn_scores}, cnn_indicies.fn)[0] cnn_tp_img = _sample({'fs': cnn_scores}, cnn_indicies.tp)[0] cnn_tn_img = _sample({'fs': cnn_scores}, cnn_indicies.tn)[0] if SIFT: sift_fp_img = _sample({'fs': sift_scores}, sift_indicies.fp)[0] sift_fn_img = _sample({'fs': sift_scores}, sift_indicies.fn)[0] sift_tp_img = _sample({'fs': sift_scores}, sift_indicies.tp)[0] sift_tn_img = _sample({'fs': sift_scores}, sift_indicies.tn)[0] # if ut.show_was_requested(): # def rectify(arr): # return np.flipud(arr) SINGLE_FIG = False if SINGLE_FIG: def dump_img(img_, lbl, fnum): fig, ax = pt.imshow(img_, figtitle=dataname + ' ' + lbl, fnum=fnum) pt.save_figure(fig=fig, dpath=epoch_dpath, dpi=180) dump_img(cnn_fp_img, 'cnn_fp_img', fnum_gen()) dump_img(cnn_fn_img, 'cnn_fn_img', fnum_gen()) dump_img(cnn_tp_img, 'cnn_tp_img', fnum_gen()) dump_img(cnn_tn_img, 'cnn_tn_img', fnum_gen()) dump_img(sift_fp_img, 'sift_fp_img', fnum_gen()) dump_img(sift_fn_img, 'sift_fn_img', fnum_gen()) dump_img(sift_tp_img, 'sift_tp_img', fnum_gen()) dump_img(sift_tn_img, 'sift_tn_img', fnum_gen()) # vt.imwrite(dataname + '_' + 'cnn_fp_img.png', (cnn_fp_img)) # vt.imwrite(dataname + '_' + 'cnn_fn_img.png', (cnn_fn_img)) # vt.imwrite(dataname + '_' + 'sift_fp_img.png', (sift_fp_img)) # vt.imwrite(dataname + '_' + 'sift_fn_img.png', (sift_fn_img)) else: print('Drawing TP FP TN FN') fnum = fnum_gen() pnum_gen = pt.make_pnum_nextgen(4, 2) fig = pt.figure(fnum) pt.imshow(cnn_fp_img, title='CNN FP', fnum=fnum, pnum=pnum_gen()) pt.imshow(sift_fp_img, title='SIFT FP', fnum=fnum, pnum=pnum_gen()) pt.imshow(cnn_fn_img, title='CNN FN', fnum=fnum, pnum=pnum_gen()) pt.imshow(sift_fn_img, title='SIFT FN', fnum=fnum, pnum=pnum_gen()) pt.imshow(cnn_tp_img, title='CNN TP', fnum=fnum, pnum=pnum_gen()) pt.imshow(sift_tp_img, title='SIFT TP', fnum=fnum, pnum=pnum_gen()) pt.imshow(cnn_tn_img, title='CNN TN', fnum=fnum, pnum=pnum_gen()) pt.imshow(sift_tn_img, title='SIFT TN', fnum=fnum, pnum=pnum_gen()) pt.set_figtitle(dataname + ' confusions') pt.adjust_subplots(left=0, right=1.0, bottom=0.0, wspace=0.01, hspace=0.05) pt.save_figure(fig=fig, dpath=epoch_dpath, dpi=180, figsize=(9, 18)) with_patch_desc = FULL if with_patch_desc: ut.colorprint('[siam_perf] Visualize Patch Descriptors', 'white') fnum = fnum_gen() fig = pt.figure(fnum=fnum, pnum=(1, 1, 1)) num_rows = 7 pnum_gen = pt.make_pnum_nextgen(num_rows, 3) # Compare actual output descriptors for index in ut.random_indexes(len(sift_list), num_rows): vec_sift = sift_list[index] vec_cnn = network_output[index] patch = data[index] pt.imshow(patch, fnum=fnum, pnum=pnum_gen()) pt.plot_descriptor_signature(vec_cnn, 'cnn vec', fnum=fnum, pnum=pnum_gen()) pt.plot_sift_signature(vec_sift, 'sift vec', fnum=fnum, pnum=pnum_gen()) pt.set_figtitle('Patch Descriptors') pt.adjust_subplots(left=0, right=0.95, bottom=0.0, wspace=0.1, hspace=0.15) pt.save_figure(fig=fig, dpath=epoch_dpath, dpi=180, figsize=(9, 18))
# ut.vd(epoch_dpath)
[docs]def show_hard_cases(model, data, labels, scores): from wbia_cnn import utils encoder = model.learn_encoder(labels, scores) encoder.visualize() # x = encoder.inverse_normalize(np.cast['float32'](encoder.learned_thresh)) # encoder.normalize_scores(x) # encoder.inverse_normalize(np.cast['float32'](encoder.learned_thresh)) fp_label_indicies, fn_label_indicies = encoder.get_error_indicies(scores, labels) fn_data_indicies = utils.expand_data_indicies( fn_label_indicies, model.data_per_label_input ) fp_data_indicies = utils.expand_data_indicies( fp_label_indicies, model.data_per_label_input ) fn_data = data.take(fn_data_indicies, axis=0) fn_labels = labels.take(fn_label_indicies, axis=0) fn_scores = scores.take(fn_label_indicies, axis=0) fp_data = data.take(fp_data_indicies, axis=0) fp_labels = labels.take(fp_label_indicies, axis=0) fp_scores = scores.take(fp_label_indicies, axis=0) from wbia_cnn import draw_results draw_results.rrr() draw_results.interact_siamsese_data_patches( fn_labels, fn_data, {'fs': fn_scores}, figtitle='FN' ) draw_results.interact_siamsese_data_patches( fp_labels, fp_data, {'fs': fp_scores}, figtitle='FP' )
[docs]def test_siamese_thresholds(network_output, y_test, **kwargs): """ Test function to see how good of a threshold we can learn network_output = prob_list """ import vtool as vt # batch cycling may cause more outputs than test labels. # should be able to just crop network_output_ = network_output[0 : len(y_test)].copy() ** 2 tp_support = network_output_.T[0][y_test.astype(np.bool)].astype(np.float64) tn_support = network_output_.T[0][~(y_test.astype(np.bool))].astype(np.float64) if tp_support.mean() < tn_support.mean(): print('need to invert scores') tp_support *= -1 tn_support *= -1 bottom = min(tn_support.min(), tp_support.min()) if bottom < 0: print('need to subtract from scores') tn_support -= bottom tp_support -= bottom vt.score_normalization.rrr() vt.score_normalization.test_score_normalization( tp_support, tn_support, with_scores=False, **kwargs )
# from wbia.algo.hots import score_normalization # test_score_normalization # learnkw = dict() # learntup = score_normalization.learn_score_normalization( # tp_support, tn_support, return_all=False, **learnkw) # (score_domain, p_tp_given_score, clip_score) = learntup # Plotting # import plottool as pt # fnum = 1 # pt.figure(fnum=fnum, pnum=(2, 1, 1), doclf=True, docla=True) # score_normalization.plot_support(tn_support, tp_support, fnum=fnum, pnum=(2, 1, 1)) # score_normalization.plot_postbayes_pdf( # score_domain, 1 - p_tp_given_score, p_tp_given_score, fnum=fnum, pnum=(2, 1, 2)) # pass if __name__ == '__main__': """ CommandLine: python -m wbia_cnn.experiments python -m wbia_cnn.experiments --allexamples python -m wbia_cnn.experiments --allexamples --noface --nosrc """ import multiprocessing multiprocessing.freeze_support() # for win32 import utool as ut # NOQA ut.doctest_funcs()