Skip to content

Commit

Permalink
Always return a cycle in digraph_find_cycle if no node is specified…
Browse files Browse the repository at this point in the history
… and a cycle exists (#1181)

* Handle find arbitrary cycle case

* Find node in cycle more smartly

* Implement find_node_in_arbitrary_cycle

* Switch to Tarjan SCC for single pass DFS

* Improve cycle checking logic

* More assert_cycle!

* Cargo fmt

* Add test case for no cycle and no source

* assertCycle for existing unit tests

* Add more tests

* Add self loop Python test

* Add release notes and fix test

* Address PR comments

* Use less traits

* Update release notes

---------

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
IvanIsCoding and mergify[bot] committed May 16, 2024
1 parent 12f8af5 commit 8e81911
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 36 deletions.
12 changes: 12 additions & 0 deletions releasenotes/notes/fix-digraph-find-cycle-141e302ff4a8fcd4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
fixes:
- |
Fixed the behavior of :func:`~rustworkx.digraph_find_cycle` when
no source node was provided. Previously, the function would start looking
for a cycle at an arbitrary node which was not guaranteed to return a cycle.
Now, the function will smartly choose a source node to start the search from
such that if a cycle exists, it will be found.
other:
- |
The `rustworkx-core` function `rustworkx_core::connectivity::find_cycle` now
requires the `petgraph::visit::Visitable` trait.
112 changes: 82 additions & 30 deletions rustworkx-core/src/connectivity/find_cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
// under the License.

use hashbrown::{HashMap, HashSet};
use petgraph::algo;
use petgraph::visit::{
EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount,
EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable,
};
use petgraph::Direction::Outgoing;
use std::hash::Hash;
Expand Down Expand Up @@ -57,22 +58,22 @@ where
G: GraphBase,
G: NodeCount,
G: EdgeCount,
for<'b> &'b G: GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected,
for<'b> &'b G:
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
G::NodeId: Eq + Hash,
{
// Find a cycle in the given graph and return it as a list of edges
let mut graph_nodes: HashSet<G::NodeId> = graph.node_identifiers().collect();
let mut cycle: Vec<(G::NodeId, G::NodeId)> = Vec::with_capacity(graph.edge_count());
let temp_value: G::NodeId;
// If source is not set get an arbitrary node from the set of graph
// nodes we've not "examined"
// If source is not set get a node in an arbitrary cycle if it exists,
// otherwise return that there is no cycle
let source_index = match source {
Some(source_value) => source_value,
None => {
temp_value = *graph_nodes.iter().next().unwrap();
graph_nodes.remove(&temp_value);
temp_value
}
None => match find_node_in_arbitrary_cycle(&graph) {
Some(node_in_cycle) => node_in_cycle,
None => {
return Vec::new();
}
},
};
// Stack (ie "pushdown list") of vertices already in the spanning tree
let mut stack: Vec<G::NodeId> = vec![source_index];
Expand Down Expand Up @@ -119,11 +120,47 @@ where
cycle
}

fn find_node_in_arbitrary_cycle<G>(graph: &G) -> Option<G::NodeId>
where
G: GraphBase,
G: NodeCount,
G: EdgeCount,
for<'b> &'b G:
GraphBase<NodeId = G::NodeId> + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable,
G::NodeId: Eq + Hash,
{
for scc in algo::kosaraju_scc(&graph) {
if scc.len() > 1 {
return Some(scc[0]);
}
}
for node in graph.node_identifiers() {
for neighbor in graph.neighbors_directed(node, Outgoing) {
if neighbor == node {
return Some(node);
}
}
}
None
}

#[cfg(test)]
mod tests {
use crate::connectivity::find_cycle;
use petgraph::prelude::*;

// Utility to assert cycles in the response
macro_rules! assert_cycle {
($g: expr, $cycle: expr) => {{
for i in 0..$cycle.len() {
let (s, t) = $cycle[i];
assert!($g.contains_edge(s, t));
let (next_s, _) = $cycle[(i + 1) % $cycle.len()];
assert_eq!(t, next_s);
}
}};
}

#[test]
fn test_find_cycle_source() {
let edge_list = vec![
Expand All @@ -141,20 +178,13 @@ mod tests {
(8, 9),
];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let mut res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
assert_eq!(res, [(0, 1), (1, 2), (2, 3), (3, 0)]);
res = find_cycle(&graph, Some(NodeIndex::new(1)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
assert_eq!(res, [(1, 2), (2, 3), (3, 0), (0, 1)]);
res = find_cycle(&graph, Some(NodeIndex::new(5)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
for i in [0, 1, 2, 3].iter() {
let idx = NodeIndex::new(*i);
let res = find_cycle(&graph, Some(idx));
assert_cycle!(graph, res);
assert_eq!(res[0].0, idx);
}
let res = find_cycle(&graph, Some(NodeIndex::new(5)));
assert_eq!(res, []);
}

Expand All @@ -176,10 +206,32 @@ mod tests {
];
let mut graph = DiGraph::<i32, i32>::from_edges(edge_list);
graph.add_edge(NodeIndex::new(1), NodeIndex::new(1), 0);
let res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0)))
.iter()
.map(|(s, t)| (s.index(), t.index()))
.collect();
assert_eq!(res, [(1, 1)]);
let res = find_cycle(&graph, Some(NodeIndex::new(0)));
assert_eq!(res[0].0, NodeIndex::new(1));
assert_cycle!(graph, res);
}

#[test]
fn test_self_loop_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3), (2, 2)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_cycle!(graph, res);
}

#[test]
fn test_cycle_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 2)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_cycle!(graph, res);
}

#[test]
fn test_no_cycle_no_source() {
let edge_list = vec![(0, 1), (1, 2), (2, 3)];
let graph = DiGraph::<i32, i32>::from_edges(edge_list);
let res = find_cycle(&graph, None);
assert_eq!(res, []);
}
}
45 changes: 39 additions & 6 deletions tests/digraph/test_find_cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import unittest

import rustworkx
import rustworkx.generators


class TestFindCycle(unittest.TestCase):
Expand All @@ -36,30 +37,38 @@ def setUp(self):
]
)

def assertCycle(self, first_node, graph, res):
self.assertEqual(first_node, res[0][0])
for i in range(len(res)):
s, t = res[i]
self.assertTrue(graph.has_edge(s, t))
next_s, _ = res[(i + 1) % len(res)]
self.assertEqual(t, next_s)

def test_find_cycle(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from(list(range(6)))
graph.add_edges_from_no_data(
[(0, 1), (0, 3), (0, 5), (1, 2), (2, 3), (3, 4), (4, 5), (4, 0)]
)
res = rustworkx.digraph_find_cycle(graph, 0)
self.assertEqual([(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)], res)
self.assertCycle(0, graph, res)

def test_find_cycle_multiple_roots_same_cycles(self):
res = rustworkx.digraph_find_cycle(self.graph, 0)
self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)])
self.assertCycle(0, self.graph, res)
res = rustworkx.digraph_find_cycle(self.graph, 1)
self.assertEqual(res, [(1, 2), (2, 3), (3, 0), (0, 1)])
self.assertCycle(1, self.graph, res)
res = rustworkx.digraph_find_cycle(self.graph, 5)
self.assertEqual(res, [])

def test_find_cycle_disconnected_graphs(self):
self.graph.add_nodes_from(["A", "B", "C"])
self.graph.add_edges_from_no_data([(10, 11), (12, 10), (11, 12)])
res = rustworkx.digraph_find_cycle(self.graph, 0)
self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)])
self.assertCycle(0, self.graph, res)
res = rustworkx.digraph_find_cycle(self.graph, 10)
self.assertEqual(res, [(10, 11), (11, 12), (12, 10)])
self.assertCycle(10, self.graph, res)

def test_invalid_types(self):
graph = rustworkx.PyGraph()
Expand All @@ -69,4 +78,28 @@ def test_invalid_types(self):
def test_self_loop(self):
self.graph.add_edge(1, 1, None)
res = rustworkx.digraph_find_cycle(self.graph, 0)
self.assertEqual([(1, 1)], res)
self.assertCycle(1, self.graph, res)

def test_no_cycle_no_source(self):
g = rustworkx.generators.directed_grid_graph(10, 10)
res = rustworkx.digraph_find_cycle(g)
self.assertEqual(res, [])

def test_cycle_no_source(self):
g = rustworkx.generators.directed_path_graph(1000)
a = g.add_node(1000)
b = g.node_indices()[-2]
g.add_edge(b, a, None)
g.add_edge(a, b, None)
res = rustworkx.digraph_find_cycle(g)
self.assertEqual(len(res), 2)
self.assertTrue(res[0] == res[1][::-1])

def test_cycle_self_loop(self):
g = rustworkx.generators.directed_path_graph(1000)
a = g.add_node(1000)
b = g.node_indices()[-1]
g.add_edge(b, a, None)
g.add_edge(a, a, None)
res = rustworkx.digraph_find_cycle(g)
self.assertEqual(res, [(a, a)])

0 comments on commit 8e81911

Please sign in to comment.