Skip to content

Commit

Permalink
Context to enable allocation statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed May 17, 2024
1 parent 0e2f19c commit 7f7f940
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 35 deletions.
76 changes: 43 additions & 33 deletions python/rmm/rmm/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,37 +38,6 @@ def enable_statistics() -> None:
)


@contextmanager
def statistics():
"""Context to enable allocation statistics temporarily.
Warning
-------
This modifies the current RMM memory resource. StatisticsResourceAdaptor
is pushed onto the current RMM memory resource stack when entering the
context and popped again when exiting. If statistics has been enabled for
the current RMM resource stack already, this is a no-op.
Raises
------
ValueError
If the RMM memory source stack was changed while in the context.
"""

# Save the current memory resource for later cleanup
prior_mr = rmm.mr.get_current_device_resource()
enable_statistics()
try:
current_mr = rmm.mr.get_current_device_resource()
yield
finally:
if current_mr is not rmm.mr.get_current_device_resource():
raise ValueError(
"RMM memory source stack was changed while in the context"
)
rmm.mr.set_current_device_resource(prior_mr)


def get_statistics() -> Optional[Dict[str, int]]:
"""Get the current allocation statistics
Expand All @@ -90,7 +59,7 @@ def push_statistics() -> Optional[Dict[str, int]]:
of zero counters on the stack of statistics.
If statistics are disabled (the current memory resource is not an
instance of `StatisticsResourceAdaptor`), this function is a no-op.
instance of StatisticsResourceAdaptor), this function is a no-op.
Return
------
Expand All @@ -110,7 +79,7 @@ def pop_statistics() -> Optional[Dict[str, int]]:
them from the stack.
If statistics are disabled (the current memory resource is not an
instance of `StatisticsResourceAdaptor`), this function is a no-op.
instance of StatisticsResourceAdaptor), this function is a no-op.
Return
------
Expand All @@ -121,3 +90,44 @@ def pop_statistics() -> Optional[Dict[str, int]]:
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.pop_counters()
return None


@contextmanager
def statistics():
"""Context to enable allocation statistics.
If statistics has been enabled already (the current memory resource is an
instance of StatisticsResourceAdaptor), new counters are pushed on the
current allocation statistics stack when entering the context and popped
again when exiting using `push_statistics()` and `push_statistics()`.
If statistics has not been enabled, StatisticsResourceAdaptor is set as
the current RMM memory resource when entering the context and removed
again when exiting.
Raises
------
ValueError
If the current RMM memory source was changed while in the context.
"""

if push_statistics() is None:
# Save the current non-statistics memory resource for later cleanup
prior_non_stats_mr = rmm.mr.get_current_device_resource()
enable_statistics()
else:
prior_non_stats_mr = None

try:
current_mr = rmm.mr.get_current_device_resource()
yield
finally:
if current_mr is not rmm.mr.get_current_device_resource():
raise ValueError(
"RMM memory source stack was changed "
"while in the statistics context"
)
if prior_non_stats_mr is None:
pop_statistics()
else:
rmm.mr.set_current_device_resource(prior_non_stats_mr)
37 changes: 35 additions & 2 deletions python/rmm/rmm/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,46 @@ def stats_mr():


def test_context():
prior_mr = rmm.mr.get_current_device_resource()
mr0 = rmm.mr.get_current_device_resource()
assert get_statistics() is None
with statistics():
mr1 = rmm.mr.get_current_device_resource()
assert isinstance(
rmm.mr.get_current_device_resource(),
rmm.mr.StatisticsResourceAdaptor,
)
assert rmm.mr.get_current_device_resource() is prior_mr
b1 = rmm.DeviceBuffer(size=20)
assert get_statistics() == {
"current_bytes": 32,
"current_count": 1,
"peak_bytes": 32,
"peak_count": 1,
"total_bytes": 32,
"total_count": 1,
}
with statistics():
mr2 = rmm.mr.get_current_device_resource()
assert mr1 is mr2
b2 = rmm.DeviceBuffer(size=10)
assert get_statistics() == {
"current_bytes": 16,
"current_count": 1,
"peak_bytes": 16,
"peak_count": 1,
"total_bytes": 16,
"total_count": 1,
}
assert get_statistics() == {
"current_bytes": 48,
"current_count": 2,
"peak_bytes": 48,
"peak_count": 2,
"total_bytes": 48,
"total_count": 2,
}
del b1
del b2
assert rmm.mr.get_current_device_resource() is mr0


def test_multiple_mr(stats_mr):
Expand Down

0 comments on commit 7f7f940

Please sign in to comment.