Coverage for src/importnb/loader.py: 90%

279 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-01-08 11:18 -0800

1# coding: utf-8 

2"""# `loader` 

3 

4the loading machinery for notebooks style documents, and less. 

5notebooks combine code, markdown, and raw cells to create a complete document. 

6the importnb loader provides an interface for transforming these objects to valid python. 

7""" 

8 

9 

10import ast 

11import inspect 

12import re 

13import shlex 

14import sys 

15import textwrap 

16from contextlib import contextmanager 

17from dataclasses import asdict, dataclass, field 

18from functools import partial 

19from importlib import _bootstrap as bootstrap 

20from importlib import reload 

21from importlib._bootstrap import _init_module_attrs, _requires_builtin 

22from importlib._bootstrap_external import FileFinder, decode_source 

23from importlib.machinery import ModuleSpec, SourceFileLoader 

24from importlib.util import LazyLoader, find_spec 

25from pathlib import Path 

26from types import ModuleType 

27 

28from . import get_ipython 

29from .decoder import LineCacheNotebookDecoder, quote 

30from .docstrings import update_docstring 

31from .finder import FileModuleSpec, FuzzyFinder, get_loader_details, get_loader_index 

32 

33__all__ = "Notebook", "reload" 

34 

35VERSION = sys.version_info.major, sys.version_info.minor 

36 

37MAGIC = re.compile("^\s*%{2}", re.MULTILINE) 

38ALLOW_TOP_LEVEL_AWAIT = getattr(ast, "PyCF_ALLOW_TOP_LEVEL_AWAIT", 0x0) 

39 

40 

41def _get_co_flags_set(co_flags): 

42 """return a deconstructed set of code flags from a code object.""" 

43 flags = set() 

44 for i in range(12): 

45 flag = 1 << i 

46 if co_flags & flag: 

47 flags.add(flag) 

48 co_flags ^= flag 

49 if not co_flags: 

50 break 

51 else: 

52 flags.intersection_update(flags) 

53 return flags 

54 

55 

56class SourceModule(ModuleType): 

57 def __fspath__(self): 

58 return self.__file__ 

59 

60 

61@dataclass 

62class Interface: 

63 """a configuration python importing interface""" 

64 

65 name: str = None 

66 path: str = None 

67 lazy: bool = False 

68 extensions: tuple = field(default_factory=[".ipy", ".ipynb"].copy) 

69 include_fuzzy_finder: bool = True 

70 include_markdown_docstring: bool = True 

71 include_non_defs: bool = True 

72 include_await: bool = True 

73 module_type: ModuleType = field(default=SourceModule) 

74 no_magic: bool = False 

75 

76 _loader_hook_position: int = field(default=0, repr=False) 

77 

78 def __new__(cls, name=None, path=None, **kwargs): 

79 kwargs.update(name=name, path=path) 

80 self = super().__new__(cls) 

81 self.__init__(**kwargs) 

82 return self 

83 

84 

85class Loader(Interface, SourceFileLoader): 

86 """The simplest implementation of a Notebook Source File Loader. 

87 This class breaks down the loading process into finer steps.""" 

88 

89 extensions: tuple = field(default_factory=[".py"].copy) 

90 

91 @property 

92 def loader(self): 

93 """generate a new loader based on the state of an existing loader.""" 

94 loader = type(self) 

95 if self.lazy: 

96 loader = LazyLoader.factory(loader) 

97 # Strip the leading underscore from slots 

98 params = asdict(self) 

99 params.pop("name") 

100 params.pop("path") 

101 return partial(loader, **params) 

102 

103 @property 

104 def finder(self): 

105 """generate a new finder based on the state of an existing loader""" 

106 return self.include_fuzzy_finder and FuzzyFinder or FileFinder 

107 

108 def raw_to_source(self, source): 

109 """transform a string from a raw file to python source.""" 

110 if self.path and self.path.endswith(".ipynb"): 

111 # when we encounter notebooks we apply different transformers to the diff cell types 

112 return LineCacheNotebookDecoder( 

113 code=self.code, raw=self.raw, markdown=self.markdown 

114 ).decode(source, self.path) 

115 

116 # for a normal file we just apply the code transformer. 

117 return self.code(source) 

118 

119 def source_to_nodes(self, source, path="<unknown>", *, _optimize=-1): 

120 """parse source string as python ast""" 

121 flags = ast.PyCF_ONLY_AST 

122 return bootstrap._call_with_frames_removed( 

123 compile, source, path, "exec", flags=flags, dont_inherit=True, optimize=_optimize 

124 ) 

125 

126 def nodes_to_code(self, nodes, path="<unknown>", *, _optimize=-1): 

127 """compile ast nodes to python code object""" 

128 flags = ALLOW_TOP_LEVEL_AWAIT 

129 return bootstrap._call_with_frames_removed( 

130 compile, nodes, path, "exec", flags=flags, dont_inherit=True, optimize=_optimize 

131 ) 

132 

133 def source_to_code(self, source, path="<unknown>", *, _optimize=-1): 

134 """tangle python source to compiled code by: 

135 1. parsing the source as ast nodes 

136 2. compiling the ast nodes as python code""" 

137 nodes = self.source_to_nodes(source, path, _optimize=_optimize) 

138 return self.nodes_to_code(nodes, path, _optimize=_optimize) 

139 

140 def get_data(self, path): 

141 """get_data injects an input transformation before the raw text. 

142 

143 this method allows notebook json to be transformed line for line into vertically sparse python code.""" 

144 return self.raw_to_source(decode_source(super().get_data(self.path))) 

145 

146 def create_module(self, spec): 

147 """an overloaded create_module method injecting fuzzy finder setup up logic.""" 

148 module = self.module_type(str(spec.name)) 

149 _init_module_attrs(spec, module) 

150 if self.name: 

151 module.__name__ = self.name 

152 

153 if module.__file__.endswith((".ipynb", ".ipy")): 

154 module.get_ipython = get_ipython 

155 

156 if getattr(spec, "alias", None): 

157 # put a fuzzy spec on the modules to avoid re importing it. 

158 # there is a funky trick you do with the fuzzy finder where you 

159 # load multiple versions with different finders. 

160 

161 sys.modules[spec.alias] = module 

162 

163 return module 

164 

165 def exec_module(self, module): 

166 """Execute the module.""" 

167 # importlib uses module.__name__, but when running modules as __main__ name will change. 

168 # this approach uses the original name on the spec. 

169 try: 

170 code = self.get_code(module.__spec__.name) 

171 

172 # from importlib 

173 if code is None: 

174 raise ImportError( 

175 f"cannot load module {module.__name__!r} when " "get_code() returns None" 

176 ) 

177 

178 if inspect.CO_COROUTINE not in _get_co_flags_set(code.co_flags): 

179 # if there isn't any async non sense then we proceed with convention. 

180 bootstrap._call_with_frames_removed(exec, code, module.__dict__) 

181 else: 

182 self.aexec_module_sync(module) 

183 

184 except BaseException as e: 

185 alias = getattr(module.__spec__, "alias", None) 

186 if alias: 

187 sys.modules.pop(alias, None) 

188 

189 raise e 

190 

191 def aexec_module_sync(self, module): 

192 if "anyio" in sys.modules: 

193 import anyio 

194 

195 __import__("anyio").run(self.aexec_module, module) 

196 else: 

197 from asyncio import get_event_loop 

198 

199 get_event_loop().run_until_complete(self.aexec_module(module)) 

200 

201 async def aexec_module(self, module): 

202 """an async exec_module method permitting top-level await.""" 

203 # there is so redudancy in this approach, but it starts getting asynchier. 

204 nodes = self.source_to_nodes(self.get_data(self.path)) 

205 

206 # iterate through the nodes and compile individual statements 

207 for node in nodes.body: 

208 co = bootstrap._call_with_frames_removed( 

209 compile, 

210 ast.Module([node], []), 

211 module.__file__, 

212 "exec", 

213 flags=ALLOW_TOP_LEVEL_AWAIT, 

214 ) 

215 if inspect.CO_COROUTINE in _get_co_flags_set(co.co_flags): 

216 # when something async is encountered we compile it with the single flag 

217 # this lets us use eval to retreive our coroutine. 

218 co = bootstrap._call_with_frames_removed( 

219 compile, 

220 ast.Interactive([node]), 

221 module.__file__, 

222 "single", 

223 flags=ALLOW_TOP_LEVEL_AWAIT, 

224 ) 

225 await bootstrap._call_with_frames_removed( 

226 eval, co, module.__dict__, module.__dict__ 

227 ) 

228 else: 

229 bootstrap._call_with_frames_removed(exec, co, module.__dict__, module.__dict__) 

230 

231 def code(self, str): 

232 return dedent(str) 

233 

234 @classmethod 

235 @_requires_builtin 

236 def is_package(cls, fullname): 

237 """Return False as built-in modules are never packages.""" 

238 if "." not in fullname: 

239 return True 

240 return super().is_package(fullname) 

241 

242 def __enter__(self): 

243 path_id, loader_id, details = get_loader_index(".py") 

244 for _, e in details: 

245 if all(map(e.__contains__, self.extensions)): 

246 self._loader_hook_position = None 

247 return self 

248 else: 

249 self._loader_hook_position = loader_id + 1 

250 details.insert(self._loader_hook_position, (self.loader, self.extensions)) 

251 sys.path_hooks[path_id] = self.finder.path_hook(*details) 

252 sys.path_importer_cache.clear() 

253 return self 

254 

255 def __exit__(self, *excepts): 

256 if self._loader_hook_position is not None: 

257 path_id, details = get_loader_details() 

258 details.pop(self._loader_hook_position) 

259 sys.path_hooks[path_id] = self.finder.path_hook(*details) 

260 sys.path_importer_cache.clear() 

261 

262 @classmethod 

263 def load_file(cls, filename, main=True, **kwargs): 

264 """Import a notebook as a module from a filename. 

265 

266 dir: The directory to load the file from. 

267 main: Load the module in the __main__ context. 

268 

269 >>> assert Notebook.load_file('foo.ipynb') 

270 """ 

271 name = main and "__main__" or filename 

272 loader = cls(name, str(filename), **kwargs) 

273 spec = FileModuleSpec(name, loader, origin=loader.path) 

274 module = loader.create_module(spec) 

275 loader.exec_module(module) 

276 return module 

277 

278 @classmethod 

279 def load_module(cls, module, main=False, **kwargs): 

280 """Import a notebook as a module. 

281 

282 main: Load the module in the __main__ context. 

283 

284 >>> assert Notebook.load_module('foo') 

285 """ 

286 from runpy import _run_module_as_main, run_module, _get_module_details 

287 from importlib.util import module_from_spec 

288 

289 with cls() as loader: 

290 mod_name, spec, code = _get_module_details(module) 

291 module = module_from_spec(spec) 

292 if main: 

293 sys.modules["__main__"] = module 

294 module.__name__ = "__main__" 

295 spec.loader.exec_module(module) 

296 return module 

297 

298 @classmethod 

299 def load_argv(cls, argv=None, *, parser=None): 

300 """load a module based on python arguments 

301 

302 load a notebook from its file name 

303 >>> Notebook.load_argv("foo.ipynb --arg abc") 

304 

305 load the same notebook from a module alias. 

306 >>> Notebook.load_argv("-m foo --arg abc") 

307 """ 

308 if parser is None: 

309 parser = cls.get_argparser() 

310 

311 if argv is None: 

312 from sys import argv 

313 

314 argv = argv[1:] 

315 

316 if isinstance(argv, str): 

317 argv = shlex.split(argv) 

318 

319 module = cls.load_ns(parser.parse_args(argv)) 

320 if module is None: 

321 return parser.print_help() 

322 

323 return module 

324 

325 @classmethod 

326 def load_ns(cls, ns): 

327 """load a module from a namespace, used when loading module from sys.argv parameters.""" 

328 if ns.tasks: 

329 # i don't quite why we need to do this here, but we do. so don't move it 

330 from doit.cmd_base import ModuleTaskLoader 

331 from doit.doit_cmd import DoitMain 

332 

333 if ns.code: 

334 with main_argv(sys.argv[0], ns.args): 

335 result = cls.load_code(ns.code) 

336 elif ns.module: 

337 if ns.dir: 

338 if ns.dir not in sys.path: 

339 sys.path.insert(0, ns.dir) 

340 elif "" in sys.path: 

341 pass 

342 else: 

343 sys.path.insert(0, "") 

344 with main_argv(ns.module, ns.args): 

345 result = cls.load_module(ns.module, main=True) 

346 elif ns.file: 

347 where = Path(ns.dir, ns.file) if ns.dir else Path(ns.file) 

348 with main_argv(str(where), ns.args): 

349 result = cls.load_file(ns.file) 

350 else: 

351 return 

352 if ns.tasks: 

353 DoitMain(ModuleTaskLoader(result)).run(ns.args or ["help"]) 

354 return result 

355 

356 @classmethod 

357 def load_code(cls, code, argv=None, mod_name=None, script_name=None, main=False): 

358 """load a module from raw source code""" 

359 

360 from runpy import _run_module_code 

361 

362 self = cls() 

363 name = main and "__main__" or mod_name or "<raw code>" 

364 

365 return _dict_module( 

366 _run_module_code(self.raw_to_source(code), mod_name=name, script_name=script_name) 

367 ) 

368 

369 @staticmethod 

370 def get_argparser(parser=None): 

371 from argparse import REMAINDER, ArgumentParser 

372 

373 if parser is None: 

374 parser = ArgumentParser("importnb", description="run notebooks as python code") 

375 parser.add_argument("file", nargs="?", help="run a file") 

376 parser.add_argument("args", nargs=REMAINDER, help="arguments to pass to script") 

377 parser.add_argument("-m", "--module", help="run a module") 

378 parser.add_argument("-c", "--code", help="run raw code") 

379 parser.add_argument("-d", "--dir", help="path to run script in") 

380 parser.add_argument("-t", "--tasks", action="store_true", help="run doit tasks") 

381 return parser 

382 

383 

384def comment(str): 

385 return textwrap.indent(str, "# ") 

386 

387 

388class DefsOnly(ast.NodeTransformer): 

389 INCLUDE = ast.Import, ast.ImportFrom, ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef 

390 

391 def visit_Module(self, node): 

392 args = ([x for x in node.body if isinstance(x, self.INCLUDE)],) 

393 if VERSION >= (3, 8): 

394 args += (node.type_ignores,) 

395 return ast.Module(*args) 

396 

397 

398class Notebook(Loader): 

399 """Notebook is a user friendly file finder and module loader for notebook source code. 

400 

401 > Remember, restart and run all or it didn't happen. 

402 

403 Notebook provides several useful options. 

404 

405 * Lazy module loading. A module is executed the first time it is used in a script. 

406 """ 

407 

408 def markdown(self, str): 

409 return quote(str) 

410 

411 def raw(self, str): 

412 return comment(str) 

413 

414 def visit(self, nodes): 

415 if self.include_non_defs: 

416 return nodes 

417 return DefsOnly().visit(nodes) 

418 

419 def code(self, str): 

420 if self.no_magic: 

421 if MAGIC.match(str): 

422 return comment(str) 

423 return super().code(str) 

424 

425 def source_to_nodes(self, source, path="<unknown>", *, _optimize=-1): 

426 nodes = super().source_to_nodes(source, path) 

427 if self.include_markdown_docstring: 

428 nodes = update_docstring(nodes) 

429 nodes = self.visit(nodes) 

430 return ast.fix_missing_locations(nodes) 

431 

432 def raw_to_source(self, source): 

433 """transform a string from a raw file to python source.""" 

434 if self.path and self.path.endswith(".ipynb"): 

435 # when we encounter notebooks we apply different transformers to the diff cell types 

436 return LineCacheNotebookDecoder( 

437 code=self.code, raw=self.raw, markdown=self.markdown 

438 ).decode(source, self.path) 

439 

440 # for a normal file we just apply the code transformer. 

441 return self.code(source) 

442 

443 

444def _dict_module(ns): 

445 m = ModuleType(ns.get("__name__"), ns.get("__doc__")) 

446 m.__dict__.update(ns) 

447 return m 

448 

449 

450@contextmanager 

451def main_argv(prog, args=None): 

452 if args is not None: 

453 args = [prog] + list(args) 

454 prior, sys.argv = sys.argv, args 

455 yield 

456 if args is not None: 

457 sys.argv = prior 

458 

459 

460try: 

461 from IPython.core.inputtransformer2 import TransformerManager 

462 

463 dedent = TransformerManager().transform_cell 

464except ModuleNotFoundError: 

465 

466 def dedent(body): 

467 from textwrap import dedent, indent 

468 

469 if MAGIC.match(body): 

470 return indent(body, "# ") 

471 return dedent(body)