Coverage for src/importnb/parameterize.py: 0%

91 statements  

« prev     ^ index     » next       coverage.py v6.4.4, created at 2022-10-02 18:31 -0700

1# coding: utf-8 

2"""# Parameterize 

3 

4The parameterize loader allows notebooks to be used as functions and command line tools. A `Parameterize` loader will convert an literal ast assigments to keyword arguments for the module. 

5""" 

6 

7import argparse 

8import ast 

9import inspect 

10import sys 

11from copy import deepcopy 

12from functools import partial, partialmethod 

13from importlib.util import find_spec, spec_from_loader 

14from inspect import Parameter, Signature, signature 

15from pathlib import Path 

16 

17from .loader import _GTE38, Notebook, module_from_spec 

18 

19if _GTE38: 

20 from importlib._bootstrap import _load_unlocked 

21else: 

22 from importlib._bootstrap import _installed_safely 

23 

24 

25class FindReplace(ast.NodeTransformer): 

26 def __init__(self, globals, parser): 

27 self.globals = globals 

28 self.parser = parser 

29 self.argv = sys.argv[1:] 

30 self.parameters = [] 

31 

32 def visit_Assign(self, node): 

33 if len(node.targets) == 1 and isinstance(node.targets[0], ast.Name): 

34 target, parameter = node.targets[0].id, node.value 

35 try: 

36 parameter = ast.literal_eval(parameter) 

37 except: 

38 return node 

39 

40 if target[0].lower(): 

41 extras = {} 

42 if isinstance(parameter, bool): 

43 extras.update( 

44 action="store_" + ["true", "false"][parameter], 

45 help="{} = {}".format(target, not parameter), 

46 ) 

47 else: 

48 extras.update( 

49 help="{} : {} = {}".format( 

50 target, type(parameter).__name__, parameter 

51 ) 

52 ) 

53 try: 

54 self.parser.add_argument( 

55 "--%s" % target, default=parameter, **extras 

56 ) 

57 except argparse.ArgumentError: 

58 ... 

59 self.parameters.append( 

60 Parameter(target, Parameter.KEYWORD_ONLY, default=parameter) 

61 ) 

62 if ("-h" not in self.argv) and ("--help" not in self.argv): 

63 ns, self.argv = self.parser.parse_known_args(self.argv) 

64 if target in self.globals: 

65 node = ast.Expr(ast.Str("Skipped")) 

66 elif getattr(ns, target) != parameter: 

67 node.value = ast.parse(str(getattr(ns, target))).body[0].value 

68 return node 

69 

70 @property 

71 def signature(self): 

72 return Signature(self.parameters) 

73 

74 def visit_Module(self, node): 

75 node.body = list(map(self.visit, node.body)) 

76 self.parser.description = ast.get_docstring(node) 

77 self.parser.parse_known_args(self.argv) # run in case there is a help arugment 

78 return node 

79 

80 def generic_visit(self, node): 

81 return node 

82 

83 

84def copy_(module): 

85 new = type(module)(module.__name__) 

86 return new.__dict__.update(**vars(module)) or new 

87 

88 

89class Parameterize(Notebook): 

90 __slots__ = Notebook.__slots__ + ("globals",) 

91 

92 def __init__( 

93 self, 

94 fullname=None, 

95 path=None, 

96 *, 

97 lazy=False, 

98 fuzzy=True, 

99 markdown_docstring=True, 

100 position=0, 

101 globals=None, 

102 main=False, 

103 **_globals 

104 ): 

105 super().__init__( 

106 fullname, path, lazy=lazy, fuzzy=fuzzy, position=position, main=main 

107 ) 

108 self.globals = globals or {} 

109 self.globals.update(**_globals) 

110 self._visitor = FindReplace( 

111 self.globals, argparse.ArgumentParser(prog=self.name) 

112 ) 

113 

114 def exec_module(self, module): 

115 self._visitor = FindReplace(self.globals, self._visitor.parser) 

116 module.__dict__.update(**self.globals) 

117 return super().exec_module(module) 

118 

119 def visit(self, node): 

120 return super().visit(self._visitor.visit(node)) 

121 

122 @classmethod 

123 def load(cls, object, **globals): 

124 return parameterize(super().load(object), **globals) 

125 

126 

127""" with Parameterize(): 

128 reload(foo) 

129 

130 with Parameterize(a=1234123): 

131 reload(foo) 

132 

133 with Parameterize(a="🤘"): 

134 reload(foo) 

135""" 

136 

137""" import foo 

138""" 

139 

140 

141def parameterize(object, **globals): 

142 with Parameterize(**globals): 

143 if isinstance(object, str): 

144 object = module_from_spec(find_spec(object)) 

145 

146 object.__loader__ = Parameterize( 

147 object.__loader__.name, object.__loader__.path, **globals 

148 ) 

149 

150 def call(**parameters): 

151 nonlocal object, globals 

152 object = copy_(object) 

153 keywords = {} 

154 keywords.update(**globals), keywords.update(**parameters) 

155 if _GTE38: 

156 Parameterize(object.__name__, object.__file__, **keywords).exec_module( 

157 object 

158 ) 

159 else: 

160 with _installed_safely(object): 

161 Parameterize(object.__name__, object.__file__, **keywords).exec_module( 

162 object 

163 ) 

164 return object 

165 

166 object.__loader__.get_code(object.__name__) 

167 call.__doc__ = object.__doc__ or object.__loader__._visitor.parser.format_help() 

168 call.__signature__ = object.__loader__._visitor.signature 

169 return call