Coverage for /opt/homebrew/lib/python3.11/site-packages/_pytest/assertion/rewrite.py: 33%
624 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
1"""Rewrite assertion AST to produce nice error messages."""
2import ast
3import errno
4import functools
5import importlib.abc
6import importlib.machinery
7import importlib.util
8import io
9import itertools
10import marshal
11import os
12import struct
13import sys
14import tokenize
15import types
16from pathlib import Path
17from pathlib import PurePath
18from typing import Callable
19from typing import Dict
20from typing import IO
21from typing import Iterable
22from typing import Iterator
23from typing import List
24from typing import Optional
25from typing import Sequence
26from typing import Set
27from typing import Tuple
28from typing import TYPE_CHECKING
29from typing import Union
31from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
32from _pytest._io.saferepr import saferepr
33from _pytest._version import version
34from _pytest.assertion import util
35from _pytest.assertion.util import ( # noqa: F401
36 format_explanation as _format_explanation,
37)
38from _pytest.config import Config
39from _pytest.main import Session
40from _pytest.pathlib import absolutepath
41from _pytest.pathlib import fnmatch_ex
42from _pytest.stash import StashKey
44if TYPE_CHECKING:
45 from _pytest.assertion import AssertionState
48assertstate_key = StashKey["AssertionState"]()
51# pytest caches rewritten pycs in pycache dirs
52PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
53PYC_EXT = ".py" + (__debug__ and "c" or "o")
54PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
57class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
58 """PEP302/PEP451 import hook which rewrites asserts."""
60 def __init__(self, config: Config) -> None:
61 self.config = config
62 try:
63 self.fnpats = config.getini("python_files")
64 except ValueError:
65 self.fnpats = ["test_*.py", "*_test.py"]
66 self.session: Optional[Session] = None
67 self._rewritten_names: Dict[str, Path] = {}
68 self._must_rewrite: Set[str] = set()
69 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
70 # which might result in infinite recursion (#3506)
71 self._writing_pyc = False
72 self._basenames_to_check_rewrite = {"conftest"}
73 self._marked_for_rewrite_cache: Dict[str, bool] = {}
74 self._session_paths_checked = False
76 def set_session(self, session: Optional[Session]) -> None:
77 self.session = session
78 self._session_paths_checked = False
80 # Indirection so we can mock calls to find_spec originated from the hook during testing
81 _find_spec = importlib.machinery.PathFinder.find_spec
83 def find_spec(
84 self,
85 name: str,
86 path: Optional[Sequence[Union[str, bytes]]] = None,
87 target: Optional[types.ModuleType] = None,
88 ) -> Optional[importlib.machinery.ModuleSpec]:
89 if self._writing_pyc:
90 return None
91 state = self.config.stash[assertstate_key]
92 if self._early_rewrite_bailout(name, state):
93 return None
94 state.trace("find_module called for: %s" % name)
96 # Type ignored because mypy is confused about the `self` binding here.
97 spec = self._find_spec(name, path) # type: ignore
98 if (
99 # the import machinery could not find a file to import
100 spec is None
101 # this is a namespace package (without `__init__.py`)
102 # there's nothing to rewrite there
103 or spec.origin is None
104 # we can only rewrite source files
105 or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
106 # if the file doesn't exist, we can't rewrite it
107 or not os.path.exists(spec.origin)
108 ):
109 return None
110 else:
111 fn = spec.origin
113 if not self._should_rewrite(name, fn, state):
114 return None
116 return importlib.util.spec_from_file_location(
117 name,
118 fn,
119 loader=self,
120 submodule_search_locations=spec.submodule_search_locations,
121 )
123 def create_module(
124 self, spec: importlib.machinery.ModuleSpec
125 ) -> Optional[types.ModuleType]:
126 return None # default behaviour is fine
128 def exec_module(self, module: types.ModuleType) -> None:
129 assert module.__spec__ is not None
130 assert module.__spec__.origin is not None
131 fn = Path(module.__spec__.origin)
132 state = self.config.stash[assertstate_key]
134 self._rewritten_names[module.__name__] = fn
136 # The requested module looks like a test file, so rewrite it. This is
137 # the most magical part of the process: load the source, rewrite the
138 # asserts, and load the rewritten source. We also cache the rewritten
139 # module code in a special pyc. We must be aware of the possibility of
140 # concurrent pytest processes rewriting and loading pycs. To avoid
141 # tricky race conditions, we maintain the following invariant: The
142 # cached pyc is always a complete, valid pyc. Operations on it must be
143 # atomic. POSIX's atomic rename comes in handy.
144 write = not sys.dont_write_bytecode
145 cache_dir = get_cache_dir(fn)
146 if write:
147 ok = try_makedirs(cache_dir)
148 if not ok:
149 write = False
150 state.trace(f"read only directory: {cache_dir}")
152 cache_name = fn.name[:-3] + PYC_TAIL
153 pyc = cache_dir / cache_name
154 # Notice that even if we're in a read-only directory, I'm going
155 # to check for a cached pyc. This may not be optimal...
156 co = _read_pyc(fn, pyc, state.trace)
157 if co is None:
158 state.trace(f"rewriting {fn!r}")
159 source_stat, co = _rewrite_test(fn, self.config)
160 if write:
161 self._writing_pyc = True
162 try:
163 _write_pyc(state, co, source_stat, pyc)
164 finally:
165 self._writing_pyc = False
166 else:
167 state.trace(f"found cached rewritten pyc for {fn}")
168 exec(co, module.__dict__)
170 def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
171 """A fast way to get out of rewriting modules.
173 Profiling has shown that the call to PathFinder.find_spec (inside of
174 the find_spec from this class) is a major slowdown, so, this method
175 tries to filter what we're sure won't be rewritten before getting to
176 it.
177 """
178 if self.session is not None and not self._session_paths_checked:
179 self._session_paths_checked = True
180 for initial_path in self.session._initialpaths:
181 # Make something as c:/projects/my_project/path.py ->
182 # ['c:', 'projects', 'my_project', 'path.py']
183 parts = str(initial_path).split(os.path.sep)
184 # add 'path' to basenames to be checked.
185 self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
187 # Note: conftest already by default in _basenames_to_check_rewrite.
188 parts = name.split(".")
189 if parts[-1] in self._basenames_to_check_rewrite:
190 return False
192 # For matching the name it must be as if it was a filename.
193 path = PurePath(*parts).with_suffix(".py")
195 for pat in self.fnpats:
196 # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
197 # on the name alone because we need to match against the full path
198 if os.path.dirname(pat):
199 return False
200 if fnmatch_ex(pat, path):
201 return False
203 if self._is_marked_for_rewrite(name, state):
204 return False
206 state.trace(f"early skip of rewriting module: {name}")
207 return True
209 def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
210 # always rewrite conftest files
211 if os.path.basename(fn) == "conftest.py":
212 state.trace(f"rewriting conftest file: {fn!r}")
213 return True
215 if self.session is not None:
216 if self.session.isinitpath(absolutepath(fn)):
217 state.trace(f"matched test file (was specified on cmdline): {fn!r}")
218 return True
220 # modules not passed explicitly on the command line are only
221 # rewritten if they match the naming convention for test files
222 fn_path = PurePath(fn)
223 for pat in self.fnpats:
224 if fnmatch_ex(pat, fn_path):
225 state.trace(f"matched test file {fn!r}")
226 return True
228 return self._is_marked_for_rewrite(name, state)
230 def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
231 try:
232 return self._marked_for_rewrite_cache[name]
233 except KeyError:
234 for marked in self._must_rewrite:
235 if name == marked or name.startswith(marked + "."):
236 state.trace(f"matched marked file {name!r} (from {marked!r})")
237 self._marked_for_rewrite_cache[name] = True
238 return True
240 self._marked_for_rewrite_cache[name] = False
241 return False
243 def mark_rewrite(self, *names: str) -> None:
244 """Mark import names as needing to be rewritten.
246 The named module or package as well as any nested modules will
247 be rewritten on import.
248 """
249 already_imported = (
250 set(names).intersection(sys.modules).difference(self._rewritten_names)
251 )
252 for name in already_imported:
253 mod = sys.modules[name]
254 if not AssertionRewriter.is_rewrite_disabled(
255 mod.__doc__ or ""
256 ) and not isinstance(mod.__loader__, type(self)):
257 self._warn_already_imported(name)
258 self._must_rewrite.update(names)
259 self._marked_for_rewrite_cache.clear()
261 def _warn_already_imported(self, name: str) -> None:
262 from _pytest.warning_types import PytestAssertRewriteWarning
264 self.config.issue_config_time_warning(
265 PytestAssertRewriteWarning(
266 "Module already imported so cannot be rewritten: %s" % name
267 ),
268 stacklevel=5,
269 )
271 def get_data(self, pathname: Union[str, bytes]) -> bytes:
272 """Optional PEP302 get_data API."""
273 with open(pathname, "rb") as f:
274 return f.read()
276 if sys.version_info >= (3, 10):
278 if sys.version_info >= (3, 12):
279 from importlib.resources.abc import TraversableResources
280 else:
281 from importlib.abc import TraversableResources
283 def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore
284 if sys.version_info < (3, 11):
285 from importlib.readers import FileReader
286 else:
287 from importlib.resources.readers import FileReader
289 return FileReader( # type:ignore[no-any-return]
290 types.SimpleNamespace(path=self._rewritten_names[name])
291 )
294def _write_pyc_fp(
295 fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
296) -> None:
297 # Technically, we don't have to have the same pyc format as
298 # (C)Python, since these "pycs" should never be seen by builtin
299 # import. However, there's little reason to deviate.
300 fp.write(importlib.util.MAGIC_NUMBER)
301 # https://www.python.org/dev/peps/pep-0552/
302 flags = b"\x00\x00\x00\x00"
303 fp.write(flags)
304 # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
305 mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
306 size = source_stat.st_size & 0xFFFFFFFF
307 # "<LL" stands for 2 unsigned longs, little-endian.
308 fp.write(struct.pack("<LL", mtime, size))
309 fp.write(marshal.dumps(co))
312def _write_pyc(
313 state: "AssertionState",
314 co: types.CodeType,
315 source_stat: os.stat_result,
316 pyc: Path,
317) -> bool:
318 proc_pyc = f"{pyc}.{os.getpid()}"
319 try:
320 with open(proc_pyc, "wb") as fp:
321 _write_pyc_fp(fp, source_stat, co)
322 except OSError as e:
323 state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
324 return False
326 try:
327 os.replace(proc_pyc, pyc)
328 except OSError as e:
329 state.trace(f"error writing pyc file at {pyc}: {e}")
330 # we ignore any failure to write the cache file
331 # there are many reasons, permission-denied, pycache dir being a
332 # file etc.
333 return False
334 return True
337def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
338 """Read and rewrite *fn* and return the code object."""
339 stat = os.stat(fn)
340 source = fn.read_bytes()
341 strfn = str(fn)
342 tree = ast.parse(source, filename=strfn)
343 rewrite_asserts(tree, source, strfn, config)
344 co = compile(tree, strfn, "exec", dont_inherit=True)
345 return stat, co
348def _read_pyc(
349 source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
350) -> Optional[types.CodeType]:
351 """Possibly read a pytest pyc containing rewritten code.
353 Return rewritten code if successful or None if not.
354 """
355 try:
356 fp = open(pyc, "rb")
357 except OSError:
358 return None
359 with fp:
360 try:
361 stat_result = os.stat(source)
362 mtime = int(stat_result.st_mtime)
363 size = stat_result.st_size
364 data = fp.read(16)
365 except OSError as e:
366 trace(f"_read_pyc({source}): OSError {e}")
367 return None
368 # Check for invalid or out of date pyc file.
369 if len(data) != (16):
370 trace("_read_pyc(%s): invalid pyc (too short)" % source)
371 return None
372 if data[:4] != importlib.util.MAGIC_NUMBER:
373 trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
374 return None
375 if data[4:8] != b"\x00\x00\x00\x00":
376 trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
377 return None
378 mtime_data = data[8:12]
379 if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
380 trace("_read_pyc(%s): out of date" % source)
381 return None
382 size_data = data[12:16]
383 if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
384 trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
385 return None
386 try:
387 co = marshal.load(fp)
388 except Exception as e:
389 trace(f"_read_pyc({source}): marshal.load error {e}")
390 return None
391 if not isinstance(co, types.CodeType):
392 trace("_read_pyc(%s): not a code object" % source)
393 return None
394 return co
397def rewrite_asserts(
398 mod: ast.Module,
399 source: bytes,
400 module_path: Optional[str] = None,
401 config: Optional[Config] = None,
402) -> None:
403 """Rewrite the assert statements in mod."""
404 AssertionRewriter(module_path, config, source).run(mod)
407def _saferepr(obj: object) -> str:
408 r"""Get a safe repr of an object for assertion error messages.
410 The assertion formatting (util.format_explanation()) requires
411 newlines to be escaped since they are a special character for it.
412 Normally assertion.util.format_explanation() does this but for a
413 custom repr it is possible to contain one of the special escape
414 sequences, especially '\n{' and '\n}' are likely to be present in
415 JSON reprs.
416 """
417 maxsize = _get_maxsize_for_saferepr(util._config)
418 return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
421def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
422 """Get `maxsize` configuration for saferepr based on the given config object."""
423 verbosity = config.getoption("verbose") if config is not None else 0
424 if verbosity >= 2:
425 return None
426 if verbosity >= 1:
427 return DEFAULT_REPR_MAX_SIZE * 10
428 return DEFAULT_REPR_MAX_SIZE
431def _format_assertmsg(obj: object) -> str:
432 r"""Format the custom assertion message given.
434 For strings this simply replaces newlines with '\n~' so that
435 util.format_explanation() will preserve them instead of escaping
436 newlines. For other objects saferepr() is used first.
437 """
438 # reprlib appears to have a bug which means that if a string
439 # contains a newline it gets escaped, however if an object has a
440 # .__repr__() which contains newlines it does not get escaped.
441 # However in either case we want to preserve the newline.
442 replaces = [("\n", "\n~"), ("%", "%%")]
443 if not isinstance(obj, str):
444 obj = saferepr(obj)
445 replaces.append(("\\n", "\n~"))
447 for r1, r2 in replaces:
448 obj = obj.replace(r1, r2)
450 return obj
453def _should_repr_global_name(obj: object) -> bool:
454 if callable(obj):
455 return False
457 try:
458 return not hasattr(obj, "__name__")
459 except Exception:
460 return True
463def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
464 explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
465 return explanation.replace("%", "%%")
468def _call_reprcompare(
469 ops: Sequence[str],
470 results: Sequence[bool],
471 expls: Sequence[str],
472 each_obj: Sequence[object],
473) -> str:
474 for i, res, expl in zip(range(len(ops)), results, expls):
475 try:
476 done = not res
477 except Exception:
478 done = True
479 if done:
480 break
481 if util._reprcompare is not None:
482 custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
483 if custom is not None:
484 return custom
485 return expl
488def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
489 if util._assertion_pass is not None:
490 util._assertion_pass(lineno, orig, expl)
493def _check_if_assertion_pass_impl() -> bool:
494 """Check if any plugins implement the pytest_assertion_pass hook
495 in order not to generate explanation unnecessarily (might be expensive)."""
496 return True if util._assertion_pass else False
499UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
501BINOP_MAP = {
502 ast.BitOr: "|",
503 ast.BitXor: "^",
504 ast.BitAnd: "&",
505 ast.LShift: "<<",
506 ast.RShift: ">>",
507 ast.Add: "+",
508 ast.Sub: "-",
509 ast.Mult: "*",
510 ast.Div: "/",
511 ast.FloorDiv: "//",
512 ast.Mod: "%%", # escaped for string formatting
513 ast.Eq: "==",
514 ast.NotEq: "!=",
515 ast.Lt: "<",
516 ast.LtE: "<=",
517 ast.Gt: ">",
518 ast.GtE: ">=",
519 ast.Pow: "**",
520 ast.Is: "is",
521 ast.IsNot: "is not",
522 ast.In: "in",
523 ast.NotIn: "not in",
524 ast.MatMult: "@",
525}
528def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
529 """Recursively yield node and all its children in depth-first order."""
530 yield node
531 for child in ast.iter_child_nodes(node):
532 yield from traverse_node(child)
535@functools.lru_cache(maxsize=1)
536def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
537 """Return a mapping from {lineno: "assertion test expression"}."""
538 ret: Dict[int, str] = {}
540 depth = 0
541 lines: List[str] = []
542 assert_lineno: Optional[int] = None
543 seen_lines: Set[int] = set()
545 def _write_and_reset() -> None:
546 nonlocal depth, lines, assert_lineno, seen_lines
547 assert assert_lineno is not None
548 ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
549 depth = 0
550 lines = []
551 assert_lineno = None
552 seen_lines = set()
554 tokens = tokenize.tokenize(io.BytesIO(src).readline)
555 for tp, source, (lineno, offset), _, line in tokens:
556 if tp == tokenize.NAME and source == "assert":
557 assert_lineno = lineno
558 elif assert_lineno is not None:
559 # keep track of depth for the assert-message `,` lookup
560 if tp == tokenize.OP and source in "([{":
561 depth += 1
562 elif tp == tokenize.OP and source in ")]}":
563 depth -= 1
565 if not lines:
566 lines.append(line[offset:])
567 seen_lines.add(lineno)
568 # a non-nested comma separates the expression from the message
569 elif depth == 0 and tp == tokenize.OP and source == ",":
570 # one line assert with message
571 if lineno in seen_lines and len(lines) == 1:
572 offset_in_trimmed = offset + len(lines[-1]) - len(line)
573 lines[-1] = lines[-1][:offset_in_trimmed]
574 # multi-line assert with message
575 elif lineno in seen_lines:
576 lines[-1] = lines[-1][:offset]
577 # multi line assert with escapd newline before message
578 else:
579 lines.append(line[:offset])
580 _write_and_reset()
581 elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
582 _write_and_reset()
583 elif lines and lineno not in seen_lines:
584 lines.append(line)
585 seen_lines.add(lineno)
587 return ret
590class AssertionRewriter(ast.NodeVisitor):
591 """Assertion rewriting implementation.
593 The main entrypoint is to call .run() with an ast.Module instance,
594 this will then find all the assert statements and rewrite them to
595 provide intermediate values and a detailed assertion error. See
596 http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
597 for an overview of how this works.
599 The entry point here is .run() which will iterate over all the
600 statements in an ast.Module and for each ast.Assert statement it
601 finds call .visit() with it. Then .visit_Assert() takes over and
602 is responsible for creating new ast statements to replace the
603 original assert statement: it rewrites the test of an assertion
604 to provide intermediate values and replace it with an if statement
605 which raises an assertion error with a detailed explanation in
606 case the expression is false and calls pytest_assertion_pass hook
607 if expression is true.
609 For this .visit_Assert() uses the visitor pattern to visit all the
610 AST nodes of the ast.Assert.test field, each visit call returning
611 an AST node and the corresponding explanation string. During this
612 state is kept in several instance attributes:
614 :statements: All the AST statements which will replace the assert
615 statement.
617 :variables: This is populated by .variable() with each variable
618 used by the statements so that they can all be set to None at
619 the end of the statements.
621 :variable_counter: Counter to create new unique variables needed
622 by statements. Variables are created using .variable() and
623 have the form of "@py_assert0".
625 :expl_stmts: The AST statements which will be executed to get
626 data from the assertion. This is the code which will construct
627 the detailed assertion message that is used in the AssertionError
628 or for the pytest_assertion_pass hook.
630 :explanation_specifiers: A dict filled by .explanation_param()
631 with %-formatting placeholders and their corresponding
632 expressions to use in the building of an assertion message.
633 This is used by .pop_format_context() to build a message.
635 :stack: A stack of the explanation_specifiers dicts maintained by
636 .push_format_context() and .pop_format_context() which allows
637 to build another %-formatted string while already building one.
639 This state is reset on every new assert statement visited and used
640 by the other visitors.
641 """
643 def __init__(
644 self, module_path: Optional[str], config: Optional[Config], source: bytes
645 ) -> None:
646 super().__init__()
647 self.module_path = module_path
648 self.config = config
649 if config is not None:
650 self.enable_assertion_pass_hook = config.getini(
651 "enable_assertion_pass_hook"
652 )
653 else:
654 self.enable_assertion_pass_hook = False
655 self.source = source
657 def run(self, mod: ast.Module) -> None:
658 """Find all assert statements in *mod* and rewrite them."""
659 if not mod.body:
660 # Nothing to do.
661 return
663 # We'll insert some special imports at the top of the module, but after any
664 # docstrings and __future__ imports, so first figure out where that is.
665 doc = getattr(mod, "docstring", None)
666 expect_docstring = doc is None
667 if doc is not None and self.is_rewrite_disabled(doc):
668 return
669 pos = 0
670 lineno = 1
671 for item in mod.body:
672 if (
673 expect_docstring
674 and isinstance(item, ast.Expr)
675 and isinstance(item.value, ast.Str)
676 ):
677 doc = item.value.s
678 if self.is_rewrite_disabled(doc):
679 return
680 expect_docstring = False
681 elif (
682 isinstance(item, ast.ImportFrom)
683 and item.level == 0
684 and item.module == "__future__"
685 ):
686 pass
687 else:
688 break
689 pos += 1
690 # Special case: for a decorated function, set the lineno to that of the
691 # first decorator, not the `def`. Issue #4984.
692 if isinstance(item, ast.FunctionDef) and item.decorator_list:
693 lineno = item.decorator_list[0].lineno
694 else:
695 lineno = item.lineno
696 # Now actually insert the special imports.
697 if sys.version_info >= (3, 10):
698 aliases = [
699 ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
700 ast.alias(
701 "_pytest.assertion.rewrite",
702 "@pytest_ar",
703 lineno=lineno,
704 col_offset=0,
705 ),
706 ]
707 else:
708 aliases = [
709 ast.alias("builtins", "@py_builtins"),
710 ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
711 ]
712 imports = [
713 ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
714 ]
715 mod.body[pos:pos] = imports
717 # Collect asserts.
718 nodes: List[ast.AST] = [mod]
719 while nodes:
720 node = nodes.pop()
721 for name, field in ast.iter_fields(node):
722 if isinstance(field, list):
723 new: List[ast.AST] = []
724 for i, child in enumerate(field):
725 if isinstance(child, ast.Assert):
726 # Transform assert.
727 new.extend(self.visit(child))
728 else:
729 new.append(child)
730 if isinstance(child, ast.AST):
731 nodes.append(child)
732 setattr(node, name, new)
733 elif (
734 isinstance(field, ast.AST)
735 # Don't recurse into expressions as they can't contain
736 # asserts.
737 and not isinstance(field, ast.expr)
738 ):
739 nodes.append(field)
741 @staticmethod
742 def is_rewrite_disabled(docstring: str) -> bool:
743 return "PYTEST_DONT_REWRITE" in docstring
745 def variable(self) -> str:
746 """Get a new variable."""
747 # Use a character invalid in python identifiers to avoid clashing.
748 name = "@py_assert" + str(next(self.variable_counter))
749 self.variables.append(name)
750 return name
752 def assign(self, expr: ast.expr) -> ast.Name:
753 """Give *expr* a name."""
754 name = self.variable()
755 self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
756 return ast.Name(name, ast.Load())
758 def display(self, expr: ast.expr) -> ast.expr:
759 """Call saferepr on the expression."""
760 return self.helper("_saferepr", expr)
762 def helper(self, name: str, *args: ast.expr) -> ast.expr:
763 """Call a helper in this module."""
764 py_name = ast.Name("@pytest_ar", ast.Load())
765 attr = ast.Attribute(py_name, name, ast.Load())
766 return ast.Call(attr, list(args), [])
768 def builtin(self, name: str) -> ast.Attribute:
769 """Return the builtin called *name*."""
770 builtin_name = ast.Name("@py_builtins", ast.Load())
771 return ast.Attribute(builtin_name, name, ast.Load())
773 def explanation_param(self, expr: ast.expr) -> str:
774 """Return a new named %-formatting placeholder for expr.
776 This creates a %-formatting placeholder for expr in the
777 current formatting context, e.g. ``%(py0)s``. The placeholder
778 and expr are placed in the current format context so that it
779 can be used on the next call to .pop_format_context().
780 """
781 specifier = "py" + str(next(self.variable_counter))
782 self.explanation_specifiers[specifier] = expr
783 return "%(" + specifier + ")s"
785 def push_format_context(self) -> None:
786 """Create a new formatting context.
788 The format context is used for when an explanation wants to
789 have a variable value formatted in the assertion message. In
790 this case the value required can be added using
791 .explanation_param(). Finally .pop_format_context() is used
792 to format a string of %-formatted values as added by
793 .explanation_param().
794 """
795 self.explanation_specifiers: Dict[str, ast.expr] = {}
796 self.stack.append(self.explanation_specifiers)
798 def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
799 """Format the %-formatted string with current format context.
801 The expl_expr should be an str ast.expr instance constructed from
802 the %-placeholders created by .explanation_param(). This will
803 add the required code to format said string to .expl_stmts and
804 return the ast.Name instance of the formatted string.
805 """
806 current = self.stack.pop()
807 if self.stack:
808 self.explanation_specifiers = self.stack[-1]
809 keys = [ast.Str(key) for key in current.keys()]
810 format_dict = ast.Dict(keys, list(current.values()))
811 form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
812 name = "@py_format" + str(next(self.variable_counter))
813 if self.enable_assertion_pass_hook:
814 self.format_variables.append(name)
815 self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
816 return ast.Name(name, ast.Load())
818 def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
819 """Handle expressions we don't have custom code for."""
820 assert isinstance(node, ast.expr)
821 res = self.assign(node)
822 return res, self.explanation_param(self.display(res))
824 def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
825 """Return the AST statements to replace the ast.Assert instance.
827 This rewrites the test of an assertion to provide
828 intermediate values and replace it with an if statement which
829 raises an assertion error with a detailed explanation in case
830 the expression is false.
831 """
832 if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
833 from _pytest.warning_types import PytestAssertRewriteWarning
834 import warnings
836 # TODO: This assert should not be needed.
837 assert self.module_path is not None
838 warnings.warn_explicit(
839 PytestAssertRewriteWarning(
840 "assertion is always true, perhaps remove parentheses?"
841 ),
842 category=None,
843 filename=self.module_path,
844 lineno=assert_.lineno,
845 )
847 self.statements: List[ast.stmt] = []
848 self.variables: List[str] = []
849 self.variable_counter = itertools.count()
851 if self.enable_assertion_pass_hook:
852 self.format_variables: List[str] = []
854 self.stack: List[Dict[str, ast.expr]] = []
855 self.expl_stmts: List[ast.stmt] = []
856 self.push_format_context()
857 # Rewrite assert into a bunch of statements.
858 top_condition, explanation = self.visit(assert_.test)
860 negation = ast.UnaryOp(ast.Not(), top_condition)
862 if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
863 msg = self.pop_format_context(ast.Str(explanation))
865 # Failed
866 if assert_.msg:
867 assertmsg = self.helper("_format_assertmsg", assert_.msg)
868 gluestr = "\n>assert "
869 else:
870 assertmsg = ast.Str("")
871 gluestr = "assert "
872 err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
873 err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
874 err_name = ast.Name("AssertionError", ast.Load())
875 fmt = self.helper("_format_explanation", err_msg)
876 exc = ast.Call(err_name, [fmt], [])
877 raise_ = ast.Raise(exc, None)
878 statements_fail = []
879 statements_fail.extend(self.expl_stmts)
880 statements_fail.append(raise_)
882 # Passed
883 fmt_pass = self.helper("_format_explanation", msg)
884 orig = _get_assertion_exprs(self.source)[assert_.lineno]
885 hook_call_pass = ast.Expr(
886 self.helper(
887 "_call_assertion_pass",
888 ast.Num(assert_.lineno),
889 ast.Str(orig),
890 fmt_pass,
891 )
892 )
893 # If any hooks implement assert_pass hook
894 hook_impl_test = ast.If(
895 self.helper("_check_if_assertion_pass_impl"),
896 self.expl_stmts + [hook_call_pass],
897 [],
898 )
899 statements_pass = [hook_impl_test]
901 # Test for assertion condition
902 main_test = ast.If(negation, statements_fail, statements_pass)
903 self.statements.append(main_test)
904 if self.format_variables:
905 variables = [
906 ast.Name(name, ast.Store()) for name in self.format_variables
907 ]
908 clear_format = ast.Assign(variables, ast.NameConstant(None))
909 self.statements.append(clear_format)
911 else: # Original assertion rewriting
912 # Create failure message.
913 body = self.expl_stmts
914 self.statements.append(ast.If(negation, body, []))
915 if assert_.msg:
916 assertmsg = self.helper("_format_assertmsg", assert_.msg)
917 explanation = "\n>assert " + explanation
918 else:
919 assertmsg = ast.Str("")
920 explanation = "assert " + explanation
921 template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
922 msg = self.pop_format_context(template)
923 fmt = self.helper("_format_explanation", msg)
924 err_name = ast.Name("AssertionError", ast.Load())
925 exc = ast.Call(err_name, [fmt], [])
926 raise_ = ast.Raise(exc, None)
928 body.append(raise_)
930 # Clear temporary variables by setting them to None.
931 if self.variables:
932 variables = [ast.Name(name, ast.Store()) for name in self.variables]
933 clear = ast.Assign(variables, ast.NameConstant(None))
934 self.statements.append(clear)
935 # Fix locations (line numbers/column offsets).
936 for stmt in self.statements:
937 for node in traverse_node(stmt):
938 ast.copy_location(node, assert_)
939 return self.statements
941 def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
942 # Display the repr of the name if it's a local variable or
943 # _should_repr_global_name() thinks it's acceptable.
944 locs = ast.Call(self.builtin("locals"), [], [])
945 inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
946 dorepr = self.helper("_should_repr_global_name", name)
947 test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
948 expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
949 return name, self.explanation_param(expr)
951 def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
952 res_var = self.variable()
953 expl_list = self.assign(ast.List([], ast.Load()))
954 app = ast.Attribute(expl_list, "append", ast.Load())
955 is_or = int(isinstance(boolop.op, ast.Or))
956 body = save = self.statements
957 fail_save = self.expl_stmts
958 levels = len(boolop.values) - 1
959 self.push_format_context()
960 # Process each operand, short-circuiting if needed.
961 for i, v in enumerate(boolop.values):
962 if i:
963 fail_inner: List[ast.stmt] = []
964 # cond is set in a prior loop iteration below
965 self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
966 self.expl_stmts = fail_inner
967 self.push_format_context()
968 res, expl = self.visit(v)
969 body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
970 expl_format = self.pop_format_context(ast.Str(expl))
971 call = ast.Call(app, [expl_format], [])
972 self.expl_stmts.append(ast.Expr(call))
973 if i < levels:
974 cond: ast.expr = res
975 if is_or:
976 cond = ast.UnaryOp(ast.Not(), cond)
977 inner: List[ast.stmt] = []
978 self.statements.append(ast.If(cond, inner, []))
979 self.statements = body = inner
980 self.statements = save
981 self.expl_stmts = fail_save
982 expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
983 expl = self.pop_format_context(expl_template)
984 return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
986 def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
987 pattern = UNARY_MAP[unary.op.__class__]
988 operand_res, operand_expl = self.visit(unary.operand)
989 res = self.assign(ast.UnaryOp(unary.op, operand_res))
990 return res, pattern % (operand_expl,)
992 def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
993 symbol = BINOP_MAP[binop.op.__class__]
994 left_expr, left_expl = self.visit(binop.left)
995 right_expr, right_expl = self.visit(binop.right)
996 explanation = f"({left_expl} {symbol} {right_expl})"
997 res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
998 return res, explanation
1000 def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
1001 new_func, func_expl = self.visit(call.func)
1002 arg_expls = []
1003 new_args = []
1004 new_kwargs = []
1005 for arg in call.args:
1006 res, expl = self.visit(arg)
1007 arg_expls.append(expl)
1008 new_args.append(res)
1009 for keyword in call.keywords:
1010 res, expl = self.visit(keyword.value)
1011 new_kwargs.append(ast.keyword(keyword.arg, res))
1012 if keyword.arg:
1013 arg_expls.append(keyword.arg + "=" + expl)
1014 else: # **args have `arg` keywords with an .arg of None
1015 arg_expls.append("**" + expl)
1017 expl = "{}({})".format(func_expl, ", ".join(arg_expls))
1018 new_call = ast.Call(new_func, new_args, new_kwargs)
1019 res = self.assign(new_call)
1020 res_expl = self.explanation_param(self.display(res))
1021 outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
1022 return res, outer_expl
1024 def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
1025 # A Starred node can appear in a function call.
1026 res, expl = self.visit(starred.value)
1027 new_starred = ast.Starred(res, starred.ctx)
1028 return new_starred, "*" + expl
1030 def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
1031 if not isinstance(attr.ctx, ast.Load):
1032 return self.generic_visit(attr)
1033 value, value_expl = self.visit(attr.value)
1034 res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
1035 res_expl = self.explanation_param(self.display(res))
1036 pat = "%s\n{%s = %s.%s\n}"
1037 expl = pat % (res_expl, res_expl, value_expl, attr.attr)
1038 return res, expl
1040 def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
1041 self.push_format_context()
1042 left_res, left_expl = self.visit(comp.left)
1043 if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
1044 left_expl = f"({left_expl})"
1045 res_variables = [self.variable() for i in range(len(comp.ops))]
1046 load_names = [ast.Name(v, ast.Load()) for v in res_variables]
1047 store_names = [ast.Name(v, ast.Store()) for v in res_variables]
1048 it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
1049 expls = []
1050 syms = []
1051 results = [left_res]
1052 for i, op, next_operand in it:
1053 next_res, next_expl = self.visit(next_operand)
1054 if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
1055 next_expl = f"({next_expl})"
1056 results.append(next_res)
1057 sym = BINOP_MAP[op.__class__]
1058 syms.append(ast.Str(sym))
1059 expl = f"{left_expl} {sym} {next_expl}"
1060 expls.append(ast.Str(expl))
1061 res_expr = ast.Compare(left_res, [op], [next_res])
1062 self.statements.append(ast.Assign([store_names[i]], res_expr))
1063 left_res, left_expl = next_res, next_expl
1064 # Use pytest.assertion.util._reprcompare if that's available.
1065 expl_call = self.helper(
1066 "_call_reprcompare",
1067 ast.Tuple(syms, ast.Load()),
1068 ast.Tuple(load_names, ast.Load()),
1069 ast.Tuple(expls, ast.Load()),
1070 ast.Tuple(results, ast.Load()),
1071 )
1072 if len(comp.ops) > 1:
1073 res: ast.expr = ast.BoolOp(ast.And(), load_names)
1074 else:
1075 res = load_names[0]
1076 return res, self.explanation_param(self.pop_format_context(expl_call))
1079def try_makedirs(cache_dir: Path) -> bool:
1080 """Attempt to create the given directory and sub-directories exist.
1082 Returns True if successful or if it already exists.
1083 """
1084 try:
1085 os.makedirs(cache_dir, exist_ok=True)
1086 except (FileNotFoundError, NotADirectoryError, FileExistsError):
1087 # One of the path components was not a directory:
1088 # - we're in a zip file
1089 # - it is a file
1090 return False
1091 except PermissionError:
1092 return False
1093 except OSError as e:
1094 # as of now, EROFS doesn't have an equivalent OSError-subclass
1095 if e.errno == errno.EROFS:
1096 return False
1097 raise
1098 return True
1101def get_cache_dir(file_path: Path) -> Path:
1102 """Return the cache directory to write .pyc files for the given .py file path."""
1103 if sys.version_info >= (3, 8) and sys.pycache_prefix:
1104 # given:
1105 # prefix = '/tmp/pycs'
1106 # path = '/home/user/proj/test_app.py'
1107 # we want:
1108 # '/tmp/pycs/home/user/proj'
1109 return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
1110 else:
1111 # classic pycache directory
1112 return file_path.parent / "__pycache__"