Source code for evoc.numba_kdtree

import numba
import numpy as np

from collections import namedtuple

NumbaKDTree = namedtuple(
    "NumbaKDTree",
    ["data", "idx_array", "idx_start", "idx_end", "radius", "is_leaf", "node_bounds"],
)
NodeData = namedtuple("NodeData", ["idx_start", "idx_end", "radius", "is_leaf"])

NodeDataType = numba.types.NamedTuple(
    [
        numba.types.intp[::1],
        numba.types.intp[::1],
        numba.types.float32[::1],
        numba.types.bool_[::1],
    ],
    NodeData,
)
# Create minimal sentinel instances at module level — zero cost
_sentinel_kdtree = NumbaKDTree(
    data=np.empty((1, 1), dtype=np.float32),
    idx_array=np.empty(1, dtype=np.intp),
    idx_start=np.empty(1, dtype=np.intp),
    idx_end=np.empty(1, dtype=np.intp),
    radius=np.empty(1, dtype=np.float32),
    is_leaf=np.empty(1, dtype=np.bool_),
    node_bounds=np.empty((2, 1, 1), dtype=np.float32),
)
NumbaKDTreeType = numba.typeof(_sentinel_kdtree)


def kdtree_to_numba(sklearn_kdtree):
    data, idx_array, node_data, node_bounds = sklearn_kdtree.get_arrays()
    return NumbaKDTree(
        data,
        idx_array,
        node_data.idx_start,
        node_data.idx_end,
        node_data.radius,
        node_data.is_leaf,
        node_bounds,
    )


@numba.njit(
    cache=True,
    fastmath=True,
    locals={
        "n_features": numba.types.intp,
        "lower_bounds": numba.types.float32[::1],
        "upper_bounds": numba.types.float32[::1],
        "radius": numba.types.float32,
        "diff": numba.types.float32,
        "data_row": numba.types.float32[::1],
    },
)
def _init_node(
    data,
    node_bounds,
    idx_array,
    idx_start_array,
    idx_end_array,
    radius_array,
    is_leaf_array,
    node,
    idx_start,
    idx_end,
):

    n_features = data.shape[1]
    lower_bounds = node_bounds[0, node, :]
    upper_bounds = node_bounds[1, node, :]

    # determine Node bounds
    for j in range(n_features):
        lower_bounds[j] = np.inf
        upper_bounds[j] = -np.inf

    for i in range(idx_start, idx_end):
        data_row = data[idx_array[i]]
        for j in range(n_features):
            lower_bounds[j] = min(lower_bounds[j], data_row[j])
            upper_bounds[j] = max(upper_bounds[j], data_row[j])

    radius = 0.0
    for j in range(n_features):
        diff = abs(upper_bounds[j] - lower_bounds[j]) * 0.5
        radius += diff * diff

    idx_start_array[node] = idx_start
    idx_end_array[node] = idx_end

    radius_array[node] = np.sqrt(radius)


@numba.njit(
    "intp(float32[:,::1], intp[::1], intp, intp)",
    cache=True,
    locals={
        "n_features": numba.types.intp,
        "result": numba.types.intp,
        "max_spread": numba.types.float32,
        "j": numba.types.intp,
        "i": numba.types.intp,
        "max_val": numba.types.float32,
        "min_val": numba.types.float32,
        "val": numba.types.float32,
        "spread": numba.types.float32,
    },
)
def _find_node_split_dim(data, idx_array, idx_start, idx_end):
    n_features = data.shape[1]
    result = 0
    max_spread = 0

    for j in range(n_features):
        max_val = data[idx_array[idx_start], j]
        min_val = max_val
        for i in range(idx_start + 1, idx_end):
            val = data[idx_array[i], j]
            max_val = max(max_val, val)
            min_val = min(min_val, val)

        spread = max_val - min_val

        if spread > max_spread:
            max_spread = spread
            result = j

    return result


@numba.njit(
    "int8(float32[:,::1], intp, intp, intp)",
    fastmath=True,
    cache=True,
    locals={
        "val1": numba.types.float32,
        "val2": numba.types.float32,
    },
)
def _compare_indices(data, axis, idx1, idx2):
    val1 = data[idx1, axis]
    val2 = data[idx2, axis]

    if val1 < val2:
        return -1
    elif val1 > val2:
        return 1
    else:
        # Break ties using original index values (like sklearn)
        if idx1 < idx2:
            return -1
        elif idx1 > idx2:
            return 1
        else:
            return 0


@numba.njit(
    "void(float32[:,::1], intp[::1], intp, intp, intp)",
    fastmath=True,
    cache=True,
    locals={
        "i": numba.types.intp,
        "key_idx": numba.types.intp,
        "j": numba.types.intp,
    },
)
def _insertion_sort_indices(data, idx_array, axis, left, right):
    for i in range(left + 1, right):
        key_idx = idx_array[i]
        j = i - 1

        while j >= left and _compare_indices(data, axis, idx_array[j], key_idx) > 0:
            idx_array[j + 1] = idx_array[j]
            j -= 1

        idx_array[j + 1] = key_idx


@numba.njit(
    "void(float32[:,::1], intp[::1], intp, intp, intp, intp)",
    fastmath=True,
    cache=True,
    locals={
        "root": numba.types.intp,
        "child": numba.types.intp,
        "swap": numba.types.intp,
    },
)
def _sift_down_indices(data, idx_array, axis, offset, start, end):
    root = start

    while root * 2 + 1 < end:
        child = root * 2 + 1
        swap = root

        if (
            _compare_indices(
                data, axis, idx_array[offset + swap], idx_array[offset + child]
            )
            < 0
        ):
            swap = child

        if (
            child + 1 < end
            and _compare_indices(
                data, axis, idx_array[offset + swap], idx_array[offset + child + 1]
            )
            < 0
        ):
            swap = child + 1

        if swap == root:
            return

        idx_array[offset + root], idx_array[offset + swap] = (
            idx_array[offset + swap],
            idx_array[offset + root],
        )
        root = swap


@numba.njit(
    "void(float32[:,::1], intp[::1], intp, intp, intp)",
    cache=True,
    locals={
        "size": numba.types.intp,
        "i": numba.types.intp,
    },
)
def _heapsort_indices(data, idx_array, axis, left, right):
    size = right - left

    # Build heap
    for i in range(size // 2 - 1, -1, -1):
        _sift_down_indices(data, idx_array, axis, left, i, size)

    # Extract elements
    for i in range(size - 1, 0, -1):
        idx_array[left], idx_array[left + i] = idx_array[left + i], idx_array[left]
        _sift_down_indices(data, idx_array, axis, left, 0, i)


@numba.njit(
    "intp(float32[:,::1], intp[::1], intp, intp, intp)",
    fastmath=True,
    cache=True,
    locals={
        "mid": numba.types.intp,
        "idx_left": numba.types.intp,
        "idx_mid": numba.types.intp,
        "idx_right": numba.types.intp,
    },
)
def _median_of_three_pivot(data, idx_array, axis, left, right):
    mid = (left + right - 1) // 2

    idx_left = idx_array[left]
    idx_mid = idx_array[mid]
    idx_right = idx_array[right - 1]

    # Sort the three candidates
    if _compare_indices(data, axis, idx_left, idx_mid) > 0:
        idx_array[left], idx_array[mid] = idx_array[mid], idx_array[left]
        idx_left, idx_mid = idx_mid, idx_left

    if _compare_indices(data, axis, idx_mid, idx_right) > 0:
        idx_array[mid], idx_array[right - 1] = idx_array[right - 1], idx_array[mid]
        idx_mid, idx_right = idx_right, idx_mid

        if _compare_indices(data, axis, idx_left, idx_mid) > 0:
            idx_array[left], idx_array[mid] = idx_array[mid], idx_array[left]

    return mid


@numba.njit(
    "intp(float32[:,::1], intp[::1], intp, intp, intp, intp)",
    fastmath=True,
    cache=True,
    locals={
        "pivot_value": numba.types.float32,
        "pivot_original_idx": numba.types.intp,
        "i": numba.types.intp,
        "j": numba.types.intp,
    },
)
def _partition_indices(data, idx_array, axis, left, right, pivot_idx):
    # Move pivot to end
    idx_array[pivot_idx], idx_array[right - 1] = (
        idx_array[right - 1],
        idx_array[pivot_idx],
    )
    pivot_value = data[idx_array[right - 1], axis]
    pivot_original_idx = idx_array[right - 1]

    i = left
    j = right - 2

    while True:
        # Find element from left that should be on right
        while (
            i <= j
            and _compare_indices(data, axis, idx_array[i], pivot_original_idx) < 0
        ):
            i += 1

        # Find element from right that should be on left
        while (
            i <= j
            and _compare_indices(data, axis, idx_array[j], pivot_original_idx) >= 0
        ):
            j -= 1

        if i >= j:
            break

        # Swap elements
        idx_array[i], idx_array[j] = idx_array[j], idx_array[i]
        i += 1
        j -= 1

    # Move pivot to final position
    idx_array[i], idx_array[right - 1] = idx_array[right - 1], idx_array[i]
    return i


@numba.njit(
    "void(float32[:,::1], intp[::1], intp, intp, intp, intp, intp)",
    cache=True,
    locals={
        "pivot_idx": numba.types.intp,
        "pivot_pos": numba.types.intp,
    },
)
def _introselect_impl(data, idx_array, axis, left, right, nth, depth_limit):
    while right - left > 16:
        if depth_limit == 0:
            # Fall back to heapsort when recursion gets too deep
            _heapsort_indices(data, idx_array, axis, left, right)
            return

        depth_limit -= 1

        # Choose pivot using median-of-three
        pivot_idx = _median_of_three_pivot(data, idx_array, axis, left, right)

        # Partition around pivot
        pivot_pos = _partition_indices(data, idx_array, axis, left, right, pivot_idx)

        # Recurse on the appropriate side
        if nth < pivot_pos:
            right = pivot_pos
        elif nth > pivot_pos:
            left = pivot_pos + 1
        else:
            # Found the nth element
            return

    # Use insertion sort for small subarrays
    _insertion_sort_indices(data, idx_array, axis, left, right)


@numba.njit(
    "void(float32[:,::1], intp[::1], intp, intp, intp, intp)",
    cache=True,
    locals={
        "size": numba.types.intp,
        "max_depth": numba.types.intp,
    },
)
def _introselect(data, idx_array, axis, left, right, nth):
    size = right - left

    # Use heapsort for small arrays or when recursion depth is too high
    if size <= 16:
        _insertion_sort_indices(data, idx_array, axis, left, right)
        return

    # Calculate maximum recursion depth (2 * log2(size))
    max_depth = 2 * int(np.log2(size))
    _introselect_impl(data, idx_array, axis, left, right, nth, max_depth)


@numba.njit(
    "void(float32[:, ::1], intp[::1], intp[::1], intp[::1], float32[::1], bool_[::1], float32[:, :, ::1], intp, intp, intp)",
    cache=True,
)
def _recursive_build_tree(
    data,
    idx_array,
    idx_start_array,
    idx_end_array,
    radius_array,
    is_leaf_array,
    node_bounds,
    idx_start,
    idx_end,
    node,
):
    n_points = idx_end - idx_start
    n_mid = n_points // 2

    _init_node(
        data,
        node_bounds,
        idx_array,
        idx_start_array,
        idx_end_array,
        radius_array,
        is_leaf_array,
        node,
        idx_start,
        idx_end,
    )

    if 2 * node + 1 >= is_leaf_array.shape[0]:
        is_leaf_array[node] = True
    elif idx_end - idx_start < 2:
        is_leaf_array[node] = True
    else:
        is_leaf_array[node] = False
        axis = _find_node_split_dim(data, idx_array, idx_start, idx_end)
        _introselect(data, idx_array, axis, idx_start, idx_end, idx_start + n_mid)
        _recursive_build_tree(
            data,
            idx_array,
            idx_start_array,
            idx_end_array,
            radius_array,
            is_leaf_array,
            node_bounds,
            idx_start,
            idx_start + n_mid,
            2 * node + 1,
        )
        _recursive_build_tree(
            data,
            idx_array,
            idx_start_array,
            idx_end_array,
            radius_array,
            is_leaf_array,
            node_bounds,
            idx_start + n_mid,
            idx_end,
            2 * node + 2,
        )

    return


[docs] def build_kdtree(data, leaf_size=40): n_samples = data.shape[0] n_features = data.shape[1] if leaf_size < 1: raise ValueError("leaf_size must be greater than or equal to 1") # determine number of levels in the tree, and from this # the number of nodes in the tree. This results in leaf nodes # with numbers of points between leaf_size and 2 * leaf_size n_levels = int(np.log2(max(1, (n_samples - 1) / leaf_size)) + 1) n_nodes = np.int32((2**n_levels) - 1) # allocate arrays for storage idx_array = np.arange(n_samples, dtype=np.intp) idx_start_array = np.zeros(n_nodes, dtype=np.intp) idx_end_array = np.zeros(n_nodes, dtype=np.intp) radius_array = np.zeros(n_nodes, dtype=np.float32) is_leaf_array = np.zeros(n_nodes, dtype=np.bool_) node_bounds = np.zeros((2, n_nodes, n_features), dtype=np.float32) _recursive_build_tree( data, idx_array, idx_start_array, idx_end_array, radius_array, is_leaf_array, node_bounds, 0, n_samples, 0, ) return NumbaKDTree( data, idx_array, idx_start_array, idx_end_array, radius_array, is_leaf_array, node_bounds, )
@numba.njit( [ "f4(f4[::1],f4[::1])", "f8(f8[::1],f8[::1])", "f8(f4[::1],f8[::1])", ], fastmath=True, cache=True, locals={ "dim": numba.types.intp, "i": numba.types.uint16, "diff": numba.types.float32, "result": numba.types.float32, }, ) def rdist(x, y): result = 0.0 dim = x.shape[0] for i in range(dim): diff = x[i] - y[i] result += diff * diff return result @numba.njit( [ "f4(f4[::1],f4[::1],f4[::1])", "f4(f8[::1],f8[::1],f4[::1])", "f4(f8[::1],f8[::1],f8[::1])", ], fastmath=True, cache=True, locals={ "dim": numba.types.intp, "i": numba.types.uint16, "d_lo": numba.types.float32, "d_hi": numba.types.float32, "d": numba.types.float32, "result": numba.types.float32, }, ) def point_to_node_lower_bound_rdist(upper, lower, pt): result = 0.0 dim = pt.shape[0] for i in range(dim): d_lo = upper[i] - pt[i] if upper[i] > pt[i] else 0.0 d_hi = pt[i] - lower[i] if pt[i] > lower[i] else 0.0 d = d_lo + d_hi result += d * d return result @numba.njit( [ "i4(f4[::1],i4[::1],f4,i4)", "i4(f8[::1],i4[::1],f8,i4)", ], fastmath=True, locals={ "size": numba.types.intp, "i": numba.types.uint16, "ic1": numba.types.uint16, "ic2": numba.types.uint16, "i_swap": numba.types.uint16, }, cache=True, ) def simple_heap_push(priorities, indices, p, n): if p >= priorities[0]: return 0 size = priorities.shape[0] # insert val at position zero priorities[0] = p indices[0] = n # descend the heap, swapping values until the max heap criterion is met i = 0 while True: ic1 = 2 * i + 1 ic2 = ic1 + 1 if ic1 >= size: break elif ic2 >= size: if priorities[ic1] > p: i_swap = ic1 else: break elif priorities[ic1] >= priorities[ic2]: if p < priorities[ic1]: i_swap = ic1 else: break else: if p < priorities[ic2]: i_swap = ic2 else: break priorities[i] = priorities[i_swap] indices[i] = indices[i_swap] i = i_swap priorities[i] = p indices[i] = n return 1 @numba.njit( fastmath=True, cache=True, locals={ "left_child": numba.types.intp, "right_child": numba.types.intp, "swap": numba.types.intp, }, ) def siftdown(heap1, heap2, elt): while elt * 2 + 1 < heap1.shape[0]: left_child = elt * 2 + 1 right_child = left_child + 1 swap = elt if heap1[swap] < heap1[left_child]: swap = left_child if right_child < heap1.shape[0] and heap1[swap] < heap1[right_child]: swap = right_child if swap == elt: break else: heap1[elt], heap1[swap] = heap1[swap], heap1[elt] heap2[elt], heap2[swap] = heap2[swap], heap2[elt] elt = swap @numba.njit(parallel=True, cache=True) def deheap_sort(distances, indices): for i in numba.prange(indices.shape[0]): # starting from the end of the array and moving back for j in range(indices.shape[1] - 1, 0, -1): indices[i, 0], indices[i, j] = indices[i, j], indices[i, 0] distances[i, 0], distances[i, j] = distances[i, j], distances[i, 0] siftdown(distances[i, :j], indices[i, :j], 0) return distances, indices @numba.njit( numba.void( NumbaKDTreeType, numba.types.intp, numba.float32[::1], numba.float32[::1], numba.int32[::1], numba.float32, ), fastmath=True, cache=True, locals={ "node": numba.types.intp, "left": numba.types.intp, "right": numba.types.intp, "d": numba.types.float32, "idx": numba.types.uint32, "idx_start": numba.types.intp, "idx_end": numba.types.intp, "is_leaf": numba.types.boolean, "i": numba.types.intp, "dist_lower_bound_left": numba.types.float32, "dist_lower_bound_right": numba.types.float32, }, ) def tree_query_recursion( tree, node, point, heap_p, heap_i, dist_lower_bound, ): # Get node information idx_start = tree.idx_start[node] idx_end = tree.idx_end[node] is_leaf = tree.is_leaf[node] # ------------------------------------------------------------ # Case 1: query point is outside node radius: # trim it from the query if dist_lower_bound > heap_p[0]: return # ------------------------------------------------------------ # Case 2: this is a leaf node. Update set of nearby points elif is_leaf: for i in range(idx_start, idx_end): idx = tree.idx_array[i] d = rdist(point, tree.data[idx]) if d < heap_p[0]: simple_heap_push(heap_p, heap_i, d, idx) # ------------------------------------------------------------ # Case 3: Node is not a leaf. Recursively query subnodes # starting with the closest else: left = 2 * node + 1 right = left + 1 dist_lower_bound_left = point_to_node_lower_bound_rdist( tree.node_bounds[0, left], tree.node_bounds[1, left], point ) dist_lower_bound_right = point_to_node_lower_bound_rdist( tree.node_bounds[0, right], tree.node_bounds[1, right], point ) # recursively query subnodes if dist_lower_bound_left <= dist_lower_bound_right: tree_query_recursion( tree, left, point, heap_p, heap_i, dist_lower_bound_left ) tree_query_recursion( tree, right, point, heap_p, heap_i, dist_lower_bound_right ) else: tree_query_recursion( tree, right, point, heap_p, heap_i, dist_lower_bound_right ) tree_query_recursion( tree, left, point, heap_p, heap_i, dist_lower_bound_left ) return @numba.njit( numba.types.Tuple((numba.float32[:, ::1], numba.int32[:, ::1]))( NumbaKDTreeType, numba.float32[:, ::1], numba.int64, numba.types.boolean, ), parallel=True, fastmath=True, cache=True, locals={ "i": numba.types.intp, "distance_lower_bound": numba.types.float32, }, ) def parallel_tree_query( tree, data, k=numba.int64(10), output_rdist=numba.types.boolean(False) ): result = ( np.full((data.shape[0], k), np.inf, dtype=np.float32), np.full((data.shape[0], k), -1, dtype=np.int32), ) for i in numba.prange(data.shape[0]): distance_lower_bound = point_to_node_lower_bound_rdist( tree.node_bounds[0, 0], tree.node_bounds[1, 0], data[i] ) heap_priorities, heap_indices = result[0][i], result[1][i] tree_query_recursion( tree, numba.intp(0), data[i], heap_priorities, heap_indices, distance_lower_bound, ) if output_rdist: return deheap_sort(result[0], result[1]) else: return deheap_sort(np.sqrt(result[0]), result[1])