Coverage for sleapyfaces/project.py: 65%
63 statements
« prev ^ index » next coverage.py v7.0.2, created at 2023-01-03 12:07 -0800
« prev ^ index » next coverage.py v7.0.2, created at 2023-01-03 12:07 -0800
1import os
2from sleapyfaces.structs import CustomColumn, File, FileConstructor
3from sleapyfaces.experiment import Experiment
4from sleapyfaces.normalize import mean_center, z_score, pca
5from dataclasses import dataclass
6import pandas as pd
9class Project:
10 """Base class for project
12 Args:
13 base (str): Base path of the project (e.g. "/specialk_cs/2p/raw/CSE009")
14 iterator (dict[str, str]): Iterator for the project files, with keys as the label and values as the folder name (e.g. {"week 1": "20211105", "week 2": "20211112"})
15 DAQFile (str): The naming convention for the DAQ files (e.g. "*_events.csv" or "DAQOutput.csv")
16 ExprMetaFile (str): The naming convention for the experimental structure files (e.g. "*_config.json" or "BehMetadata.json")
17 SLEAPFile (str): The naming convention for the SLEAP files (e.g. "*_sleap.h5" or "SLEAP.h5")
18 VideoFile (str): The naming convention for the video files (e.g. "*.mp4" or "video.avi")
19 glob (bool): Whether to use glob to find the files (e.g. True or False)
20 NOTE: if glob is True, make sure to include the file extension in the naming convention
22 """
24 def __init__(
25 self,
26 DAQFile: str,
27 BehFile: str,
28 SLEAPFile: str,
29 VideoFile: str,
30 base: str,
31 iterator: dict[str, str] = {},
32 get_glob: bool = False,
33 ):
34 self.base = base
35 self.DAQFile = DAQFile
36 self.BehFile = BehFile
37 self.SLEAPFile = SLEAPFile
38 self.VideoFile = VideoFile
39 self.get_glob = get_glob
40 if len(iterator.keys()) == 0:
41 weeks = os.listdir(self.base)
42 weeks = [
43 week for week in weeks if os.path.isdir(os.path.join(self.base, week))
44 ]
45 weeks.sort()
46 for i, week in enumerate(weeks):
47 iterator[f"week {i+1}"] = week
48 self.iterator = iterator
49 self.exprs = [0] * len(self.iterator.keys())
50 self.files = [0] * len(self.iterator.keys())
51 for i, name in enumerate(list(self.iterator.keys())):
52 daq_file = File(
53 os.path.join(self.base, self.iterator[name]),
54 self.DAQFile,
55 self.get_glob,
56 )
57 sleap_file = File(
58 os.path.join(self.base, self.iterator[name]),
59 self.SLEAPFile,
60 self.get_glob,
61 )
62 beh_file = File(
63 os.path.join(self.base, self.iterator[name]),
64 self.BehFile,
65 self.get_glob,
66 )
67 video_file = File(
68 os.path.join(self.base, self.iterator[name]),
69 self.VideoFile,
70 self.get_glob,
71 )
72 self.files[i] = FileConstructor(daq_file, sleap_file, beh_file, video_file)
73 self.exprs[i] = Experiment(name, self.files[i])
75 def buildColumns(self, columns: list, values: list):
76 """Builds the custom columns for the project and builds the data for each experiment
78 Args:
79 columns (list[str]): the column titles
80 values (list[any]): the data for each column
82 Initializes attributes:
83 custom_columns (list[CustomColumn]): list of custom columns
84 all_data (pd.DataFrame): the data for all experiments concatenated together
85 """
86 self.custom_columns = [0] * len(columns)
87 for i in range(len(self.custom_columns)):
88 self.custom_columns[i] = CustomColumn(columns[i], values[i])
89 exprs_list = [0] * len(self.exprs)
90 names_list = [0] * len(self.exprs)
91 for i in range(len(self.exprs)):
92 self.exprs[i].buildData(self.custom_columns)
93 exprs_list[i] = self.exprs[i].sleap.tracks
94 names_list[i] = self.exprs[i].name
95 self.all_data = pd.concat(exprs_list, keys=names_list)
97 def buildTrials(
98 self,
99 TrackedData: list[str],
100 Reduced: list[bool],
101 start_buffer: int = 10000,
102 end_buffer: int = 13000,
103 ):
104 """Parses the data from each experiment into its individual trials
106 Args:
107 TrackedData (list[str]): The title of the columns from the DAQ data to be tracked
108 Reduced (list[bool]): The corresponding boolean for whether the DAQ data is to be reduced (`True`) or not (`False`)
109 start_buffer (int, optional): The time in milliseconds before the trial start to capture. Defaults to 10000.
110 end_buffer (int, optional): The time in milliseconds after the trial start to capture. Defaults to 13000.
112 Initializes attributes:
113 exprs[i].trials (pd.DataFrame): the data frame containing the concatenated trial data for each experiment
114 exprs[i].trialData (list[pd.DataFrame]): the list of data frames containing the trial data for each trial for each experiment
115 """
116 for i in range(len(self.exprs)):
117 self.exprs[i].buildTrials(TrackedData, Reduced, start_buffer, end_buffer)
119 def meanCenter(self):
120 """Recursively mean centers the data for each trial for each experiment
122 Initializes attributes:
123 all_data (pd.DataFrame): the mean centered data for all trials and experiments concatenated together
124 """
125 mean_all = [0] * len(self.exprs)
126 for i in range(len(self.exprs)):
127 mean_all[i] = [0] * len(self.exprs[i].trialData)
128 for j in range(len(self.exprs[i].trialData)):
129 mean_all[i][j] = mean_center(
130 self.exprs[i].trialData[i], self.exprs[i].sleap.track_names
131 )
132 mean_all[i] = pd.concat(
133 mean_all[i],
134 axis=0,
135 keys=range(len(mean_all[i])),
136 )
137 mean_all[i] = mean_center(mean_all[i], self.exprs[i].sleap.track_names)
138 self.all_data = pd.concat(mean_all, keys=list(self.iterator.keys()))
140 def zScore(self):
141 """Z scores the mean centered data for each experiment
143 Updates attributes:
144 all_data (pd.DataFrame): the z-scored data for all experiments concatenated together
145 """
146 self.all_data = z_score(self.all_data, self.exprs[0].sleap.track_names)
148 def normalize(self):
149 """Runs the mean centering and z scoring functions
151 Updates attributes:
152 all_data (pd.DataFrame): the fully normalized data for all experiments concatenated together
153 """
154 analyze_all = [0] * len(self.exprs)
155 for iterator in range(len(self.exprs)):
156 analyze_all[iterator] = self.exprs[iterator].normalizeTrials()
157 analyze_all = pd.concat(analyze_all, keys=list(self.iterator.keys()))
158 self.all_data = z_score(analyze_all, self.exprs[0].sleap.track_names)
160 def visualize(self):
161 """Reduces `all_data` to 2 and 3 dimensions using PCA
163 Initializes attributes:
164 pcas (dict[str, pd.DataFrame]): a dictionary containing the 2 and 3 dimensional PCA data for each experiment (the keys are 'pca2d', 'pca3d')
165 """
166 self.pcas = pca(self.all_data, self.exprs[0].sleap.track_names)