Coverage for onionizer/onionizer/onionizer.py: 99%

108 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-04-02 17:43 +0200

1import functools 

2import inspect 

3from abc import ABC 

4from contextlib import ExitStack 

5from typing import Callable, Any, Iterable, Sequence, TypeVar, Generator 

6 

7T = TypeVar("T") # pragma: no mutate 

8 

9OnionGenerator = Generator[Any, T, T] # pragma: no mutate 

10 

11UNCHANGED = 123 # pragma: no mutate 

12 

13__all__ = [ 

14 "wrap_around", 

15 "decorate", 

16 "OnionGenerator", 

17 "UNCHANGED", 

18 "PositionalArgs", 

19 "MixedArgs", 

20 "KeywordArgs", 

21 "postprocessor", 

22 "preprocessor", 

23 "as_decorator", 

24] 

25 

26 

27def _capture_last_message(coroutine, value_to_send: Any) -> Any: 

28 try: 

29 coroutine.send(value_to_send) 

30 except StopIteration as e: 

31 # expected if the generator is exhausted 

32 return e.value 

33 else: 

34 raise RuntimeError( 

35 "Generator did not exhaust. Your function should yield exactly once." 

36 ) 

37 

38 

39def _leave_the_onion(coroutines: Sequence, output: Any) -> Any: 

40 for coroutine in reversed(coroutines): 

41 # reversed to respect onion model 

42 output = _capture_last_message(coroutine, output) 

43 return output 

44 

45 

46def as_decorator(middleware): 

47 return decorate([middleware]) 

48 

49 

50def decorate(middlewares): 

51 if not isinstance(middlewares, Iterable): 

52 if callable(middlewares): 

53 middlewares = [middlewares] 

54 else: 

55 raise TypeError( 

56 "middlewares must be a list of coroutines or a single coroutine" 

57 ) 

58 

59 def decorator(func): 

60 return wrap_around(func, middlewares) 

61 

62 return decorator 

63 

64 

65def wrap_around( 

66 func: Callable[..., Any], middlewares: list, sigcheck: bool = True 

67) -> Callable[..., Any]: 

68 """ 

69 It takes a function and a list of middlewares, 

70 and returns a function that calls the middlewares in order, then the 

71 function, then the middlewares in reverse order 

72 

73 def func(x, y): 

74 return x + y 

75 

76 def middleware1(*args, **kwargs): 

77 result = yield (args[0]+1, args[1]), kwargs 

78 return result 

79 

80 def middleware2(*args, **kwargs): 

81 result = yield (args[0], args[1]+1), kwargs 

82 return result 

83 

84 

85 wrapped_func = dip.wrap_around(func, [middleware1, middleware2]) 

86 result = wrapped_func(0, 0) 

87 

88 assert result == 2 

89 

90 :param func: the function to be wrapped 

91 :type func: Callable[..., Any] 

92 :param middlewares: a list of functions that will be called in order 

93 :type middlewares: list 

94 :return: A function that wraps the original function with the middlewares. 

95 """ 

96 _check_validity(func, middlewares, sigcheck) 

97 

98 @functools.wraps(func) 

99 def wrapped_func(*args, **kwargs): 

100 arguments = MixedArgs(args, kwargs) 

101 coroutines = [] 

102 with ExitStack() as stack: 

103 # programmatic support for context manager, possibly nested ! 

104 # https://docs.python.org/3/library/contextlib.html#contextlib.ExitStack 

105 for middleware in middlewares: 

106 if hasattr(middleware, "__enter__") and hasattr(middleware, "__exit__"): 

107 stack.enter_context(middleware) 

108 continue 

109 coroutine = arguments.call_function(middleware) 

110 coroutines.append(coroutine) 

111 try: 

112 raw_arguments = coroutine.send(None) 

113 except AttributeError: 

114 raise TypeError( 

115 f"Middleware {middleware.__name__} is not a coroutine. " 

116 f"Did you forget to use a yield statement?" 

117 ) 

118 arguments = _refine(raw_arguments, arguments) 

119 # just reached the core of the onion 

120 output = arguments.call_function(func) 

121 # now we go back to the surface 

122 output = _leave_the_onion(coroutines, output) 

123 return output 

124 

125 return wrapped_func 

126 

127 

128def _check_validity(func, middlewares, sigcheck): 

129 if not callable(func): 

130 raise TypeError("func must be callable") 

131 if not isinstance(middlewares, Iterable): 

132 raise TypeError("middlewares must be a list of coroutines") 

133 if sigcheck: 

134 _inspect_signatures(func, middlewares) 

135 

136 

137def _inspect_signatures(func, middlewares): 

138 func_signature = inspect.signature(func) 

139 func_signature_params = func_signature.parameters 

140 for middleware in middlewares: 

141 if not ( 

142 hasattr(middleware, "ignore_signature_check") 

143 and middleware.ignore_signature_check is True 

144 ) and not all(hasattr(middleware, attr) for attr in ("__enter__", "__exit__")): 

145 middleware_signature = inspect.signature(middleware) 

146 middleware_signature_params = middleware_signature.parameters 

147 if middleware_signature_params != func_signature_params: 

148 raise ValueError( 

149 f"Expected arguments of the target function mismatch " 

150 f"middleware expected arguments. {func.__name__}{func_signature} " 

151 f"differs with {middleware.__name__}{middleware_signature}" 

152 ) 

153 

154 

155class ArgsMode(ABC): 

156 def call_function(self, func: Callable[..., Any]): 

157 raise NotImplementedError 

158 

159 

160class PositionalArgs(ArgsMode): 

161 def __init__(self, *args): 

162 self.args = args 

163 

164 def call_function(self, func: Callable[..., Any]): 

165 return func(*self.args) 

166 

167 

168class KeywordArgs(ArgsMode): 

169 def __init__(self, kwargs): 

170 self.kwargs = kwargs 

171 

172 def call_function(self, func: Callable[..., Any]): 

173 return func(**self.kwargs) 

174 

175 

176class MixedArgs(ArgsMode): 

177 def __init__(self, args, kwargs): 

178 self.args = args 

179 self.kwargs = kwargs 

180 

181 def call_function(self, func: Callable[..., Any]): 

182 return func(*self.args, **self.kwargs) 

183 

184 

185def _refine(arguments, previous_arguments): 

186 if arguments is UNCHANGED: 

187 return previous_arguments 

188 if isinstance(arguments, ArgsMode): 

189 return arguments 

190 if not isinstance(arguments, Sequence) or len(arguments) != 2: 

191 raise TypeError( 

192 "arguments must be a tuple of length 2, " 

193 "maybe use onionizer.PositionalArgs or onionizer.MixedArgs instead" 

194 ) 

195 args, kwargs = arguments 

196 return MixedArgs(args, kwargs) 

197 

198 

199def preprocessor(func): 

200 @functools.wraps(func) 

201 def wrapper(*args, **kwargs) -> OnionGenerator: 

202 arguments = yield func(*args, **kwargs) 

203 return arguments 

204 

205 return wrapper 

206 

207 

208def postprocessor(func): 

209 @functools.wraps(func) 

210 def wrapper(*args, **kwargs) -> OnionGenerator: 

211 output = yield UNCHANGED 

212 return func(output) 

213 

214 wrapper.ignore_signature_check = True 

215 return wrapper