Coverage for /Users/jerry/Development/yenta/yenta/pipeline/Pipeline.py: 98%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import io
2import json
3import logging
4import tempfile
5import pickle
6import shutil
8from dataclasses import dataclass, field, asdict
9from enum import Enum
10from itertools import chain
11from pathlib import Path
12from typing import Dict, List, Union, Any
14import networkx as nx
15from colorama import Fore, Style
16from more_itertools import split_after
18from yenta.artifacts.Artifact import Artifact
19from yenta.config import settings
20from yenta.tasks.Task import TaskDef, ParameterType, ResultSpec
22logger = logging.getLogger(__name__)
25class InvalidTaskResultError(Exception):
26 pass
29class InvalidParameterError(Exception):
30 pass
33class PipelineConfigError(Exception):
34 pass
37class TaskStatus(str, Enum):
39 SUCCESS = 'success'
40 FAILURE = 'failure'
43@dataclass
44class TaskResult:
45 """ Holds the result of a specific task execution """
47 values: Dict[str, Any] = field(default_factory=dict)
48 """ A dictionary whose keys are value names and whose values are... values."""
50 status: TaskStatus = None
51 """ Whether the task succeeded or failed."""
53 error: str = None
54 """ Error message associated with task failure."""
57@dataclass
58class PipelineResult:
59 """ Holds the intermediate results of a step in the pipeline, where the keys of the dicts
60 are the names of the tasks that have been executed and the values are TaskResults"""
62 task_results: Dict[str, TaskResult] = field(default_factory=dict)
63 """ A dictionary whose keys are task names and whose values are the results of that task execution."""
65 task_inputs: Dict[str, 'PipelineResult'] = field(default_factory=dict)
66 """ A dictionary whose keys are task names and whose values are the inputs used in executing that task."""
68 def values(self, task_name: str, value_name: str):
69 """ Return the value named `value_name` that was produced by task `task_name`.
71 :param str task_name: The name of the task
72 :param str value_name: The name of the value
73 :return: the unwrapped value produced by the task
74 :rtype: Union[list, int, bool, float, str]
75 """
76 return self.task_results[task_name].values[value_name]
78 def artifacts(self, task_name: str, artifact_name: str):
79 """ Return the artifact names `artifact_name` that was produced by the task `task_name`.
81 :param str task_name: The name of the task
82 :param str artifact_name: The name of the artifact
83 :return: The artifact produced by the task
84 :rtype: Artifact
85 """
86 return self.task_results[task_name].artifacts[artifact_name]
88 def from_spec(self, spec: ResultSpec):
89 """ Return either the value or the artifact of a given task, as computed by
90 a ResultSpec. Delegates the actual work to the `value` and `artifacts` functions.
92 :param ResultSpec spec: The result spec
93 :return: either the value or the artifact computed from the spec
94 """
95 func = getattr(self, spec.result_type)
96 return func(spec.result_task_name, spec.result_var_name)
99class Pipeline:
101 def __init__(self, *tasks, name='default'):
103 self._tasks = tasks
104 self.task_graph = nx.DiGraph()
105 self.execution_order = []
106 self.name = name
107 self.store_path = settings.YENTA_STORE_PATH / self.name
109 self.store_path.mkdir(exist_ok=True, parents=True)
111 self.build_task_graph()
113 self._tasks_executed = set()
114 self._tasks_reused = set()
116 def _clear_pipeline_cache(self):
117 """ Delete the pipeline cache. Only used for testing purposes. """
118 shutil.rmtree(self.store_path) # pragma: no cover
120 def build_task_graph(self) -> None:
121 """ Construct the task graph for the pipeline
123 :return: None
124 """
125 logger.debug('Building task graph')
126 for task in self._tasks:
127 self.task_graph.add_node(task.task_def.name, task=task)
128 for dependency in (task.task_def.depends_on or []):
129 dependency = dependency.split('.')[0]
130 self.task_graph.add_edge(dependency, task.task_def.name)
132 logger.debug('Computing execution order')
133 try:
134 self.execution_order = list(nx.algorithms.dag.lexicographical_topological_sort(self.task_graph))
135 except nx.NetworkXUnfeasible as ex:
136 print(Fore.RED + 'Unable to build execution graph because pipeline contains cyclic dependencies.')
137 raise ex
139 @staticmethod
140 def _wrap_task_output(raw_output: Union[dict, TaskResult], task_name: str) -> TaskResult:
141 """ Wrap the raw output of a task in a TaskResult.
143 :param Union[dict, TaskResult] raw_output: The raw output of a task.
144 :param task_name: The name of the task.
145 :return: A TaskResult containing the output
146 :rtype: TaskResult
147 """
149 if isinstance(raw_output, dict):
150 output: TaskResult = TaskResult(**raw_output)
151 elif isinstance(raw_output, TaskResult):
152 output = raw_output
153 else:
154 raise InvalidTaskResultError(f'Task {task_name} returned invalid result of type {type(raw_output)}, '
155 f'expected either a dict or a TaskResult')
157 return output
159 @staticmethod
160 def build_args_dict(task, args: PipelineResult) -> Dict[str, Any]:
161 """ Build the args dictionary for executing a task.
163 :param task: The task itself, which has a `task_def` attached to it.
164 :param PipelineResult args: The results of the pipeline up to this point
165 :return: A dictionary whose keys correspond to the arguments expected by
166 the task to be executed, and whose values are the values to be
167 passed in.
168 :rtype: Dict[str, Any]
169 """
171 logger.debug('Building args dictionary')
172 task_def: TaskDef = task.task_def
173 args_dict = {}
175 for spec in task_def.param_specs:
176 if spec.param_type == ParameterType.PIPELINE_RESULTS:
177 args_dict[spec.param_name] = args
178 elif spec.param_type == ParameterType.EXPLICIT:
179 args_dict[spec.param_name] = args.values(spec.result_spec.result_task_name,
180 spec.result_spec.result_var_name)
182 return args_dict
184 def invoke_task(self, task, **kwargs) -> TaskResult:
185 """ Call the function that represents the task with the supplied kwargs.
187 :param Callable task: The task function.
188 :param dict kwargs: The arguments obtained from `build_args`.
189 :return: The task result
190 :rtype: TaskResult
191 """
193 output = task(**kwargs)
194 return self._wrap_task_output(output, task.task_def.name)
196 @staticmethod
197 def merge_pipeline_results(res1: PipelineResult, res2: PipelineResult) -> PipelineResult:
198 """ Combine two different pipeline results. If they share keys,
199 the results of the second pipeline will overwrite those of
200 the first.
202 :param PipelineResult res1: The first result.
203 :param PipelineResult res2: The second result.
204 :return: The merged result.
205 :rtype: PipelineResult
206 """
208 return PipelineResult(task_results={**res1.task_results, **res2.task_results},
209 task_inputs={**res1.task_inputs, **res2.task_inputs})
211 def cache_result(self, task_name: str, result: PipelineResult):
212 """ Write the pipeline results to a file.
214 :param str task_name: The name of the task to cache.
215 :param PipelineResult result: The results.
216 :return: None
217 """
218 task_path = self.store_path / task_name
219 task_path.mkdir(exist_ok=True, parents=True)
221 task_cache = task_path / 'result.pk'
222 with open(task_cache, 'wb') as f:
223 pickle.dump(result.task_results[task_name], f)
225 task_cache = task_path / 'inputs.pk'
226 with open(task_cache, 'wb') as f:
227 pickle.dump(result.task_inputs[task_name], f)
229 @staticmethod
230 def load_pipeline(store_path: Path) -> PipelineResult:
231 """ Load a pipeline from file.
233 :return: The pipeline.
234 :rtype: PipelineResult
235 """
236 logger.debug(f'Loading pipeline from {store_path}')
237 pipeline = PipelineResult()
238 if store_path.exists():
239 for task_path in store_path.iterdir():
240 if task_path.is_dir():
241 task_name = task_path.stem
242 with open(task_path / 'inputs.pk', 'rb') as f:
243 inputs = pickle.load(f)
244 with open(task_path / 'result.pk', 'rb') as f:
245 result = pickle.load(f)
246 pipeline.task_inputs[task_name] = inputs
247 pipeline.task_results[task_name] = result
249 return pipeline
251 @staticmethod
252 def reuse_inputs(task_name: str, previous_result: PipelineResult, args: PipelineResult) -> bool:
253 """ Determine whether inputs from the previous instance of this task should be reused
254 or whether the task should be executed again.
256 :param str task_name: The name of the task.
257 :param PipelineResult previous_result: The previous pipeline result.
258 :param PipelineResult args: The arguments with which this task is being called.
259 :return: True or False
260 :rtype: bool
261 """
262 previous_inputs = previous_result.task_inputs.get(task_name, None)
263 if previous_inputs and previous_result.task_results.get(task_name).status == TaskStatus.SUCCESS:
264 return previous_inputs == args
266 return False
268 def run_pipeline(self, up_to: str = None, force_rerun: List[str] = None) -> PipelineResult:
269 """ Execute the tasks in the pipeline.
271 :param str up_to: If supplied, execute the pipeline only up to this task.
272 :param List[str] force_rerun: Optionally force the listed tasks to be executed.
273 :return: The final pipeline state.
274 :rtype: PipelineResult
275 """
277 previous_result: PipelineResult = self.load_pipeline(self.store_path)
278 result = PipelineResult()
279 self._tasks_reused.clear()
280 self._tasks_executed.clear()
282 for task_name in list(split_after(self.execution_order, lambda x: x == up_to))[0]:
283 logger.debug(f'Starting executions of {task_name}')
284 task_node = self.task_graph.nodes.get(task_name, None)
285 if not task_node:
286 raise PipelineConfigError(f'Dependency on nonexistent task: {task_name}')
287 task = task_node['task']
288 args = PipelineResult()
289 dependencies_succeeded = True
290 for dependency in (task.task_def.depends_on or []):
291 dependency = dependency.split('.')[0]
292 args.task_results[dependency] = result.task_results[dependency]
293 if result.task_results[dependency].status == TaskStatus.FAILURE:
294 dependencies_succeeded = False
295 break
297 if dependencies_succeeded:
298 if task.task_def.pure and task_name not in (force_rerun or []) and \
299 self.reuse_inputs(task_name, previous_result, args):
300 logger.debug(f'Reusing previous results of {task_name}')
301 self._tasks_reused.add(task_name)
302 output = previous_result.task_results[task_name]
303 marker = Fore.YELLOW + u'\u2014' + Fore.WHITE
304 else:
305 args_dict = self.build_args_dict(task, args)
306 try:
307 logger.debug(f'Calling function to execute {task_name}')
308 output = self.invoke_task(task, **args_dict)
309 output.status = TaskStatus.SUCCESS
310 marker = Fore.GREEN + u'\u2714' + Fore.WHITE
311 self._tasks_executed.add(task_name)
312 except Exception as ex:
313 import traceback
314 print(Fore.RED)
315 traceback.print_exc()
316 print(Fore.WHITE)
317 logger.error(f'Caught exception executing {task_name}: {ex}')
318 output = TaskResult(status=TaskStatus.FAILURE, error=str(ex))
319 marker = Fore.RED + u'\u2718' + Fore.WHITE
321 print(Fore.WHITE + Style.BRIGHT + f'[{marker}] {task_name}')
323 result.task_results[task_name] = output
324 result.task_inputs[task_name] = args
326 result = self.merge_pipeline_results(previous_result, result)
327 self.cache_result(task_name, result)
329 return result