Skip to content

Commit

Permalink
current allocation statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed May 17, 2024
1 parent fc49fe9 commit b9d57db
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 2 deletions.
55 changes: 55 additions & 0 deletions python/rmm/rmm/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from contextlib import contextmanager
from typing import Dict, Optional

import rmm.mr

Expand Down Expand Up @@ -66,3 +67,57 @@ def statistics():
"RMM memory source stack was changed while in the context"
)
rmm.mr.set_current_device_resource(prior_mr)


def get_statistics() -> Optional[Dict[str, int]]:
"""Get the current allocation statistics
Return
------
If enabled, returns the current tracked statistics.
If disabled, returns None.
"""
mr = rmm.mr.get_current_device_resource()
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.allocation_counts
return None


def push_statistics() -> Optional[Dict[str, int]]:
"""Push new counters on the current allocation statistics stack
This returns the current tracked statistics and push a new set
of zero counters on the stack of statistics.
If statistics are disabled (the current memory resource is not an
instance of `StatisticsResourceAdaptor`), this function is a no-op.
Return
------
If enabled, returns the current tracked statistics _before_ the pop.
If disabled, returns None.
"""
mr = rmm.mr.get_current_device_resource()
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.push_counters()
return None


def pop_statistics() -> Optional[Dict[str, int]]:
"""Pop the counters of the current allocation statistics stack
This returns the counters of current tracked statistics and pops
them from the stack.
If statistics are disabled (the current memory resource is not an
instance of `StatisticsResourceAdaptor`), this function is a no-op.
Return
------
If enabled, returns the popped counters.
If disabled, returns None.
"""
mr = rmm.mr.get_current_device_resource()
if isinstance(mr, rmm.mr.StatisticsResourceAdaptor):
return mr.pop_counters()
return None
54 changes: 52 additions & 2 deletions python/rmm/rmm/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
import pytest

import rmm.mr
from rmm.statistics import statistics
from rmm.statistics import (
get_statistics,
pop_statistics,
push_statistics,
statistics,
)


@pytest.fixture
Expand Down Expand Up @@ -116,7 +121,7 @@ def test_counter_stack(stats_mr):
"total_count": 1,
}
del b1
# pop returns the stats from the top before the pop.
# pop returns the popped stats
# Note, the bytes and counts can be negative
assert stats_mr.pop_counters() == { # stats from stack level 2
"current_bytes": -16,
Expand Down Expand Up @@ -165,3 +170,48 @@ def test_counter_stack(stats_mr):
del buffers
with pytest.raises(IndexError, match="cannot pop the last counter pair"):
stats_mr.pop_counters()


def test_current_statistics(stats_mr):
b1 = rmm.DeviceBuffer(size=10)
assert get_statistics() == {
"current_bytes": 16,
"current_count": 1,
"peak_bytes": 16,
"peak_count": 1,
"total_bytes": 16,
"total_count": 1,
}
b2 = rmm.DeviceBuffer(size=20)
assert push_statistics() == {
"current_bytes": 48,
"current_count": 2,
"peak_bytes": 48,
"peak_count": 2,
"total_bytes": 48,
"total_count": 2,
}
del b1
assert pop_statistics() == {
"current_bytes": -16,
"current_count": -1,
"peak_bytes": 0,
"peak_count": 0,
"total_bytes": 0,
"total_count": 0,
}
del b2
assert get_statistics() == {
"current_bytes": 0,
"current_count": 0,
"peak_bytes": 48,
"peak_count": 2,
"total_bytes": 48,
"total_count": 2,
}


def test_statistics_disabled():
assert get_statistics() is None
assert push_statistics() is None
assert get_statistics() is None

0 comments on commit b9d57db

Please sign in to comment.