|
7 | 7 |
|
8 | 8 | _CMD_USAGE = "python -m memory_profiler script_file.py" |
9 | 9 |
|
10 | | -from functools import wraps |
| 10 | +from asyncio import coroutine, iscoroutinefunction |
| 11 | +from contextlib import contextmanager |
| 12 | +from functools import partial, wraps |
| 13 | +import builtins |
11 | 14 | import inspect |
12 | 15 | import linecache |
13 | 16 | import logging |
|
18 | 21 | import time |
19 | 22 | import traceback |
20 | 23 | import warnings |
21 | | -import contextlib |
22 | 24 |
|
23 | 25 | if sys.platform == "win32": |
24 | 26 | # any value except signal.CTRL_C_EVENT and signal.CTRL_BREAK_EVENT |
|
43 | 45 | line_cell_magic = lambda func: func |
44 | 46 | magics_class = lambda cls: cls |
45 | 47 |
|
46 | | -PY2 = sys.version_info[0] == 2 |
47 | | - |
48 | 48 | _TWO_20 = float(2 ** 20) |
49 | 49 |
|
50 | | -if PY2: |
51 | | - import __builtin__ as builtins |
52 | | - to_str = lambda x: x |
53 | | - from future_builtins import filter |
54 | | -else: |
55 | | - import builtins |
56 | | - to_str = lambda x: str(x) |
57 | 50 |
|
58 | 51 | # .. get available packages .. |
59 | 52 | try: |
@@ -580,7 +573,7 @@ def f(*args, **kwds): |
580 | 573 |
|
581 | 574 | return f |
582 | 575 |
|
583 | | - @contextlib.contextmanager |
| 576 | + @contextmanager |
584 | 577 | def call_on_stack(self, func, *args, **kwds): |
585 | 578 | self.current_stack_level += 1 |
586 | 579 | self.stack[func].append(self.current_stack_level) |
@@ -701,16 +694,27 @@ def add_function(self, func): |
701 | 694 | else: |
702 | 695 | self.code_map.add(code) |
703 | 696 |
|
| 697 | + @contextmanager |
| 698 | + def _count_ctxmgr(self): |
| 699 | + self.enable_by_count() |
| 700 | + try: |
| 701 | + yield |
| 702 | + finally: |
| 703 | + self.disable_by_count() |
| 704 | + |
704 | 705 | def wrap_function(self, func): |
705 | 706 | """ Wrap a function to profile it. |
706 | 707 | """ |
707 | 708 |
|
708 | | - def f(*args, **kwds): |
709 | | - self.enable_by_count() |
710 | | - try: |
711 | | - return func(*args, **kwds) |
712 | | - finally: |
713 | | - self.disable_by_count() |
| 709 | + if iscoroutinefunction(func): |
| 710 | + @coroutine |
| 711 | + def f(*args, **kwargs): |
| 712 | + with self._count_ctxmgr(): |
| 713 | + yield from func(*args, **kwargs) |
| 714 | + else: |
| 715 | + def f(*args, **kwds): |
| 716 | + with self._count_ctxmgr(): |
| 717 | + return func(*args, **kwds) |
714 | 718 |
|
715 | 719 | return f |
716 | 720 |
|
@@ -831,7 +835,7 @@ def show_results(prof, stream=None, precision=1): |
831 | 835 | inc = u'' |
832 | 836 | occurences = u'' |
833 | 837 | tmp = template.format(lineno, total_mem, inc, occurences, all_lines[lineno - 1]) |
834 | | - stream.write(to_str(tmp)) |
| 838 | + stream.write(tmp) |
835 | 839 | stream.write(u'\n\n') |
836 | 840 |
|
837 | 841 |
|
@@ -1121,12 +1125,23 @@ def profile(func=None, stream=None, precision=1, backend='psutil'): |
1121 | 1125 | if not tracemalloc.is_tracing(): |
1122 | 1126 | tracemalloc.start() |
1123 | 1127 | if func is not None: |
1124 | | - @wraps(func) |
1125 | | - def wrapper(*args, **kwargs): |
1126 | | - prof = LineProfiler(backend=backend) |
1127 | | - val = prof(func)(*args, **kwargs) |
1128 | | - show_results(prof, stream=stream, precision=precision) |
1129 | | - return val |
| 1128 | + get_prof = partial(LineProfiler, backend=backend) |
| 1129 | + show_results_bound = partial( |
| 1130 | + show_results, stream=stream, precision=precision |
| 1131 | + ) |
| 1132 | + if iscoroutinefunction(func): |
| 1133 | + @coroutine |
| 1134 | + def wrapper(*args, **kwargs): |
| 1135 | + prof = get_prof() |
| 1136 | + val = yield from prof(func)(*args, **kwargs) |
| 1137 | + show_results_bound(prof) |
| 1138 | + return val |
| 1139 | + else: |
| 1140 | + def wrapper(*args, **kwargs): |
| 1141 | + prof = get_prof() |
| 1142 | + val = prof(func)(*args, **kwargs) |
| 1143 | + show_results_bound(prof) |
| 1144 | + return val |
1130 | 1145 |
|
1131 | 1146 | return wrapper |
1132 | 1147 | else: |
@@ -1198,16 +1213,13 @@ def run_module_with_profiler(module, profiler, backend, passed_args=[]): |
1198 | 1213 | ns = dict(_CLEAN_GLOBALS, profile=profiler) |
1199 | 1214 | _backend = choose_backend(backend) |
1200 | 1215 | sys.argv = [module] + passed_args |
1201 | | - if PY2: |
| 1216 | + if _backend == 'tracemalloc' and has_tracemalloc: |
| 1217 | + tracemalloc.start() |
| 1218 | + try: |
1202 | 1219 | run_module(module, run_name="__main__", init_globals=ns) |
1203 | | - else: |
1204 | | - if _backend == 'tracemalloc' and has_tracemalloc: |
1205 | | - tracemalloc.start() |
1206 | | - try: |
1207 | | - run_module(module, run_name="__main__", init_globals=ns) |
1208 | | - finally: |
1209 | | - if has_tracemalloc and tracemalloc.is_tracing(): |
1210 | | - tracemalloc.stop() |
| 1220 | + finally: |
| 1221 | + if has_tracemalloc and tracemalloc.is_tracing(): |
| 1222 | + tracemalloc.stop() |
1211 | 1223 |
|
1212 | 1224 |
|
1213 | 1225 | class LogFile(object): |
|
0 commit comments