DDSA Solutions

Count Spanning Trees in a Graph

Time: O(n³)
Space: O(n²)
Advertisement

Intuition

Kirchhoff's Matrix-Tree theorem states that the number of spanning trees of a graph equals the absolute value of any cofactor (e.g., the determinant of the (n−1)×(n−1) submatrix) of its Laplacian matrix. The Laplacian captures both the vertex degrees and edge weights, so computing its determinant via Gaussian elimination yields the count in O(n³).

Algorithm

  1. 1Build the Laplacian matrix L: L[u][u] = degree of u; L[u][v] = −count of edges between u and v.
  2. 2Extract the (n−1)×(n−1) submatrix by removing the last row and column.
  3. 3Compute the determinant via Gaussian elimination:
  4. 4 a) For each pivot column i, find the row with the largest absolute value in column i (partial pivoting).
  5. 5 b) Swap rows if necessary, tracking sign changes (multiply det by −1).
  6. 6 c) Eliminate all entries below the pivot via row operations.
  7. 7Multiply the diagonal entries to get the determinant.
  8. 8Round to the nearest integer and return.

Example Walkthrough

Input: n = 3, edges = [[0,1],[0,2],[1,2]]

  1. 1.Degrees: 0→2, 1→2, 2→2.
  2. 2.Laplacian: [[2,-1,-1],[-1,2,-1],[-1,-1,2]].
  3. 3.Submatrix (remove last row/col): [[2,-1],[-1,2]].
  4. 4.Determinant: 2*2 - (-1)*(-1) = 4 - 1 = 3.
  5. 5.A triangle has exactly 3 spanning trees: each subset of 2 edges.

Output: 3

Common Pitfalls

  • The Laplacian matrix must be constructed correctly: degree on diagonal, negative adjacency off-diagonal.
  • Determinant computation uses floating-point arithmetic; comparison against zero must use an epsilon (1e-9) to handle rounding errors.
  • Removing the last row and column is arbitrary (any (n−1)×(n−1) cofactor works), but the code must be consistent.
Count Spanning Trees in a Graph.java
Java

// Approach: Apply Kirchhoff's Matrix-Tree theorem: the number of spanning trees equals any cofactor of the Laplacian matrix.
// Build the Laplacian (degree diagonal - adjacency), remove the last row and column to get an (n-1)×(n-1) matrix,
// then compute its determinant via Gaussian elimination with partial pivoting.
// Time: O(n³) Space: O(n²)

class Solution {

    public int countSpanTree(int n, int[][] edges) {
        if (n == 1) {
            return 1;
        }

        int[][] lap = new int[n][n];
        for (int[] e : edges) {
            int u = e[0];
            int v = e[1];
            lap[u][u]++;
            lap[v][v]++;
            lap[u][v]--;
            lap[v][u]--;
        }
        double[][] mat = new double[n - 1][n - 1];
        for (int i = 1; i < n; i++) {
            for (int j = 1; j < n; j++) {
                mat[i - 1][j - 1] = lap[i][j];
            }
        }
        long det = determinant(mat, n - 1);
        return (int) Math.round(det);
    }

    private long determinant(double[][] mat, int size) {
        double det = 1;
        for (int i = 0; i < size; i++) {
            int pivot = i;
            for (int j = i; j < size; j++) {
                if (Math.abs(mat[j][i]) > Math.abs(mat[pivot][i])) {
                    pivot = j;
                }
            }
            if (Math.abs(mat[pivot][i]) < 1e-9) {
                return 0;
            }
            if (i != pivot) {
                double[] temp = mat[i];
                mat[i] = mat[pivot];
                mat[pivot] = temp;
                det *= -1;
            }

            det *= mat[i][i];
            for (int j = i + 1; j < size; j++) {
                double factor = mat[j][i] / mat[i][i];
                for (int k = i; k < size; k++) {
                    mat[j][k] -= factor * mat[i][k];
                }
            }
        }

        return Math.round(det);
    }
}
Advertisement
Was this solution helpful?