Skip to content

Commit

Permalink
Move Minimum Spanning Tree Algorithm to its own module (#624)
Browse files Browse the repository at this point in the history
* refact: move minimum spanning tree algo to its own module

* refact: move min_spanning_tree benches to a different test file
  • Loading branch information
BryanCruz committed Apr 1, 2024
1 parent 4678de4 commit 08b0ad9
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 218 deletions.
60 changes: 60 additions & 0 deletions benches/min_spanning_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#![feature(test)]

extern crate petgraph;
extern crate test;

use test::Bencher;

#[allow(dead_code)]
mod common;
use common::{digraph, ungraph};

use petgraph::algo::min_spanning_tree;

#[bench]
fn min_spanning_tree_praust_undir_bench(bench: &mut Bencher) {
let a = ungraph().praust_a();
let b = ungraph().praust_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_praust_dir_bench(bench: &mut Bencher) {
let a = digraph().praust_a();
let b = digraph().praust_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_full_undir_bench(bench: &mut Bencher) {
let a = ungraph().full_a();
let b = ungraph().full_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_full_dir_bench(bench: &mut Bencher) {
let a = digraph().full_a();
let b = digraph().full_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_petersen_undir_bench(bench: &mut Bencher) {
let a = ungraph().petersen_a();
let b = ungraph().petersen_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_petersen_dir_bench(bench: &mut Bencher) {
let a = digraph().petersen_a();
let b = digraph().petersen_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}
50 changes: 1 addition & 49 deletions benches/unionfind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use test::Bencher;
mod common;
use common::*;

use petgraph::algo::{connected_components, is_cyclic_undirected, min_spanning_tree};
use petgraph::algo::{connected_components, is_cyclic_undirected};

#[bench]
fn connected_components_praust_undir_bench(bench: &mut Bencher) {
Expand Down Expand Up @@ -106,51 +106,3 @@ fn is_cyclic_undirected_petersen_dir_bench(bench: &mut Bencher) {

bench.iter(|| (is_cyclic_undirected(&a), is_cyclic_undirected(&b)));
}

#[bench]
fn min_spanning_tree_praust_undir_bench(bench: &mut Bencher) {
let a = ungraph().praust_a();
let b = ungraph().praust_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_praust_dir_bench(bench: &mut Bencher) {
let a = digraph().praust_a();
let b = digraph().praust_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_full_undir_bench(bench: &mut Bencher) {
let a = ungraph().full_a();
let b = ungraph().full_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_full_dir_bench(bench: &mut Bencher) {
let a = digraph().full_a();
let b = digraph().full_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_petersen_undir_bench(bench: &mut Bencher) {
let a = ungraph().petersen_a();
let b = ungraph().petersen_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}

#[bench]
fn min_spanning_tree_petersen_dir_bench(bench: &mut Bencher) {
let a = digraph().petersen_a();
let b = digraph().petersen_b();

bench.iter(|| (min_spanning_tree(&a), min_spanning_tree(&b)));
}
117 changes: 117 additions & 0 deletions src/algo/min_spanning_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
//! Minimum Spanning Tree algorithms.

use std::collections::{BinaryHeap, HashMap};

use crate::prelude::*;

use crate::data::Element;
use crate::scored::MinScored;
use crate::unionfind::UnionFind;
use crate::visit::{Data, IntoNodeReferences, NodeRef};
use crate::visit::{IntoEdgeReferences, NodeIndexable};

/// \[Generic\] Compute a *minimum spanning tree* of a graph.
///
/// The input graph is treated as if undirected.
///
/// Using Kruskal's algorithm with runtime **O(|E| log |E|)**. We actually
/// return a minimum spanning forest, i.e. a minimum spanning tree for each connected
/// component of the graph.
///
/// The resulting graph has all the vertices of the input graph (with identical node indices),
/// and **|V| - c** edges, where **c** is the number of connected components in `g`.
///
/// Use `from_elements` to create a graph from the resulting iterator.
pub fn min_spanning_tree<G>(g: G) -> MinSpanningTree<G>
where
G::NodeWeight: Clone,
G::EdgeWeight: Clone + PartialOrd,
G: IntoNodeReferences + IntoEdgeReferences + NodeIndexable,
{
// Initially each vertex is its own disjoint subgraph, track the connectedness
// of the pre-MST with a union & find datastructure.
let subgraphs = UnionFind::new(g.node_bound());

let edges = g.edge_references();
let mut sort_edges = BinaryHeap::with_capacity(edges.size_hint().0);
for edge in edges {
sort_edges.push(MinScored(
edge.weight().clone(),
(edge.source(), edge.target()),
));
}

MinSpanningTree {
graph: g,
node_ids: Some(g.node_references()),
subgraphs,
sort_edges,
node_map: HashMap::new(),
node_count: 0,
}
}

/// An iterator producing a minimum spanning forest of a graph.
#[derive(Debug, Clone)]
pub struct MinSpanningTree<G>
where
G: Data + IntoNodeReferences,
{
graph: G,
node_ids: Option<G::NodeReferences>,
subgraphs: UnionFind<usize>,
#[allow(clippy::type_complexity)]
sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
node_map: HashMap<usize, usize>,
node_count: usize,
}

impl<G> Iterator for MinSpanningTree<G>
where
G: IntoNodeReferences + NodeIndexable,
G::NodeWeight: Clone,
G::EdgeWeight: PartialOrd,
{
type Item = Element<G::NodeWeight, G::EdgeWeight>;

fn next(&mut self) -> Option<Self::Item> {
let g = self.graph;
if let Some(ref mut iter) = self.node_ids {
if let Some(node) = iter.next() {
self.node_map.insert(g.to_index(node.id()), self.node_count);
self.node_count += 1;
return Some(Element::Node {
weight: node.weight().clone(),
});
}
}
self.node_ids = None;

// Kruskal's algorithm.
// Algorithm is this:
//
// 1. Create a pre-MST with all the vertices and no edges.
// 2. Repeat:
//
// a. Remove the shortest edge from the original graph.
// b. If the edge connects two disjoint trees in the pre-MST,
// add the edge.
while let Some(MinScored(score, (a, b))) = self.sort_edges.pop() {
// check if the edge would connect two disjoint parts
let (a_index, b_index) = (g.to_index(a), g.to_index(b));
if self.subgraphs.union(a_index, b_index) {
let (&a_order, &b_order) =
match (self.node_map.get(&a_index), self.node_map.get(&b_index)) {
(Some(a_id), Some(b_id)) => (a_id, b_id),
_ => panic!("Edge references unknown node"),
};
return Some(Element::Edge {
source: a_order,
target: b_order,
weight: score,
});
}
}
None
}
}
112 changes: 2 additions & 110 deletions src/algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ pub mod floyd_warshall;
pub mod isomorphism;
pub mod k_shortest_path;
pub mod matching;
pub mod min_spanning_tree;
pub mod page_rank;
pub mod simple_paths;
pub mod tred;

use std::collections::{BinaryHeap, HashMap};
use std::num::NonZeroUsize;

use crate::prelude::*;
Expand All @@ -29,10 +29,7 @@ use super::visit::{
IntoNodeIdentifiers, NodeCompactIndexable, NodeIndexable, Reversed, VisitMap, Visitable,
};
use super::EdgeType;
use crate::data::Element;
use crate::scored::MinScored;
use crate::visit::Walker;
use crate::visit::{Data, IntoNodeReferences, NodeRef};

pub use astar::astar;
pub use bellman_ford::{bellman_ford, find_negative_cycle};
Expand All @@ -45,6 +42,7 @@ pub use isomorphism::{
};
pub use k_shortest_path::k_shortest_path;
pub use matching::{greedy_matching, maximum_matching, Matching};
pub use min_spanning_tree::min_spanning_tree;
pub use page_rank::page_rank;
pub use simple_paths::all_simple_paths;

Expand Down Expand Up @@ -637,112 +635,6 @@ where
condensed
}

/// \[Generic\] Compute a *minimum spanning tree* of a graph.
///
/// The input graph is treated as if undirected.
///
/// Using Kruskal's algorithm with runtime **O(|E| log |E|)**. We actually
/// return a minimum spanning forest, i.e. a minimum spanning tree for each connected
/// component of the graph.
///
/// The resulting graph has all the vertices of the input graph (with identical node indices),
/// and **|V| - c** edges, where **c** is the number of connected components in `g`.
///
/// Use `from_elements` to create a graph from the resulting iterator.
pub fn min_spanning_tree<G>(g: G) -> MinSpanningTree<G>
where
G::NodeWeight: Clone,
G::EdgeWeight: Clone + PartialOrd,
G: IntoNodeReferences + IntoEdgeReferences + NodeIndexable,
{
// Initially each vertex is its own disjoint subgraph, track the connectedness
// of the pre-MST with a union & find datastructure.
let subgraphs = UnionFind::new(g.node_bound());

let edges = g.edge_references();
let mut sort_edges = BinaryHeap::with_capacity(edges.size_hint().0);
for edge in edges {
sort_edges.push(MinScored(
edge.weight().clone(),
(edge.source(), edge.target()),
));
}

MinSpanningTree {
graph: g,
node_ids: Some(g.node_references()),
subgraphs,
sort_edges,
node_map: HashMap::new(),
node_count: 0,
}
}

/// An iterator producing a minimum spanning forest of a graph.
#[derive(Debug, Clone)]
pub struct MinSpanningTree<G>
where
G: Data + IntoNodeReferences,
{
graph: G,
node_ids: Option<G::NodeReferences>,
subgraphs: UnionFind<usize>,
#[allow(clippy::type_complexity)]
sort_edges: BinaryHeap<MinScored<G::EdgeWeight, (G::NodeId, G::NodeId)>>,
node_map: HashMap<usize, usize>,
node_count: usize,
}

impl<G> Iterator for MinSpanningTree<G>
where
G: IntoNodeReferences + NodeIndexable,
G::NodeWeight: Clone,
G::EdgeWeight: PartialOrd,
{
type Item = Element<G::NodeWeight, G::EdgeWeight>;

fn next(&mut self) -> Option<Self::Item> {
let g = self.graph;
if let Some(ref mut iter) = self.node_ids {
if let Some(node) = iter.next() {
self.node_map.insert(g.to_index(node.id()), self.node_count);
self.node_count += 1;
return Some(Element::Node {
weight: node.weight().clone(),
});
}
}
self.node_ids = None;

// Kruskal's algorithm.
// Algorithm is this:
//
// 1. Create a pre-MST with all the vertices and no edges.
// 2. Repeat:
//
// a. Remove the shortest edge from the original graph.
// b. If the edge connects two disjoint trees in the pre-MST,
// add the edge.
while let Some(MinScored(score, (a, b))) = self.sort_edges.pop() {
// check if the edge would connect two disjoint parts
let (a_index, b_index) = (g.to_index(a), g.to_index(b));
if self.subgraphs.union(a_index, b_index) {
let (&a_order, &b_order) =
match (self.node_map.get(&a_index), self.node_map.get(&b_index)) {
(Some(a_id), Some(b_id)) => (a_id, b_id),
_ => panic!("Edge references unknown node"),
};
return Some(Element::Edge {
source: a_order,
target: b_order,
weight: score,
});
}
}
None
}
}

/// An algorithm error: a cycle was found in the graph.
#[derive(Clone, Debug, PartialEq)]
pub struct Cycle<N>(N);
Expand Down

0 comments on commit 08b0ad9

Please sign in to comment.