Coverage for .tox/py39/lib/python3.9/site-packages/cows/catalogue.py: 100.00%
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import numpy as np
3from .filament import label_skeleton, find_filaments
6def gen_catalogue(data, periodic=False, sort=True):
7 ''' Generate a catalogue of filaments
9 Generates a catalogue of filaments given an array containing a
10 cleaned and separated skeleton.
12 Parameters
13 ----------
14 data : ndarray, 3D
15 An array containing the separated skeleton. Zeros represent
16 background, ones are endpoints and, twos are regular cells.
17 periodic: bool
18 If True, the skeletonization uses periodic boundary conditions
19 for the input array. Input array must be 3D.
20 sort : boolean
21 If sort=True, the filaments are sorted by filament length in
22 descending order and reassigned IDs such that the longest filament
23 has an ID of zero.
25 Returns
26 -------
27 result : ndarray, 3D
28 An array with data.shape containing the sets of connected cells
29 with their respective ID.
30 catalogue : ndarray, 2D
31 A catalogue containing, for each cell, a row of filament ID,
32 filament length, X-, Y-, Z-position, X-, Y-, Z-direction.
34 Notes
35 -----
36 The function assumes that values larger than zero are part of the
37 skeleton.
39 Filament length is defined here as the number of member cells.
40 '''
42 ncells = data.shape[0]
44 # Classify the skeleton to identify endpoints and regular cells
45 data = label_skeleton(data, periodic=periodic)
47 # Connect cells within a 3x3x3 neighbourhood from endpoint to endpoint
48 # and store in the first column
49 _, cat = find_filaments(data, periodic=periodic)
50 catalogue = np.zeros([cat.shape[0], 8], order='c')
51 catalogue[:,0] = cat[:,0]
53 # Return the empty catalogue if no filaments were found
54 if catalogue.shape[0] == 0:
55 return catalogue
57 # Store the filament cell positions
58 catalogue[:,2:5] = cat[:,1:]
60 # Calculate the filament lengths and store in the second column
61 group_lengths = np.diff(np.hstack([0,np.where(np.diff(cat[:,0])!=0)[0]+1,
62 len(cat[:,0])]))
63 catalogue[:,1] = np.repeat(group_lengths, group_lengths)
65 # Calculate and store the filament cell directions
66 catalogue[:,5:8] = _get_direction(cat[:,0], cat[:,1:], ncells)
68 # Sort the catalogue by filament length in descending order
69 if sort:
70 sort_idx = np.lexsort([catalogue[:,0],catalogue[:,1]])[::-1]
71 catalogue = catalogue[sort_idx]
72 group_lengths = np.sort(group_lengths)[::-1]
73 catalogue[:,0] = np.repeat(np.arange(np.max(catalogue[:,0]))+1,
74 group_lengths)
76 return catalogue
78def _get_direction(index, pos, box_size):
79 '''
80 Calculates the direction of a filament cell based on the location
81 of that cells' neighbours. The direction vector is normalised.
82 '''
84 assert pos.ndim == 2
85 assert pos.shape[1] == 3
87 dxyz = np.zeros(pos.shape)
89 # Find the beginning and end of filaments in the catalogue
90 idx_diff = 1 - np.diff(index)
92 # Calculate position vector between 2 neighbours and account for periodic
93 dxyz_tmp = np.mod((pos[:-1]-pos[1:])+1, box_size) - 1
95 # Add direction to appropriate indices
96 dxyz_tmp = dxyz_tmp * idx_diff[:,None] # set to 0 between filaments
97 dxyz[:-1] += dxyz_tmp
98 dxyz[1:] += dxyz_tmp
100 # Noramlise direction vector
101 r = np.sqrt(np.sum(dxyz**2, axis=1))
102 return dxyz/r[:,None]