Coverage for src/importnb/entry_points.py: 82%

45 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2022-12-04 10:42 -0800

1from .loader import Loader 

2from dataclasses import dataclass, field 

3from types import MethodType 

4from contextlib import contextmanager, ExitStack 

5 

6 

7def _get_importnb_entry_points(): 

8 try: 

9 from importlib.metadata import entry_points 

10 

11 yield from entry_points()["importnb"] 

12 except ModuleNotFoundError: 

13 from importlib_metadata import entry_points 

14 

15 yield from entry_points(group="importnb") 

16 

17 

18__all__ = ("imports",) 

19ENTRY_POINTS = dict() 

20 

21 

22def get_importnb_entry_points(): 

23 """discover the known importnb entry points""" 

24 global ENTRY_POINTS 

25 for ep in _get_importnb_entry_points(): 

26 ENTRY_POINTS[ep.name] = ep.value 

27 return ENTRY_POINTS 

28 

29 

30def loader_from_alias(alias): 

31 """load an attribute from a module using the entry points value specificaiton""" 

32 from importlib import import_module 

33 from operator import attrgetter 

34 

35 module, _, member = alias.rpartition(":") 

36 module = import_module(module) 

37 return attrgetter(member)(module) 

38 

39 

40def loader_from_ep(alias): 

41 """discover a loader for an importnb alias or vaue""" 

42 if ":" in alias: 

43 return loader_from_alias(alias) 

44 

45 if not ENTRY_POINTS: 

46 get_importnb_entry_points() 

47 

48 if alias in ENTRY_POINTS: 

49 return loader_from_alias(ENTRY_POINTS[alias]) 

50 

51 raise ValueError(f"{alias} is not a valid loader alias.") 

52 

53 

54@contextmanager 

55def imports(*names): 

56 """a shortcut to importnb loaders through entrypoints""" 

57 types = set() 

58 with ExitStack() as stack: 

59 for name in names: 

60 t = loader_from_ep(name) 

61 if t not in types: 

62 stack.enter_context(t()) 

63 types.add(t) 

64 yield stack 

65 

66 

67def list_aliases(): 

68 """list the entry points associated with importnb""" 

69 if not ENTRY_POINTS: 

70 get_importnb_entry_points() 

71 return list(ENTRY_POINTS)