import numba
import numpy as np
from .disjoint_set import RankDisjointSetType, ds_rank_create, ds_find, ds_union_by_rank
from .numba_kdtree import (
NumbaKDTreeType,
parallel_tree_query,
rdist,
point_to_node_lower_bound_rdist,
NumbaKDTree,
)
@numba.njit(
numba.float32[:, ::1](
RankDisjointSetType,
numba.int32[::1],
numba.types.Array(numba.float32, 1, "A"),
numba.int64[::1],
),
locals={"i": numba.types.int64},
cache=True,
)
def merge_components(
disjoint_set, candidate_neighbors, candidate_neighbor_distances, point_components
):
component_edges = {
np.int64(0): (np.int64(0), np.int64(1), np.float32(0.0)) for i in range(0)
}
# Find the best edges from each component
for i in range(candidate_neighbors.shape[0]):
from_component = np.int64(point_components[i])
if from_component in component_edges:
if candidate_neighbor_distances[i] < component_edges[from_component][2]:
component_edges[from_component] = (
numba.int64(i),
numba.int64(candidate_neighbors[i]),
numba.float32(candidate_neighbor_distances[i]),
)
else:
component_edges[from_component] = (
numba.int64(i),
numba.int64(candidate_neighbors[i]),
numba.float32(candidate_neighbor_distances[i]),
)
result = np.empty((len(component_edges), 3), dtype=np.float32)
result_idx = 0
# Add the best edges to the edge set and merge the relevant components
for edge in component_edges.values():
from_component = ds_find(disjoint_set, numba.int32(edge[0]))
to_component = ds_find(disjoint_set, numba.int32(edge[1]))
if from_component != to_component:
result[result_idx] = (
numba.float32(edge[0]),
numba.float32(edge[1]),
numba.float32(edge[2]),
)
result_idx += 1
ds_union_by_rank(disjoint_set, from_component, to_component)
return result[:result_idx]
@numba.njit(
numba.void(
NumbaKDTreeType,
RankDisjointSetType,
numba.int64[::1],
numba.int64[::1],
),
locals={
"i": numba.types.int32,
"j": numba.types.int32,
"idx": numba.types.int32,
"left": numba.types.int32,
"right": numba.types.int32,
"candidate_component": numba.types.int32,
},
parallel=True,
cache=True,
fastmath=True,
)
def update_component_vectors(tree, disjoint_set, node_components, point_components):
for i in numba.prange(point_components.shape[0]):
point_components[i] = ds_find(disjoint_set, np.int32(i))
for i in range(tree.idx_start.shape[0] - 1, -1, -1):
# Access node information from the separate arrays
is_leaf = tree.is_leaf[i]
idx_start = tree.idx_start[i]
idx_end = tree.idx_end[i]
# Case 1:
# If the node is a leaf we need to check that every point
# in the node is of the same component
if is_leaf:
candidate_component = point_components[tree.idx_array[idx_start]]
for j in range(idx_start + 1, idx_end):
idx = tree.idx_array[j]
if point_components[idx] != candidate_component:
break
else:
node_components[i] = candidate_component
# Case 2:
# If the node is not a leaf we only need to check
# that both child nodes are in the same component
else:
left = 2 * i + 1
right = left + 1
if node_components[left] == node_components[right]:
node_components[i] = node_components[left]
@numba.njit(
numba.void(
NumbaKDTreeType,
numba.int32,
numba.float32[::1],
numba.float32[::1],
numba.int32[::1],
numba.float32,
numba.types.Array(numba.float32, 1, "A"),
numba.int64,
numba.int64[::1],
numba.int64[::1],
numba.float32,
numba.float32[::1],
),
locals={
"i": numba.types.int32,
"idx": numba.types.int32,
"left": numba.types.int32,
"right": numba.types.int32,
"d": numba.types.float32,
"dist_lower_bound_left": numba.types.float32,
"dist_lower_bound_right": numba.types.float32,
},
cache=True,
fastmath=True,
)
def component_aware_query_recursion(
tree,
node,
point,
heap_p,
heap_i,
current_core_distance,
core_distances,
current_component,
node_components,
point_components,
dist_lower_bound,
component_nearest_neighbor_dist,
):
# Access node information from the separate arrays
is_leaf = tree.is_leaf[node]
idx_start = tree.idx_start[node]
idx_end = tree.idx_end[node]
# ------------------------------------------------------------
# Case 1a: query point is outside node radius:
# trim it from the query
if dist_lower_bound > heap_p[0]:
return
# ------------------------------------------------------------
# Case 1b: we can't improve on the best distance for this component
# trim it from the query
elif (
dist_lower_bound > component_nearest_neighbor_dist[0]
or current_core_distance > component_nearest_neighbor_dist[0]
):
return
# ------------------------------------------------------------
# Case 1c: node contains only points in same component as query
# trim it from the query
elif node_components[node] == current_component:
return
# ------------------------------------------------------------
# Case 2: this is a leaf node. Update set of nearby points
elif is_leaf:
for i in range(idx_start, idx_end):
idx = tree.idx_array[i]
if (
point_components[idx] != current_component
and core_distances[idx] < component_nearest_neighbor_dist[0]
):
d = max(
rdist(point, tree.data[idx]),
current_core_distance,
core_distances[idx],
)
if d < heap_p[0]:
heap_p[0] = d
heap_i[0] = idx
if d < component_nearest_neighbor_dist[0]:
component_nearest_neighbor_dist[0] = d
# ------------------------------------------------------------
# Case 3: Node is not a leaf. Recursively query subnodes
# starting with the closest
else:
left = numba.int32(2 * node + 1)
right = numba.int32(left + 1)
dist_lower_bound_left = point_to_node_lower_bound_rdist(
tree.node_bounds[0, left], tree.node_bounds[1, left], point
)
dist_lower_bound_right = point_to_node_lower_bound_rdist(
tree.node_bounds[0, right], tree.node_bounds[1, right], point
)
# recursively query subnodes
if dist_lower_bound_left <= dist_lower_bound_right:
component_aware_query_recursion(
tree,
left,
point,
heap_p,
heap_i,
current_core_distance,
core_distances,
current_component,
node_components,
point_components,
dist_lower_bound_left,
component_nearest_neighbor_dist,
)
component_aware_query_recursion(
tree,
right,
point,
heap_p,
heap_i,
current_core_distance,
core_distances,
current_component,
node_components,
point_components,
dist_lower_bound_right,
component_nearest_neighbor_dist,
)
else:
component_aware_query_recursion(
tree,
right,
point,
heap_p,
heap_i,
current_core_distance,
core_distances,
current_component,
node_components,
point_components,
dist_lower_bound_right,
component_nearest_neighbor_dist,
)
component_aware_query_recursion(
tree,
left,
point,
heap_p,
heap_i,
current_core_distance,
core_distances,
current_component,
node_components,
point_components,
dist_lower_bound_left,
component_nearest_neighbor_dist,
)
return
@numba.njit(
numba.types.Tuple((numba.float32[::1], numba.int32[::1]))(
NumbaKDTreeType,
numba.int64[::1],
numba.int64[::1],
numba.types.Array(numba.float32, 1, "A"),
),
locals={
"i": numba.types.int32,
"distance_lower_bound": numba.types.float32,
"current_component": numba.types.int32,
},
parallel=True,
cache=True,
fastmath=True,
)
def boruvka_tree_query(tree, node_components, point_components, core_distances):
candidate_distances = np.full(tree.data.shape[0], np.inf, dtype=np.float32)
candidate_indices = np.full(tree.data.shape[0], -1, dtype=np.int32)
component_nearest_neighbor_dist = np.full(
tree.data.shape[0], np.inf, dtype=np.float32
)
data = tree.data.astype(np.float32)
for i in numba.prange(tree.data.shape[0]):
distance_lower_bound = point_to_node_lower_bound_rdist(
tree.node_bounds[0, 0], tree.node_bounds[1, 0], tree.data[i]
)
heap_p, heap_i = candidate_distances[i : i + 1], candidate_indices[i : i + 1]
component_aware_query_recursion(
tree,
numba.int32(0),
data[i],
heap_p,
heap_i,
core_distances[i],
core_distances,
point_components[i],
node_components,
point_components,
distance_lower_bound,
component_nearest_neighbor_dist[
point_components[i] : point_components[i] + 1
],
)
return candidate_distances, candidate_indices
@numba.njit(inline="always", cache=True)
def calculate_block_size(n_components, n_points, num_threads):
"""Calculate adaptive block size based on component sizes."""
if n_components == 0:
points_per_component = n_points
else:
points_per_component = n_points / n_components
if points_per_component < 10:
block_size = num_threads * 512 # Weak pruning, large blocks
elif points_per_component < 100:
block_size = num_threads * 128 # Moderate pruning
elif points_per_component < 1000:
block_size = num_threads * 32 # Good pruning
else:
block_size = num_threads * 8 # Excellent pruning, small blocks
# Ensure reasonable bounds
block_size = max(num_threads, min(block_size, n_points // 4 + 1))
return int(block_size)
@numba.njit(
[
"void(float32[:], float32[:], int32[:], int32, int32)",
"void(float64[:], float64[:], int64[:], int64, int64)",
],
locals={
"i": numba.types.int32,
"component": numba.types.int32,
"block_bound": numba.types.float32,
},
cache=True,
fastmath=True,
inline="always",
)
def update_component_bounds_from_block(
component_nearest_neighbor_dist,
block_component_bounds,
point_components,
block_start,
block_end,
):
"""Update global component bounds from block results."""
for i in range(block_start, block_end):
component = point_components[i]
block_bound = block_component_bounds[i - block_start]
if block_bound < component_nearest_neighbor_dist[component]:
component_nearest_neighbor_dist[component] = block_bound
@numba.njit(
numba.types.Tuple((numba.float32[::1], numba.int32[::1]))(
NumbaKDTreeType,
numba.int64[::1],
numba.int64[::1],
numba.types.Array(numba.float32, 1, "A"),
numba.int64,
),
locals={
"block_start": numba.types.int32,
"block_end": numba.types.int32,
"block_size_actual": numba.types.int32,
"i": numba.types.int32,
"distance_lower_bound": numba.types.float32,
"current_component": numba.types.int32,
},
parallel=True,
cache=True,
fastmath=True,
)
def boruvka_tree_query_reproducible(
tree, node_components, point_components, core_distances, block_size
):
"""Reproducible version using block-based processing to avoid race conditions."""
candidate_distances = np.full(tree.data.shape[0], np.inf, dtype=np.float32)
candidate_indices = np.full(tree.data.shape[0], -1, dtype=np.int32)
component_nearest_neighbor_dist = np.full(
tree.data.shape[0], np.inf, dtype=np.float32
)
data = tree.data.astype(np.float32)
# Reusable buffer for block component bounds (allocate once, reuse)
max_block_component_bounds = np.full(block_size, np.inf, dtype=np.float32)
# Process points in blocks
for block_start in range(0, tree.data.shape[0], block_size):
block_end = min(block_start + block_size, tree.data.shape[0])
block_size_actual = block_end - block_start
# Reset only the portion we'll use (more cache-friendly)
max_block_component_bounds[:block_size_actual] = np.inf
# Parallel processing within the block
for i in numba.prange(block_start, block_end):
distance_lower_bound = point_to_node_lower_bound_rdist(
tree.node_bounds[0, 0], tree.node_bounds[1, 0], tree.data[i]
)
heap_p, heap_i = (
candidate_distances[i : i + 1],
candidate_indices[i : i + 1],
)
# Use current global bounds for this component
current_component = point_components[i]
local_component_bound = component_nearest_neighbor_dist[
current_component : current_component + 1
]
component_aware_query_recursion(
tree,
numba.int32(0),
data[i],
heap_p,
heap_i,
core_distances[i],
core_distances,
point_components[i],
node_components,
point_components,
distance_lower_bound,
local_component_bound,
)
# Store the potentially updated bound for this point
max_block_component_bounds[i - block_start] = local_component_bound[0]
# Sequential update of global component bounds after the block
update_component_bounds_from_block(
component_nearest_neighbor_dist,
max_block_component_bounds,
point_components,
block_start,
block_end,
)
return candidate_distances, candidate_indices
@numba.njit(
locals={
"i": numba.types.int32,
"j": numba.types.int32,
"k": numba.types.int32,
"result_idx": numba.types.int32,
"from_component": numba.types.int32,
"to_component": numba.types.int32,
},
parallel=True,
cache=True,
)
def initialize_boruvka_from_knn(
knn_indices, knn_distances, core_distances, disjoint_set
):
# component_edges = {0:(np.int32(0), np.int32(1), np.float32(0.0)) for i in range(0)}
component_edges = np.full((knn_indices.shape[0], 3), -1, dtype=np.float64)
for i in numba.prange(knn_indices.shape[0]):
for j in range(1, knn_indices.shape[1]):
k = np.int32(knn_indices[i, j])
if core_distances[i] >= core_distances[k]:
# Use max of core distance and actual distance as edge weight
edge_weight = max(core_distances[i], knn_distances[i, j])
component_edges[i] = (
np.float64(i),
np.float64(k),
np.float64(edge_weight),
)
break
result = np.empty((len(component_edges), 3), dtype=np.float64)
result_idx = 0
# Add the best edges to the edge set and merge the relevant components
for edge in component_edges:
if edge[0] < 0:
continue
from_component = ds_find(disjoint_set, np.int32(edge[0]))
to_component = ds_find(disjoint_set, np.int32(edge[1]))
if from_component != to_component:
result[result_idx] = (
np.float64(edge[0]),
np.float64(edge[1]),
np.float64(edge[2]),
)
result_idx += 1
ds_union_by_rank(disjoint_set, from_component, to_component)
return result[:result_idx].astype(np.float32)
[docs]
@numba.njit(
numba.float32[:, ::1](
NumbaKDTreeType,
numba.int64,
numba.int64,
numba.types.boolean,
),
cache=True,
)
def parallel_boruvka(tree, n_threads, min_samples=10, reproducible=False):
components_disjoint_set = ds_rank_create(tree.data.shape[0])
point_components = np.arange(tree.data.shape[0])
node_components = np.full(tree.idx_start.shape[0], -1)
n_components = point_components.shape[0]
if min_samples > 1:
distances, neighbors = parallel_tree_query(
tree, tree.data, k=numba.int64(min_samples + 1), output_rdist=True
)
core_distances = distances.T[-1]
initial_edges = initialize_boruvka_from_knn(
neighbors, distances, core_distances, components_disjoint_set
)
update_component_vectors(
tree, components_disjoint_set, node_components, point_components
)
else:
core_distances = np.zeros(tree.data.shape[0], dtype=np.float32)
distances, neighbors = parallel_tree_query(
tree, tree.data, k=numba.int64(2), output_rdist=True
)
initial_edges = initialize_boruvka_from_knn(
neighbors, distances, core_distances, components_disjoint_set
)
update_component_vectors(
tree, components_disjoint_set, node_components, point_components
)
# Count initial components after initialization
n_components = len(np.unique(point_components))
# Use list to accumulate edges, then convert at end (more efficient than vstack)
# all_edges = [initial_edges]
# all_edges = initial_edges
max_edges = tree.data.shape[0] - 1
all_edges = np.empty((max_edges, 3), dtype=np.float32)
n_edges = numba.int64(len(initial_edges))
all_edges[:n_edges] = initial_edges
while n_components > 1:
if reproducible:
# Calculate adaptive block size based on current component sizes
block_size = calculate_block_size(
n_components, tree.data.shape[0], n_threads
)
candidate_distances, candidate_indices = boruvka_tree_query_reproducible(
tree, node_components, point_components, core_distances, block_size
)
else:
candidate_distances, candidate_indices = boruvka_tree_query(
tree, node_components, point_components, core_distances
)
new_edges = merge_components(
components_disjoint_set,
candidate_indices,
candidate_distances,
point_components,
)
# Update component count more efficiently - subtract merged components
n_components -= len(new_edges)
update_component_vectors(
tree, components_disjoint_set, node_components, point_components
)
if len(new_edges) > 0:
# # all_edges.append(new_edges)
# all_edges = np.vstack((all_edges, new_edges)).astype(np.float32)
all_edges[n_edges : n_edges + len(new_edges)] = new_edges
n_edges += numba.int64(len(new_edges))
all_edges[:, 2] = np.sqrt(all_edges.T[2])
return all_edges