Source code for rlscope.profiler.log_stacktrace

"""
Code for recording and logging stack traces of training scripts to help determine
how Python interfaces with native C libraries.

This code was useful for determining how Python calls into PyTorch; PyTorch
has multiple native shared libraries it calls into.
"""
import textwrap
import traceback
import contextlib
from io import StringIO

import typing

# import tensorflow as tf

from rlscope.profiler.rlscope_logging import logger
from rlscope.profiler import wrap_util

# Intercept tf.Session.run(...) calls to see when calls to TensorFlow graph computations are made.
#
# Never called with --python-mode...

def print_indent(ss, indent):
    if indent == 0 or indent is None:
        return
    ss.write('  '*indent)

def with_indent(txt, indent):
    if indent == 0 or indent is None:
        return txt
    return textwrap.indent(txt, prefix='  '*indent)

class LoggedStackTrace:
    def __init__(self, name, format_stack):
        self.name = name
        self.format_stack = format_stack
        self.num_calls = 0
        self.printed = False

    def add_call(self):
        self.num_calls += 1
        self.printed = False

    def print(self, ss, skip_last=0, indent=0):
        keep_stack = self.format_stack[:len(self.format_stack)-skip_last]
        ss.write(with_indent(''.join(keep_stack), indent))
        self.printed = True

class _LoggedStackTraces:
    def __init__(self):
        # traceback.format_stack() ->
        self.stacktraces = dict()

    def _key(self, name, format_stack):
        return tuple(format_stack)

    def log_call(self, name, format_stack):
        key = self._key(name, format_stack)
        stacktrace = self.stacktraces.get(key, None)
        if stacktrace is None:
            stacktrace = LoggedStackTrace(name, format_stack)
            self.stacktraces[key] = stacktrace
        stacktrace.add_call()

    def num_to_print(self):
        n = 0
        for st in self.stacktraces.values():
            if not st.printed:
                n += 1
        return n

    def print(self, ss, skip_last=0, indent=0):
        # Only print stacktraces for functions that have been called since we last printed.
        stacktraces = [st for st in self.stacktraces.values() if not st.printed]
        # Sort by number of calls
        stacktraces.sort(key=lambda st: (st.num_calls, st.name))
        print_indent(ss, indent)
        ss.write("Stacktraces ordered by number of calls (since last printed)\n")
        for i, st in enumerate(stacktraces):
            print_indent(ss, indent+1)
            ss.write("Stacktrace[{i}] num_calls={num_calls}: {name}\n".format(
                i=i,
                num_calls=st.num_calls,
                name=st.name,
            ))
            st.print(ss, indent=indent+2, skip_last=skip_last)

    def wrap_module(self, module, should_wrap=None):
        wrap_util.wrap_module(LoggedCall, module, should_wrap=should_wrap)

    def unwrap_module(self, module):
        wrap_util.unwrap_module(LoggedCall, module)

    def wrap_func(self, module, name, should_wrap=None):
        wrap_util.wrap_func(LoggedCall, module, name, should_wrap=should_wrap)

    def unwrap_func(self, module, name):
        wrap_util.unwrap_func(LoggedCall, module, name)

def log_call(func, name, *args, **kwargs):
    if LoggedStackTraces is not None:
        stack = traceback.format_stack()
        LoggedStackTraces.log_call(name, stack)
    return func(*args, **kwargs)

class LoggedCall:
    def __init__(self, func, name=None):
        self.func = func
        if name is None:
            name = self.func.__name__
        self.name = name

    # -> typing.Any:
    def __call__(self, *args, **kwargs):
        if LoggedStackTraces is not None:
            stack = traceback.format_stack()
            LoggedStackTraces.log_call(self.name, stack)
        ret = self.func(*args, **kwargs)
        return ret

LoggedStackTraces = _LoggedStackTraces()

# LoggedStackTraces = None
# def setup_logging_stack_traces(FLAGS):
#     global LoggedStackTraces
#     WRAP_TF_SESSION_RUN = FLAGS.log_stacktrace_freq is not None
#     if WRAP_TF_SESSION_RUN:
#         LoggedStackTraces = _LoggedStackTraces()
#
#         original_tf_Session_run = tf.compat.v1.Session.run
#         def wrapped_tf_Session_run(self, fetches, feed_dict=None, options=None, run_metadata=None):
#             return log_call(original_tf_Session_run, "tf.Session.run", self, fetches, feed_dict=feed_dict, options=options, run_metadata=run_metadata)
#         tf.compat.v1.Session.run = wrapped_tf_Session_run
#
#         from tensorflow.python import pywrap_tfe
#
#         original_pywrap_tfe_TFE_Py_Execute = pywrap_tfe.TFE_Py_Execute
#         def wrapped_pywrap_tfe_TFE_Py_Execute(*args, **kwargs):
#             return log_call(original_pywrap_tfe_TFE_Py_Execute, "TFE_Py_Execute", *args, **kwargs)
#         pywrap_tfe.TFE_Py_Execute = wrapped_pywrap_tfe_TFE_Py_Execute
#
#         original_pywrap_tfe_TFE_Py_FastPathExecute = pywrap_tfe.TFE_Py_FastPathExecute
#         def wrapped_pywrap_tfe_TFE_Py_FastPathExecute(*args, **kwargs):
#             return log_call(original_pywrap_tfe_TFE_Py_FastPathExecute, "TFE_Py_FastPathExecute", *args, **kwargs)
#         pywrap_tfe.TFE_Py_FastPathExecute = wrapped_pywrap_tfe_TFE_Py_FastPathExecute


[docs]@contextlib.contextmanager def with_log_stacktraces(): """Context manager for soft device placement, allowing summaries on CPU. Eager and graph contexts have different default device placements. See b/148408921 for details. This context manager should be used whenever using summary writers contexts to make sure summaries work when executing on TPUs. Yields: Sets `tf.config.set_soft_device_placement(True)` within the context """ try: yield finally: log_stacktraces()
def log_stacktraces(): if LoggedStackTraces is not None and LoggedStackTraces.num_to_print() > 0: ss = StringIO() # stack[-1] = Call to "traceback.format_stack()" # stack[-2] = Call to "return log_call(...)" # LoggedStackTraces.print(ss, skip_last=2, indent=0) LoggedStackTraces.print(ss, skip_last=1, indent=0) logger.info(ss.getvalue().rstrip())