diff --git a/test/test_neighborhood.py b/test/test_neighborhood.py new file mode 100644 index 0000000..77bd6af --- /dev/null +++ b/test/test_neighborhood.py @@ -0,0 +1,49 @@ +"""Testing the neighborhood module.""" + +import pytest +import toponetx as tnx + +import topoembedx as tex + + +class TestNeighborhood: + """Test the neighborhood module of TopoEmbedX.""" + + def test_neighborhood_from_complex_raise_error(self): + """Testing if right assertion is raised for incorrect type.""" + with pytest.raises(TypeError) as e: + tex.neighborhood.neighborhood_from_complex(1) + + assert ( + str(e.value) + == """Input Complex can only be a Simplicial, Cell or Combinatorial Complex.""" + ) + + def test_neighborhood_from_complex_matrix_dimension_cell_complex(self): + """Testing the matrix dimensions for the adjacency and coadjacency matrices.""" + # Testing for the case of Cell Complex + cc1 = tnx.classes.CellComplex( + [[0, 1, 2, 3], [1, 2, 3, 4], [1, 3, 4, 5, 6, 7, 8]] + ) + + cc2 = tnx.classes.CellComplex([[0, 1, 2], [1, 2, 3]]) + + ind, A = tex.neighborhood.neighborhood_from_complex(cc1) + assert A.todense().shape == tuple([9, 9]) + assert len(ind) == 9 + + ind, A = tex.neighborhood.neighborhood_from_complex(cc2) + assert A.todense().shape == tuple([4, 4]) + assert len(ind) == 4 + + ind, A = tex.neighborhood.neighborhood_from_complex( + cc1, neighborhood_type="!adj" + ) + assert A.todense().shape == tuple([9, 9]) + assert len(ind) == 9 + + ind, A = tex.neighborhood.neighborhood_from_complex( + cc2, neighborhood_type="!adj" + ) + assert A.todense().shape == tuple([4, 4]) + assert len(ind) == 4 diff --git a/topoembedx/neighborhood.py b/topoembedx/neighborhood.py index 67a3fa2..fee5089 100644 --- a/topoembedx/neighborhood.py +++ b/topoembedx/neighborhood.py @@ -14,7 +14,7 @@ def neighborhood_from_complex( Parameters ---------- - complex : SimplicialComplex or CellComplex or CombinatorialComplex or CombinatorialComplex + complex : SimplicialComplex or CellComplex or CombinatorialComplex The complex to compute the neighborhood for. neighborhood_type : str The type of neighborhood to compute. "adj" for adjacency matrix, "coadj" for coadjacency matrix. @@ -41,8 +41,7 @@ def neighborhood_from_complex( Raises ------ ValueError - If the input `complex` is not a SimplicialComplex, CellComplex, CombinatorialComplex, or - CombinatorialComplex. + If the input `complex` is not a SimplicialComplex, CellComplex or CombinatorialComplex """ if isinstance(complex, SimplicialComplex) or isinstance(complex, CellComplex): if neighborhood_type == "adj": @@ -50,9 +49,7 @@ def neighborhood_from_complex( else: ind, A = complex.coadjacency_matrix(neighborhood_dim["adj"], index=True) - elif isinstance(complex, CombinatorialComplex) or isinstance( - complex, CombinatorialComplex - ): + elif isinstance(complex, CombinatorialComplex): if neighborhood_type == "adj": ind, A = complex.adjacency_matrix( neighborhood_dim["adj"], neighborhood_dim["coadj"], index=True @@ -62,8 +59,8 @@ def neighborhood_from_complex( neighborhood_dim["coadj"], neighborhood_dim["adj"], index=True ) else: - ValueError( - "input complex must be SimplicialComplex,CellComplex,CombinatorialComplex, or CombinatorialComplex " + raise TypeError( + """Input Complex can only be a Simplicial, Cell or Combinatorial Complex.""" ) return ind, A