diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ab274a9..c36ded1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,7 @@ jobs: checks: name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }} runs-on: ${{ matrix.runs-on }} + needs: [pre-commit] strategy: fail-fast: false matrix: @@ -36,7 +37,6 @@ jobs: - python-version: pypy-3.7 runs-on: ubuntu-latest - steps: - uses: actions/checkout@v2 @@ -47,16 +47,41 @@ jobs: - name: Install package run: python -m pip install .[test] - - name: Test package + - name: Test run: python -m pytest -ra + root: + name: ROOT test + runs-on: ubuntu-latest + needs: [pre-commit] + defaults: + run: + shell: bash -l {0} + + steps: + - uses: actions/checkout@v2 + + - uses: conda-incubator/setup-miniconda@v2 + with: + miniforge-variant: Mambaforge + use-mamba: true + environment-file: environment.yml + + - name: Install package + run: pip install . + + - name: Test root + run: pytest -ra tests/test_root.py + + dist: name: Distribution build runs-on: ubuntu-latest + needs: [pre-commit] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 - name: Build sdist and wheel run: pipx run build diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..6ebbe21 --- /dev/null +++ b/environment.yml @@ -0,0 +1,9 @@ +name: uhi +channels: + - conda-forge +dependencies: + - pip + - pytest + - root + - pytest + - boost-histogram diff --git a/noxfile.py b/noxfile.py index 1b6651b..4c82bef 100644 --- a/noxfile.py +++ b/noxfile.py @@ -17,7 +17,7 @@ def lint(session): Run the linter. """ session.install("pre-commit") - session.run("pre-commit", "run", "--all-files") + session.run("pre-commit", "run", "--all-files", *session.posargs) @nox.session(python=ALL_PYTHONS) @@ -26,7 +26,7 @@ def tests(session): Run the unit and regular tests. """ session.install(".[test]") - session.run("pytest") + session.run("pytest", *session.posargs) @nox.session @@ -63,3 +63,14 @@ def build(session): session.install("build") session.run("python", "-m", "build") + + +@nox.session(venv_backend="conda") +def root_tests(session): + """ + Test against ROOT. + """ + + session.conda_install("--channel=conda-forge", "ROOT", "pytest", "boost-histogram") + session.install(".") + session.run("pytest", "tests/test_root.py") diff --git a/src/uhi/numpy_plottable.py b/src/uhi/numpy_plottable.py index d4df773..4e5c4f8 100644 --- a/src/uhi/numpy_plottable.py +++ b/src/uhi/numpy_plottable.py @@ -160,6 +160,172 @@ def variances(self) -> Optional[np.ndarray]: _: PlottableHistogram = cast(NumPyPlottableHistogram, None) +def _roottarray_asnumpy( + tarr: Any, shape: Optional[Tuple[int, ...]] = None +) -> np.ndarray: + llv = tarr.GetArray() + arr: np.ndarray = np.frombuffer(llv, dtype=llv.typecode, count=tarr.GetSize()) + if shape is not None: + return np.reshape(arr, shape, order="F") + else: + return arr + + +class ROOTAxis: + def __init__(self, tAxis: Any) -> None: + self.tAx = tAxis + + def __len__(self) -> int: + return self.tAx.GetNbins() # type: ignore + + def __getitem__(self, index: int) -> Any: + pass + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, ROOTAxis): + return NotImplemented + return len(self) == len(other) and all( + aEdges == bEdges for aEdges, bEdges in zip(self, other) + ) + + def __iter__(self) -> Union[Iterator[Tuple[float, float]], Iterator[str]]: + pass + + @staticmethod + def create(tAx: Any) -> Union["DiscreteROOTAxis", "ContinuousROOTAxis"]: + if all(tAx.GetBinLabel(i + 1) for i in range(tAx.GetNbins())): + return DiscreteROOTAxis(tAx) + else: + return ContinuousROOTAxis(tAx) + + +class ContinuousROOTAxis(ROOTAxis): + @property + def traits(self) -> PlottableTraits: + return Traits(circular=False, discrete=False) + + def __getitem__(self, index: int) -> Tuple[float, float]: + return (self.tAx.GetBinLowEdge(index + 1), self.tAx.GetBinUpEdge(index + 1)) + + def __iter__(self) -> Iterator[Tuple[float, float]]: + for i in range(len(self)): + yield self[i] + + +class DiscreteROOTAxis(ROOTAxis): + @property + def traits(self) -> PlottableTraits: + return Traits(circular=False, discrete=True) + + def __getitem__(self, index: int) -> str: + return self.tAx.GetBinLabel(index + 1) # type: ignore + + def __iter__(self) -> Iterator[str]: + for i in range(len(self)): + yield self[i] + + +class ROOTPlottableHistBase: + """Common base for ROOT histograms and TProfile""" + + def __init__(self, thist: Any) -> None: + self.thist: Any = thist + nDim = thist.GetDimension() + self._shape: Tuple[int, ...] = tuple( + getattr(thist, f"GetNbins{ax}")() + 2 for ax in "XYZ"[:nDim] + ) + self.axes: Tuple[Union[ContinuousROOTAxis, DiscreteROOTAxis], ...] = tuple( + ROOTAxis.create(getattr(thist, f"Get{ax}axis")()) for ax in "XYZ"[:nDim] + ) + + @property + def name(self) -> str: + return self.thist.GetName() # type: ignore + + +class ROOTPlottableHistogram(ROOTPlottableHistBase): + def __init__(self, thist: Any) -> None: + super().__init__(thist) + + @property + def hasWeights(self) -> bool: + return bool(self.thist.GetSumw2() and self.thist.GetSumw2N()) + + @property + def kind(self) -> str: + return Kind.COUNT + + def values(self) -> np.ndarray: + return _roottarray_asnumpy(self.thist, shape=self._shape)[ # type: ignore + tuple([slice(1, -1)] * len(self._shape)) + ] + + def variances(self) -> np.ndarray: + if self.hasWeights: + return _roottarray_asnumpy(self.thist.GetSumw2(), shape=self._shape)[ # type: ignore + tuple([slice(1, -1)] * len(self._shape)) + ] + else: + return self.values() + + def counts(self) -> np.ndarray: + if not self.hasWeights: + return self.values() + + sumw = self.values() + return np.divide( # type: ignore + sumw ** 2, + self.variances(), + out=np.zeros_like(sumw, dtype=np.float64), + where=sumw != 0, + ) + + +class ROOTPlottableProfile(ROOTPlottableHistBase): + def __init__(self, thist: Any) -> None: + super().__init__(thist) + + @property + def kind(self) -> str: + return Kind.MEAN + + def values(self) -> np.ndarray: + return np.array( # type: ignore + [self.thist.GetBinContent(i) for i in range(self.thist.GetNcells())] + ).reshape(self._shape, order="F")[tuple([slice(1, -1)] * len(self._shape))] + + def variances(self) -> np.ndarray: + return ( # type: ignore + np.array([self.thist.GetBinError(i) for i in range(self.thist.GetNcells())]) + ** 2 + ).reshape(self._shape, order="F")[tuple([slice(1, -1)] * len(self._shape))] + + def counts(self) -> np.ndarray: + sumw = _roottarray_asnumpy(self.thist, shape=self._shape)[ + tuple([slice(1, -1)] * len(self._shape)) + ] + if not (self.thist.GetSumw2() and self.thist.GetSumw2N()): + return sumw # type: ignore + + sumw2 = _roottarray_asnumpy(self.thist.GetSumw2(), shape=self._shape)[ + tuple([slice(1, -1)] * len(self._shape)) + ] + return np.divide( # type: ignore + sumw ** 2, + sumw2, + out=np.zeros_like(sumw, dtype=np.float64), + where=sumw != 0, + ) + + +if TYPE_CHECKING: + # Verify that the above class is a valid PlottableHistogram + _axis = cast(ContinuousROOTAxis, None) + _axis2: PlottableAxisGeneric[str] = cast(DiscreteROOTAxis, None) + _ = cast(ROOTPlottableHistogram, None) + _ = cast(ROOTPlottableProfile, None) + + def ensure_plottable_histogram(hist: Any) -> PlottableHistogram: """ Ensure a histogram follows the PlottableHistogram Protocol. @@ -206,5 +372,14 @@ def ensure_plottable_histogram(hist: Any) -> PlottableHistogram: # Standard tuple return NumPyPlottableHistogram(*(np.asarray(h) for h in hist)) + elif hasattr(hist, "InheritsFrom") and hist.InheritsFrom("TH1"): + if any( + hist.InheritsFrom(profCls) + for profCls in ("TProfile", "TProfile2D", "TProfile3D") + ): + return ROOTPlottableProfile(hist) + + return ROOTPlottableHistogram(hist) + else: raise TypeError(f"Can't be used on this type of object: {hist!r}") diff --git a/tests/test_root.py b/tests/test_root.py new file mode 100644 index 0000000..4b5695e --- /dev/null +++ b/tests/test_root.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest +from pytest import approx + +from uhi.numpy_plottable import ensure_plottable_histogram + +ROOT = pytest.importorskip("ROOT") + + +def test_root_imported() -> None: + assert ROOT.TString("hi") == "hi" + + +def test_root_th1f_convert() -> None: + th = ROOT.TH1F("h1", "h1", 50, -2.5, 2.5) + th.FillRandom("gaus", 10000) + h = ensure_plottable_histogram(th) + assert all(th.GetBinContent(i + 1) == approx(iv) for i, iv in enumerate(h.values())) + assert all( + th.GetBinError(i + 1) == approx(ie) + for i, ie in enumerate(np.sqrt(h.variances())) # type: ignore + ) + + +def test_root_th2f_convert() -> None: + th = ROOT.TH2F("h2", "h2", 50, -2.5, 2.5, 50, -2.5, 2.5) + _ = ROOT.TF2("xyg", "xygaus", -2.5, 2.5, -2.5, 2.5) + th.FillRandom("xyg", 10000) + h = ensure_plottable_histogram(th) + assert all( + th.GetBinContent(i + 1, j + 1) == approx(iv) + for i, row in enumerate(h.values()) + for j, iv in enumerate(row) + ) + assert all( + th.GetBinError(i + 1, j + 1) == approx(ie) + for i, row in enumerate(np.sqrt(h.variances())) # type: ignore + for j, ie in enumerate(row) + )