Coverage for src/importnb/loader.py: 90%
279 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-08 11:18 -0800
« prev ^ index » next coverage.py v6.5.0, created at 2023-01-08 11:18 -0800
1# coding: utf-8
2"""# `loader`
4the loading machinery for notebooks style documents, and less.
5notebooks combine code, markdown, and raw cells to create a complete document.
6the importnb loader provides an interface for transforming these objects to valid python.
7"""
10import ast
11import inspect
12import re
13import shlex
14import sys
15import textwrap
16from contextlib import contextmanager
17from dataclasses import asdict, dataclass, field
18from functools import partial
19from importlib import _bootstrap as bootstrap
20from importlib import reload
21from importlib._bootstrap import _init_module_attrs, _requires_builtin
22from importlib._bootstrap_external import FileFinder, decode_source
23from importlib.machinery import ModuleSpec, SourceFileLoader
24from importlib.util import LazyLoader, find_spec
25from pathlib import Path
26from types import ModuleType
28from . import get_ipython
29from .decoder import LineCacheNotebookDecoder, quote
30from .docstrings import update_docstring
31from .finder import FileModuleSpec, FuzzyFinder, get_loader_details, get_loader_index
33__all__ = "Notebook", "reload"
35VERSION = sys.version_info.major, sys.version_info.minor
37MAGIC = re.compile("^\s*%{2}", re.MULTILINE)
38ALLOW_TOP_LEVEL_AWAIT = getattr(ast, "PyCF_ALLOW_TOP_LEVEL_AWAIT", 0x0)
41def _get_co_flags_set(co_flags):
42 """return a deconstructed set of code flags from a code object."""
43 flags = set()
44 for i in range(12):
45 flag = 1 << i
46 if co_flags & flag:
47 flags.add(flag)
48 co_flags ^= flag
49 if not co_flags:
50 break
51 else:
52 flags.intersection_update(flags)
53 return flags
56class SourceModule(ModuleType):
57 def __fspath__(self):
58 return self.__file__
61@dataclass
62class Interface:
63 """a configuration python importing interface"""
65 name: str = None
66 path: str = None
67 lazy: bool = False
68 extensions: tuple = field(default_factory=[".ipy", ".ipynb"].copy)
69 include_fuzzy_finder: bool = True
70 include_markdown_docstring: bool = True
71 include_non_defs: bool = True
72 include_await: bool = True
73 module_type: ModuleType = field(default=SourceModule)
74 no_magic: bool = False
76 _loader_hook_position: int = field(default=0, repr=False)
78 def __new__(cls, name=None, path=None, **kwargs):
79 kwargs.update(name=name, path=path)
80 self = super().__new__(cls)
81 self.__init__(**kwargs)
82 return self
85class Loader(Interface, SourceFileLoader):
86 """The simplest implementation of a Notebook Source File Loader.
87 This class breaks down the loading process into finer steps."""
89 extensions: tuple = field(default_factory=[".py"].copy)
91 @property
92 def loader(self):
93 """generate a new loader based on the state of an existing loader."""
94 loader = type(self)
95 if self.lazy:
96 loader = LazyLoader.factory(loader)
97 # Strip the leading underscore from slots
98 params = asdict(self)
99 params.pop("name")
100 params.pop("path")
101 return partial(loader, **params)
103 @property
104 def finder(self):
105 """generate a new finder based on the state of an existing loader"""
106 return self.include_fuzzy_finder and FuzzyFinder or FileFinder
108 def raw_to_source(self, source):
109 """transform a string from a raw file to python source."""
110 if self.path and self.path.endswith(".ipynb"):
111 # when we encounter notebooks we apply different transformers to the diff cell types
112 return LineCacheNotebookDecoder(
113 code=self.code, raw=self.raw, markdown=self.markdown
114 ).decode(source, self.path)
116 # for a normal file we just apply the code transformer.
117 return self.code(source)
119 def source_to_nodes(self, source, path="<unknown>", *, _optimize=-1):
120 """parse source string as python ast"""
121 flags = ast.PyCF_ONLY_AST
122 return bootstrap._call_with_frames_removed(
123 compile, source, path, "exec", flags=flags, dont_inherit=True, optimize=_optimize
124 )
126 def nodes_to_code(self, nodes, path="<unknown>", *, _optimize=-1):
127 """compile ast nodes to python code object"""
128 flags = ALLOW_TOP_LEVEL_AWAIT
129 return bootstrap._call_with_frames_removed(
130 compile, nodes, path, "exec", flags=flags, dont_inherit=True, optimize=_optimize
131 )
133 def source_to_code(self, source, path="<unknown>", *, _optimize=-1):
134 """tangle python source to compiled code by:
135 1. parsing the source as ast nodes
136 2. compiling the ast nodes as python code"""
137 nodes = self.source_to_nodes(source, path, _optimize=_optimize)
138 return self.nodes_to_code(nodes, path, _optimize=_optimize)
140 def get_data(self, path):
141 """get_data injects an input transformation before the raw text.
143 this method allows notebook json to be transformed line for line into vertically sparse python code."""
144 return self.raw_to_source(decode_source(super().get_data(self.path)))
146 def create_module(self, spec):
147 """an overloaded create_module method injecting fuzzy finder setup up logic."""
148 module = self.module_type(str(spec.name))
149 _init_module_attrs(spec, module)
150 if self.name:
151 module.__name__ = self.name
153 if module.__file__.endswith((".ipynb", ".ipy")):
154 module.get_ipython = get_ipython
156 if getattr(spec, "alias", None):
157 # put a fuzzy spec on the modules to avoid re importing it.
158 # there is a funky trick you do with the fuzzy finder where you
159 # load multiple versions with different finders.
161 sys.modules[spec.alias] = module
163 return module
165 def exec_module(self, module):
166 """Execute the module."""
167 # importlib uses module.__name__, but when running modules as __main__ name will change.
168 # this approach uses the original name on the spec.
169 try:
170 code = self.get_code(module.__spec__.name)
172 # from importlib
173 if code is None:
174 raise ImportError(
175 f"cannot load module {module.__name__!r} when " "get_code() returns None"
176 )
178 if inspect.CO_COROUTINE not in _get_co_flags_set(code.co_flags):
179 # if there isn't any async non sense then we proceed with convention.
180 bootstrap._call_with_frames_removed(exec, code, module.__dict__)
181 else:
182 self.aexec_module_sync(module)
184 except BaseException as e:
185 alias = getattr(module.__spec__, "alias", None)
186 if alias:
187 sys.modules.pop(alias, None)
189 raise e
191 def aexec_module_sync(self, module):
192 if "anyio" in sys.modules:
193 import anyio
195 __import__("anyio").run(self.aexec_module, module)
196 else:
197 from asyncio import get_event_loop
199 get_event_loop().run_until_complete(self.aexec_module(module))
201 async def aexec_module(self, module):
202 """an async exec_module method permitting top-level await."""
203 # there is so redudancy in this approach, but it starts getting asynchier.
204 nodes = self.source_to_nodes(self.get_data(self.path))
206 # iterate through the nodes and compile individual statements
207 for node in nodes.body:
208 co = bootstrap._call_with_frames_removed(
209 compile,
210 ast.Module([node], []),
211 module.__file__,
212 "exec",
213 flags=ALLOW_TOP_LEVEL_AWAIT,
214 )
215 if inspect.CO_COROUTINE in _get_co_flags_set(co.co_flags):
216 # when something async is encountered we compile it with the single flag
217 # this lets us use eval to retreive our coroutine.
218 co = bootstrap._call_with_frames_removed(
219 compile,
220 ast.Interactive([node]),
221 module.__file__,
222 "single",
223 flags=ALLOW_TOP_LEVEL_AWAIT,
224 )
225 await bootstrap._call_with_frames_removed(
226 eval, co, module.__dict__, module.__dict__
227 )
228 else:
229 bootstrap._call_with_frames_removed(exec, co, module.__dict__, module.__dict__)
231 def code(self, str):
232 return dedent(str)
234 @classmethod
235 @_requires_builtin
236 def is_package(cls, fullname):
237 """Return False as built-in modules are never packages."""
238 if "." not in fullname:
239 return True
240 return super().is_package(fullname)
242 def __enter__(self):
243 path_id, loader_id, details = get_loader_index(".py")
244 for _, e in details:
245 if all(map(e.__contains__, self.extensions)):
246 self._loader_hook_position = None
247 return self
248 else:
249 self._loader_hook_position = loader_id + 1
250 details.insert(self._loader_hook_position, (self.loader, self.extensions))
251 sys.path_hooks[path_id] = self.finder.path_hook(*details)
252 sys.path_importer_cache.clear()
253 return self
255 def __exit__(self, *excepts):
256 if self._loader_hook_position is not None:
257 path_id, details = get_loader_details()
258 details.pop(self._loader_hook_position)
259 sys.path_hooks[path_id] = self.finder.path_hook(*details)
260 sys.path_importer_cache.clear()
262 @classmethod
263 def load_file(cls, filename, main=True, **kwargs):
264 """Import a notebook as a module from a filename.
266 dir: The directory to load the file from.
267 main: Load the module in the __main__ context.
269 >>> assert Notebook.load_file('foo.ipynb')
270 """
271 name = main and "__main__" or filename
272 loader = cls(name, str(filename), **kwargs)
273 spec = FileModuleSpec(name, loader, origin=loader.path)
274 module = loader.create_module(spec)
275 loader.exec_module(module)
276 return module
278 @classmethod
279 def load_module(cls, module, main=False, **kwargs):
280 """Import a notebook as a module.
282 main: Load the module in the __main__ context.
284 >>> assert Notebook.load_module('foo')
285 """
286 from runpy import _run_module_as_main, run_module, _get_module_details
287 from importlib.util import module_from_spec
289 with cls() as loader:
290 mod_name, spec, code = _get_module_details(module)
291 module = module_from_spec(spec)
292 if main:
293 sys.modules["__main__"] = module
294 module.__name__ = "__main__"
295 spec.loader.exec_module(module)
296 return module
298 @classmethod
299 def load_argv(cls, argv=None, *, parser=None):
300 """load a module based on python arguments
302 load a notebook from its file name
303 >>> Notebook.load_argv("foo.ipynb --arg abc")
305 load the same notebook from a module alias.
306 >>> Notebook.load_argv("-m foo --arg abc")
307 """
308 if parser is None:
309 parser = cls.get_argparser()
311 if argv is None:
312 from sys import argv
314 argv = argv[1:]
316 if isinstance(argv, str):
317 argv = shlex.split(argv)
319 module = cls.load_ns(parser.parse_args(argv))
320 if module is None:
321 return parser.print_help()
323 return module
325 @classmethod
326 def load_ns(cls, ns):
327 """load a module from a namespace, used when loading module from sys.argv parameters."""
328 if ns.tasks:
329 # i don't quite why we need to do this here, but we do. so don't move it
330 from doit.cmd_base import ModuleTaskLoader
331 from doit.doit_cmd import DoitMain
333 if ns.code:
334 with main_argv(sys.argv[0], ns.args):
335 result = cls.load_code(ns.code)
336 elif ns.module:
337 if ns.dir:
338 if ns.dir not in sys.path:
339 sys.path.insert(0, ns.dir)
340 elif "" in sys.path:
341 pass
342 else:
343 sys.path.insert(0, "")
344 with main_argv(ns.module, ns.args):
345 result = cls.load_module(ns.module, main=True)
346 elif ns.file:
347 where = Path(ns.dir, ns.file) if ns.dir else Path(ns.file)
348 with main_argv(str(where), ns.args):
349 result = cls.load_file(ns.file)
350 else:
351 return
352 if ns.tasks:
353 DoitMain(ModuleTaskLoader(result)).run(ns.args or ["help"])
354 return result
356 @classmethod
357 def load_code(cls, code, argv=None, mod_name=None, script_name=None, main=False):
358 """load a module from raw source code"""
360 from runpy import _run_module_code
362 self = cls()
363 name = main and "__main__" or mod_name or "<raw code>"
365 return _dict_module(
366 _run_module_code(self.raw_to_source(code), mod_name=name, script_name=script_name)
367 )
369 @staticmethod
370 def get_argparser(parser=None):
371 from argparse import REMAINDER, ArgumentParser
373 if parser is None:
374 parser = ArgumentParser("importnb", description="run notebooks as python code")
375 parser.add_argument("file", nargs="?", help="run a file")
376 parser.add_argument("args", nargs=REMAINDER, help="arguments to pass to script")
377 parser.add_argument("-m", "--module", help="run a module")
378 parser.add_argument("-c", "--code", help="run raw code")
379 parser.add_argument("-d", "--dir", help="path to run script in")
380 parser.add_argument("-t", "--tasks", action="store_true", help="run doit tasks")
381 return parser
384def comment(str):
385 return textwrap.indent(str, "# ")
388class DefsOnly(ast.NodeTransformer):
389 INCLUDE = ast.Import, ast.ImportFrom, ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef
391 def visit_Module(self, node):
392 args = ([x for x in node.body if isinstance(x, self.INCLUDE)],)
393 if VERSION >= (3, 8):
394 args += (node.type_ignores,)
395 return ast.Module(*args)
398class Notebook(Loader):
399 """Notebook is a user friendly file finder and module loader for notebook source code.
401 > Remember, restart and run all or it didn't happen.
403 Notebook provides several useful options.
405 * Lazy module loading. A module is executed the first time it is used in a script.
406 """
408 def markdown(self, str):
409 return quote(str)
411 def raw(self, str):
412 return comment(str)
414 def visit(self, nodes):
415 if self.include_non_defs:
416 return nodes
417 return DefsOnly().visit(nodes)
419 def code(self, str):
420 if self.no_magic:
421 if MAGIC.match(str):
422 return comment(str)
423 return super().code(str)
425 def source_to_nodes(self, source, path="<unknown>", *, _optimize=-1):
426 nodes = super().source_to_nodes(source, path)
427 if self.include_markdown_docstring:
428 nodes = update_docstring(nodes)
429 nodes = self.visit(nodes)
430 return ast.fix_missing_locations(nodes)
432 def raw_to_source(self, source):
433 """transform a string from a raw file to python source."""
434 if self.path and self.path.endswith(".ipynb"):
435 # when we encounter notebooks we apply different transformers to the diff cell types
436 return LineCacheNotebookDecoder(
437 code=self.code, raw=self.raw, markdown=self.markdown
438 ).decode(source, self.path)
440 # for a normal file we just apply the code transformer.
441 return self.code(source)
444def _dict_module(ns):
445 m = ModuleType(ns.get("__name__"), ns.get("__doc__"))
446 m.__dict__.update(ns)
447 return m
450@contextmanager
451def main_argv(prog, args=None):
452 if args is not None:
453 args = [prog] + list(args)
454 prior, sys.argv = sys.argv, args
455 yield
456 if args is not None:
457 sys.argv = prior
460try:
461 from IPython.core.inputtransformer2 import TransformerManager
463 dedent = TransformerManager().transform_cell
464except ModuleNotFoundError:
466 def dedent(body):
467 from textwrap import dedent, indent
469 if MAGIC.match(body):
470 return indent(body, "# ")
471 return dedent(body)