Coverage for /Users/jerry/Development/yenta/yenta/pipeline/Pipeline.py: 98%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

170 statements  

1import io 

2import json 

3import logging 

4import tempfile 

5import pickle 

6import shutil 

7 

8from dataclasses import dataclass, field, asdict 

9from enum import Enum 

10from itertools import chain 

11from pathlib import Path 

12from typing import Dict, List, Union, Any 

13 

14import networkx as nx 

15from colorama import Fore, Style 

16from more_itertools import split_after 

17 

18from yenta.artifacts.Artifact import Artifact 

19from yenta.config import settings 

20from yenta.tasks.Task import TaskDef, ParameterType, ResultSpec 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25class InvalidTaskResultError(Exception): 

26 pass 

27 

28 

29class InvalidParameterError(Exception): 

30 pass 

31 

32 

33class PipelineConfigError(Exception): 

34 pass 

35 

36 

37class TaskStatus(str, Enum): 

38 

39 SUCCESS = 'success' 

40 FAILURE = 'failure' 

41 

42 

43@dataclass 

44class TaskResult: 

45 """ Holds the result of a specific task execution """ 

46 

47 values: Dict[str, Any] = field(default_factory=dict) 

48 """ A dictionary whose keys are value names and whose values are... values.""" 

49 

50 status: TaskStatus = None 

51 """ Whether the task succeeded or failed.""" 

52 

53 error: str = None 

54 """ Error message associated with task failure.""" 

55 

56 

57@dataclass 

58class PipelineResult: 

59 """ Holds the intermediate results of a step in the pipeline, where the keys of the dicts 

60 are the names of the tasks that have been executed and the values are TaskResults""" 

61 

62 task_results: Dict[str, TaskResult] = field(default_factory=dict) 

63 """ A dictionary whose keys are task names and whose values are the results of that task execution.""" 

64 

65 task_inputs: Dict[str, 'PipelineResult'] = field(default_factory=dict) 

66 """ A dictionary whose keys are task names and whose values are the inputs used in executing that task.""" 

67 

68 def values(self, task_name: str, value_name: str): 

69 """ Return the value named `value_name` that was produced by task `task_name`. 

70 

71 :param str task_name: The name of the task 

72 :param str value_name: The name of the value 

73 :return: the unwrapped value produced by the task 

74 :rtype: Union[list, int, bool, float, str] 

75 """ 

76 return self.task_results[task_name].values[value_name] 

77 

78 def artifacts(self, task_name: str, artifact_name: str): 

79 """ Return the artifact names `artifact_name` that was produced by the task `task_name`. 

80 

81 :param str task_name: The name of the task 

82 :param str artifact_name: The name of the artifact 

83 :return: The artifact produced by the task 

84 :rtype: Artifact 

85 """ 

86 return self.task_results[task_name].artifacts[artifact_name] 

87 

88 def from_spec(self, spec: ResultSpec): 

89 """ Return either the value or the artifact of a given task, as computed by 

90 a ResultSpec. Delegates the actual work to the `value` and `artifacts` functions. 

91 

92 :param ResultSpec spec: The result spec 

93 :return: either the value or the artifact computed from the spec 

94 """ 

95 func = getattr(self, spec.result_type) 

96 return func(spec.result_task_name, spec.result_var_name) 

97 

98 

99class Pipeline: 

100 

101 def __init__(self, *tasks, name='default'): 

102 

103 self._tasks = tasks 

104 self.task_graph = nx.DiGraph() 

105 self.execution_order = [] 

106 self.name = name 

107 self.store_path = settings.YENTA_STORE_PATH / self.name 

108 

109 self.store_path.mkdir(exist_ok=True, parents=True) 

110 

111 self.build_task_graph() 

112 

113 self._tasks_executed = set() 

114 self._tasks_reused = set() 

115 

116 def _clear_pipeline_cache(self): 

117 """ Delete the pipeline cache. Only used for testing purposes. """ 

118 shutil.rmtree(self.store_path) # pragma: no cover 

119 

120 def build_task_graph(self) -> None: 

121 """ Construct the task graph for the pipeline 

122 

123 :return: None 

124 """ 

125 logger.debug('Building task graph') 

126 for task in self._tasks: 

127 self.task_graph.add_node(task.task_def.name, task=task) 

128 for dependency in (task.task_def.depends_on or []): 

129 dependency = dependency.split('.')[0] 

130 self.task_graph.add_edge(dependency, task.task_def.name) 

131 

132 logger.debug('Computing execution order') 

133 try: 

134 self.execution_order = list(nx.algorithms.dag.lexicographical_topological_sort(self.task_graph)) 

135 except nx.NetworkXUnfeasible as ex: 

136 print(Fore.RED + 'Unable to build execution graph because pipeline contains cyclic dependencies.') 

137 raise ex 

138 

139 @staticmethod 

140 def _wrap_task_output(raw_output: Union[dict, TaskResult], task_name: str) -> TaskResult: 

141 """ Wrap the raw output of a task in a TaskResult. 

142 

143 :param Union[dict, TaskResult] raw_output: The raw output of a task. 

144 :param task_name: The name of the task. 

145 :return: A TaskResult containing the output 

146 :rtype: TaskResult 

147 """ 

148 

149 if isinstance(raw_output, dict): 

150 output: TaskResult = TaskResult(**raw_output) 

151 elif isinstance(raw_output, TaskResult): 

152 output = raw_output 

153 else: 

154 raise InvalidTaskResultError(f'Task {task_name} returned invalid result of type {type(raw_output)}, ' 

155 f'expected either a dict or a TaskResult') 

156 

157 return output 

158 

159 @staticmethod 

160 def build_args_dict(task, args: PipelineResult) -> Dict[str, Any]: 

161 """ Build the args dictionary for executing a task. 

162 

163 :param task: The task itself, which has a `task_def` attached to it. 

164 :param PipelineResult args: The results of the pipeline up to this point 

165 :return: A dictionary whose keys correspond to the arguments expected by 

166 the task to be executed, and whose values are the values to be 

167 passed in. 

168 :rtype: Dict[str, Any] 

169 """ 

170 

171 logger.debug('Building args dictionary') 

172 task_def: TaskDef = task.task_def 

173 args_dict = {} 

174 

175 for spec in task_def.param_specs: 

176 if spec.param_type == ParameterType.PIPELINE_RESULTS: 

177 args_dict[spec.param_name] = args 

178 elif spec.param_type == ParameterType.EXPLICIT: 

179 args_dict[spec.param_name] = args.values(spec.result_spec.result_task_name, 

180 spec.result_spec.result_var_name) 

181 

182 return args_dict 

183 

184 def invoke_task(self, task, **kwargs) -> TaskResult: 

185 """ Call the function that represents the task with the supplied kwargs. 

186 

187 :param Callable task: The task function. 

188 :param dict kwargs: The arguments obtained from `build_args`. 

189 :return: The task result 

190 :rtype: TaskResult 

191 """ 

192 

193 output = task(**kwargs) 

194 return self._wrap_task_output(output, task.task_def.name) 

195 

196 @staticmethod 

197 def merge_pipeline_results(res1: PipelineResult, res2: PipelineResult) -> PipelineResult: 

198 """ Combine two different pipeline results. If they share keys, 

199 the results of the second pipeline will overwrite those of 

200 the first. 

201 

202 :param PipelineResult res1: The first result. 

203 :param PipelineResult res2: The second result. 

204 :return: The merged result. 

205 :rtype: PipelineResult 

206 """ 

207 

208 return PipelineResult(task_results={**res1.task_results, **res2.task_results}, 

209 task_inputs={**res1.task_inputs, **res2.task_inputs}) 

210 

211 def cache_result(self, task_name: str, result: PipelineResult): 

212 """ Write the pipeline results to a file. 

213 

214 :param str task_name: The name of the task to cache. 

215 :param PipelineResult result: The results. 

216 :return: None 

217 """ 

218 task_path = self.store_path / task_name 

219 task_path.mkdir(exist_ok=True, parents=True) 

220 

221 task_cache = task_path / 'result.pk' 

222 with open(task_cache, 'wb') as f: 

223 pickle.dump(result.task_results[task_name], f) 

224 

225 task_cache = task_path / 'inputs.pk' 

226 with open(task_cache, 'wb') as f: 

227 pickle.dump(result.task_inputs[task_name], f) 

228 

229 @staticmethod 

230 def load_pipeline(store_path: Path) -> PipelineResult: 

231 """ Load a pipeline from file. 

232 

233 :return: The pipeline. 

234 :rtype: PipelineResult 

235 """ 

236 logger.debug(f'Loading pipeline from {store_path}') 

237 pipeline = PipelineResult() 

238 if store_path.exists(): 

239 for task_path in store_path.iterdir(): 

240 if task_path.is_dir(): 

241 task_name = task_path.stem 

242 with open(task_path / 'inputs.pk', 'rb') as f: 

243 inputs = pickle.load(f) 

244 with open(task_path / 'result.pk', 'rb') as f: 

245 result = pickle.load(f) 

246 pipeline.task_inputs[task_name] = inputs 

247 pipeline.task_results[task_name] = result 

248 

249 return pipeline 

250 

251 @staticmethod 

252 def reuse_inputs(task_name: str, previous_result: PipelineResult, args: PipelineResult) -> bool: 

253 """ Determine whether inputs from the previous instance of this task should be reused 

254 or whether the task should be executed again. 

255 

256 :param str task_name: The name of the task. 

257 :param PipelineResult previous_result: The previous pipeline result. 

258 :param PipelineResult args: The arguments with which this task is being called. 

259 :return: True or False 

260 :rtype: bool 

261 """ 

262 previous_inputs = previous_result.task_inputs.get(task_name, None) 

263 if previous_inputs and previous_result.task_results.get(task_name).status == TaskStatus.SUCCESS: 

264 return previous_inputs == args 

265 

266 return False 

267 

268 def run_pipeline(self, up_to: str = None, force_rerun: List[str] = None) -> PipelineResult: 

269 """ Execute the tasks in the pipeline. 

270 

271 :param str up_to: If supplied, execute the pipeline only up to this task. 

272 :param List[str] force_rerun: Optionally force the listed tasks to be executed. 

273 :return: The final pipeline state. 

274 :rtype: PipelineResult 

275 """ 

276 

277 previous_result: PipelineResult = self.load_pipeline(self.store_path) 

278 result = PipelineResult() 

279 self._tasks_reused.clear() 

280 self._tasks_executed.clear() 

281 

282 for task_name in list(split_after(self.execution_order, lambda x: x == up_to))[0]: 

283 logger.debug(f'Starting executions of {task_name}') 

284 task_node = self.task_graph.nodes.get(task_name, None) 

285 if not task_node: 

286 raise PipelineConfigError(f'Dependency on nonexistent task: {task_name}') 

287 task = task_node['task'] 

288 args = PipelineResult() 

289 dependencies_succeeded = True 

290 for dependency in (task.task_def.depends_on or []): 

291 dependency = dependency.split('.')[0] 

292 args.task_results[dependency] = result.task_results[dependency] 

293 if result.task_results[dependency].status == TaskStatus.FAILURE: 

294 dependencies_succeeded = False 

295 break 

296 

297 if dependencies_succeeded: 

298 if task.task_def.pure and task_name not in (force_rerun or []) and \ 

299 self.reuse_inputs(task_name, previous_result, args): 

300 logger.debug(f'Reusing previous results of {task_name}') 

301 self._tasks_reused.add(task_name) 

302 output = previous_result.task_results[task_name] 

303 marker = Fore.YELLOW + u'\u2014' + Fore.WHITE 

304 else: 

305 args_dict = self.build_args_dict(task, args) 

306 try: 

307 logger.debug(f'Calling function to execute {task_name}') 

308 output = self.invoke_task(task, **args_dict) 

309 output.status = TaskStatus.SUCCESS 

310 marker = Fore.GREEN + u'\u2714' + Fore.WHITE 

311 self._tasks_executed.add(task_name) 

312 except Exception as ex: 

313 import traceback 

314 print(Fore.RED) 

315 traceback.print_exc() 

316 print(Fore.WHITE) 

317 logger.error(f'Caught exception executing {task_name}: {ex}') 

318 output = TaskResult(status=TaskStatus.FAILURE, error=str(ex)) 

319 marker = Fore.RED + u'\u2718' + Fore.WHITE 

320 

321 print(Fore.WHITE + Style.BRIGHT + f'[{marker}] {task_name}') 

322 

323 result.task_results[task_name] = output 

324 result.task_inputs[task_name] = args 

325 

326 result = self.merge_pipeline_results(previous_result, result) 

327 self.cache_result(task_name, result) 

328 

329 return result