I thought this was an interesting task, so I decided to give it a try. Here is what I came up with:
import ast
import inspect
import re
import sys
import __future__
if sys.version_info >= (3,5):
ast_Call = ast.Call
else:
def ast_Call(func, args, keywords):
"""Compatibility wrapper for ast.Call on Python 3.4 and below.
Used to have two additional fields (starargs, kwargs)."""
return ast.Call(func, args, keywords, None, None)
COMMENT_RE = re.compile(r'^(\s*)#\s?(.*)$')
def convert_comment_to_print(line):
"""If `line` contains a comment, it is changed into a print
statement, otherwise nothing happens. Only acts on full-line comments,
not on trailing comments. Returns the (possibly modified) line."""
match = COMMENT_RE.match(line)
if match:
return '{}print({!r})\n'.format(*match.groups())
else:
return line
def convert_docstrings_to_prints(syntax_tree):
"""Walks an AST and changes every docstring (i.e. every expression
statement consisting only of a string) to a print statement.
The AST is modified in-place."""
ast_print = ast.Name('print', ast.Load())
nodes = list(ast.walk(syntax_tree))
for node in nodes:
for bodylike_field in ('body', 'orelse', 'finalbody'):
if hasattr(node, bodylike_field):
for statement in getattr(node, bodylike_field):
if (isinstance(statement, ast.Expr) and
isinstance(statement.value, ast.Str)):
arg = statement.value
statement.value = ast_Call(ast_print, [arg], [])
def get_future_flags(module_or_func):
"""Get the compile flags corresponding to the features imported from
__future__ by the specified module, or by the module containing the
specific function. Returns a single integer containing the bitwise OR
of all the flags that were found."""
result = 0
for feature_name in __future__.all_feature_names:
feature = getattr(__future__, feature_name)
if (hasattr(module_or_func, feature_name) and
getattr(module_or_func, feature_name) is feature and
hasattr(feature, 'compiler_flag')):
result |= feature.compiler_flag
return result
def eval_function(syntax_tree, func_globals, filename, lineno, compile_flags,
*args, **kwargs):
"""Helper function for `trace`. Execute the function defined by
the given syntax tree, and return its return value."""
func = syntax_tree.body[0]
func.decorator_list.insert(0, ast.Name('_trace_exec_decorator', ast.Load()))
ast.increment_lineno(syntax_tree, lineno-1)
ast.fix_missing_locations(syntax_tree)
code = compile(syntax_tree, filename, 'exec', compile_flags, True)
result = [None]
def _trace_exec_decorator(compiled_func):
result[0] = compiled_func(*args, **kwargs)
func_locals = {'_trace_exec_decorator': _trace_exec_decorator}
exec(code, func_globals, func_locals)
return result[0]
def trace(func, *args, **kwargs):
"""Run the given function with the given arguments and keyword arguments,
and whenever a docstring or (whole-line) comment is encountered,
print it to stdout."""
filename = inspect.getsourcefile(func)
lines, lineno = inspect.getsourcelines(func)
lines = map(convert_comment_to_print, lines)
modified_source = ''.join(lines)
compile_flags = get_future_flags(func)
syntax_tree = compile(modified_source, filename, 'exec',
ast.PyCF_ONLY_AST | compile_flags, True)
convert_docstrings_to_prints(syntax_tree)
return eval_function(syntax_tree, func.__globals__,
filename, lineno, compile_flags, *args, **kwargs)
This is a little longer, because I tried to highlight the most important cases, and the code may not be the most readable, but I hope it is good enough to follow.
How it works:
- ,
inspect.getsourcelines
. (: inspect
, . , , dill
, . .) - , , . ( , .)
- AST.
- AST docstrings .
- AST.
- . , , (, ,
__future__
, ). , , .
Python 2 3 ( , , 2.7 3.6).
, :
result = trace(func, 2)
, :
from trace_comments import trace
from dateutil.easter import easter, EASTER_ORTHODOX
def func(x):
"""Function docstring."""
result = x + 1
if result > 0:
return result
else:
return -1 * result
if __name__ == '__main__':
result1 = trace(func, 2)
print("result1 = {}".format(result1))
result2 = trace(func, -10)
print("result2 = {}".format(result2))
result3 = func(42)
print("result3 = {}".format(result3))
print("-----")
print(trace(easter, 2018))
print("-----")
print(trace(easter, 2018, EASTER_ORTHODOX))