Skip to content

mlx3d.nn

mlx3d.nn

FusedMLP

Bases: Module

A small ReLU MLP with a fused Metal forward path.

Parameters:

Name Type Description Default
layer_dims list[int]

sizes [in, h1, ..., out]; every hidden/in/out dimension must be <= 64. ReLU is applied after every layer except the last.

required
Source code in src/mlx3d/nn/fused_mlp.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
class FusedMLP(nn.Module):
    """A small ReLU MLP with a fused Metal forward path.

    Args:
        layer_dims: sizes ``[in, h1, ..., out]``; every hidden/in/out dimension
            must be ``<= 64``. ReLU is applied after every layer except the last.
    """

    def __init__(self, layer_dims: list[int]):
        super().__init__()
        if any(d > _MAX_WIDTH for d in layer_dims):
            raise ValueError(f"FusedMLP supports dimensions <= {_MAX_WIDTH}; got {layer_dims}.")
        self.layer_dims = list(layer_dims)
        # Weights stored (in, out) so x @ W matches the kernel's row-major layout.
        self.weights = []
        self.biases = []
        for din, dout in zip(layer_dims[:-1], layer_dims[1:]):
            scale = (2.0 / din) ** 0.5
            self.weights.append(mx.random.normal((din, dout)) * scale)
            self.biases.append(mx.zeros((dout,)))

    def __call__(self, x: mx.array) -> mx.array:
        """Differentiable MLX forward (use for training)."""
        h = x
        n = len(self.weights)
        for i, (w, b) in enumerate(zip(self.weights, self.biases)):
            h = h @ w + b
            if i < n - 1:
                h = nn.relu(h)
        return h

    def forward_fused(self, x: mx.array) -> mx.array:
        """Fused single-kernel forward; matches :meth:`__call__` exactly.

        A correct reference for the fused-MLP idea. See the module note: MLX's
        native matmuls are currently faster on Apple GPUs, so prefer
        :meth:`__call__` in practice.
        """
        rows = int(x.shape[0])
        dims = mx.array([len(self.weights), *self.layer_dims], dtype=mx.int32)
        meta = mx.array([rows], dtype=mx.int32)
        flat_w = mx.concatenate([w.reshape(-1) for w in self.weights])
        flat_b = mx.concatenate([b.reshape(-1) for b in self.biases])
        tg = 256
        grid = ((rows + tg - 1) // tg * tg, 1, 1)
        (out,) = _fused_kernel(
            inputs=[
                mx.contiguous(x.astype(mx.float32)),
                mx.contiguous(flat_w.astype(mx.float32)),
                mx.contiguous(flat_b.astype(mx.float32)),
                dims,
                meta,
            ],
            output_shapes=[(rows * self.layer_dims[-1],)],
            output_dtypes=[mx.float32],
            grid=grid,
            threadgroup=(tg, 1, 1),
        )
        return out.reshape(rows, self.layer_dims[-1])

__call__(x)

Differentiable MLX forward (use for training).

Source code in src/mlx3d/nn/fused_mlp.py
 94
 95
 96
 97
 98
 99
100
101
102
def __call__(self, x: mx.array) -> mx.array:
    """Differentiable MLX forward (use for training)."""
    h = x
    n = len(self.weights)
    for i, (w, b) in enumerate(zip(self.weights, self.biases)):
        h = h @ w + b
        if i < n - 1:
            h = nn.relu(h)
    return h

forward_fused(x)

Fused single-kernel forward; matches :meth:__call__ exactly.

A correct reference for the fused-MLP idea. See the module note: MLX's native matmuls are currently faster on Apple GPUs, so prefer :meth:__call__ in practice.

Source code in src/mlx3d/nn/fused_mlp.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def forward_fused(self, x: mx.array) -> mx.array:
    """Fused single-kernel forward; matches :meth:`__call__` exactly.

    A correct reference for the fused-MLP idea. See the module note: MLX's
    native matmuls are currently faster on Apple GPUs, so prefer
    :meth:`__call__` in practice.
    """
    rows = int(x.shape[0])
    dims = mx.array([len(self.weights), *self.layer_dims], dtype=mx.int32)
    meta = mx.array([rows], dtype=mx.int32)
    flat_w = mx.concatenate([w.reshape(-1) for w in self.weights])
    flat_b = mx.concatenate([b.reshape(-1) for b in self.biases])
    tg = 256
    grid = ((rows + tg - 1) // tg * tg, 1, 1)
    (out,) = _fused_kernel(
        inputs=[
            mx.contiguous(x.astype(mx.float32)),
            mx.contiguous(flat_w.astype(mx.float32)),
            mx.contiguous(flat_b.astype(mx.float32)),
            dims,
            meta,
        ],
        output_shapes=[(rows * self.layer_dims[-1],)],
        output_dtypes=[mx.float32],
        grid=grid,
        threadgroup=(tg, 1, 1),
    )
    return out.reshape(rows, self.layer_dims[-1])

HashGridEncoding

Bases: Module

Trainable multi-resolution 3D hash-grid encoder.

This follows the Instant-NGP idea: points are normalized to a unit cube, each level performs trilinear interpolation over hashed grid vertices, and all level features are concatenated.

Source code in src/mlx3d/nn/hashgrid.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class HashGridEncoding(nn.Module):
    """Trainable multi-resolution 3D hash-grid encoder.

    This follows the Instant-NGP idea: points are normalized to a unit cube,
    each level performs trilinear interpolation over hashed grid vertices, and
    all level features are concatenated.
    """

    def __init__(
        self,
        num_levels: int = 12,
        features_per_level: int = 2,
        log2_hashmap_size: int = 15,
        base_resolution: int = 16,
        finest_resolution: int = 512,
        bounds: tuple[float, float] = (-1.0, 1.0),
    ):
        super().__init__()
        self.num_levels = int(num_levels)
        self.features_per_level = int(features_per_level)
        self.hashmap_size = 1 << int(log2_hashmap_size)
        self.base_resolution = int(base_resolution)
        self.finest_resolution = int(finest_resolution)
        self.bounds = bounds
        self.per_level_scale = (
            math.exp(math.log(finest_resolution / base_resolution) / max(num_levels - 1, 1))
            if num_levels > 1
            else 1.0
        )
        self.tables = [
            mx.random.uniform(
                low=-1e-4,
                high=1e-4,
                shape=(self.hashmap_size, self.features_per_level),
            )
            for _ in range(self.num_levels)
        ]

    @property
    def output_dim(self) -> int:
        return self.num_levels * self.features_per_level

    def _hash(self, coords: mx.array) -> mx.array:
        coords = coords.astype(mx.uint32)
        x, y, z = coords[..., 0], coords[..., 1], coords[..., 2]
        h = x * mx.array(1, dtype=mx.uint32)
        h = h ^ (y * mx.array(2654435761, dtype=mx.uint32))
        h = h ^ (z * mx.array(805459861, dtype=mx.uint32))
        return (h % self.hashmap_size).astype(mx.int32)

    def __call__(self, x: mx.array) -> mx.array:
        lo, hi = self.bounds
        x = (x - lo) / max(float(hi - lo), 1e-8)
        x = mx.clip(x, 0.0, 1.0)
        outs = []
        offsets = mx.array(
            [
                [0, 0, 0],
                [1, 0, 0],
                [0, 1, 0],
                [1, 1, 0],
                [0, 0, 1],
                [1, 0, 1],
                [0, 1, 1],
                [1, 1, 1],
            ],
            dtype=mx.int32,
        )
        for level, table in enumerate(self.tables):
            res = int(math.floor(self.base_resolution * (self.per_level_scale**level)))
            res = max(res, 2)
            p = x * (res - 1)
            p0 = mx.floor(p).astype(mx.int32)
            frac = p - p0.astype(mx.float32)
            corner = mx.minimum(p0[..., None, :] + offsets, res - 1)
            idx = self._hash(corner)
            feat = table[idx]
            wx = mx.where(offsets[:, 0] == 1, frac[..., 0:1], 1.0 - frac[..., 0:1])
            wy = mx.where(offsets[:, 1] == 1, frac[..., 1:2], 1.0 - frac[..., 1:2])
            wz = mx.where(offsets[:, 2] == 1, frac[..., 2:3], 1.0 - frac[..., 2:3])
            weight = (wx * wy * wz)[..., None]
            outs.append(mx.sum(feat * weight, axis=-2))
        return mx.concatenate(outs, axis=-1)

HashGridNeRF

Bases: Module

A compact hash-grid NeRF (Instant-NGP style).

The hash-grid hyperparameters (num_levels, features_per_level, log2_hashmap_size, base_resolution, finest_resolution) are forwarded to :class:~mlx3d.nn.HashGridEncoding.

Parameters:

Name Type Description Default
bounds tuple[float, float]

axis-aligned scene bounds the hash grid covers; sample points should lie within this cube (density is zeroed outside it).

(-1.5, 1.5)
geo_feat_dim int

size of the geometry feature passed to the color MLP.

15
hidden_dim int

width of both small MLPs.

64
dir_freqs int

positional-encoding frequencies for the view direction.

4
Source code in src/mlx3d/nn/instant_ngp.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class HashGridNeRF(nn.Module):
    """A compact hash-grid NeRF (Instant-NGP style).

    The hash-grid hyperparameters (``num_levels``, ``features_per_level``,
    ``log2_hashmap_size``, ``base_resolution``, ``finest_resolution``) are
    forwarded to :class:`~mlx3d.nn.HashGridEncoding`.

    Args:
        bounds: axis-aligned scene bounds the hash grid covers; sample points
            should lie within this cube (density is zeroed outside it).
        geo_feat_dim: size of the geometry feature passed to the color MLP.
        hidden_dim: width of both small MLPs.
        dir_freqs: positional-encoding frequencies for the view direction.
    """

    def __init__(
        self,
        bounds: tuple[float, float] = (-1.5, 1.5),
        num_levels: int = 16,
        features_per_level: int = 2,
        log2_hashmap_size: int = 19,
        base_resolution: int = 16,
        finest_resolution: int = 1024,
        geo_feat_dim: int = 15,
        hidden_dim: int = 64,
        dir_freqs: int = 4,
    ):
        super().__init__()
        self.bounds = bounds
        self.encoding = HashGridEncoding(
            num_levels=num_levels,
            features_per_level=features_per_level,
            log2_hashmap_size=log2_hashmap_size,
            base_resolution=base_resolution,
            finest_resolution=finest_resolution,
            bounds=bounds,
        )
        self.dir_enc = PositionalEncoding(dir_freqs)
        self.geo_feat_dim = geo_feat_dim

        enc_dim = self.encoding.output_dim
        dir_dim = 3 * self.dir_enc.output_dim_multiplier

        # Density MLP: hash features -> [density, geometry features].
        self.sigma_l1 = nn.Linear(enc_dim, hidden_dim)
        self.sigma_l2 = nn.Linear(hidden_dim, 1 + geo_feat_dim)
        # Color MLP: [geometry features, encoded direction] -> RGB.
        self.color_l1 = nn.Linear(geo_feat_dim + dir_dim, hidden_dim)
        self.color_l2 = nn.Linear(hidden_dim, hidden_dim)
        self.color_l3 = nn.Linear(hidden_dim, 3)

    def __call__(self, points: mx.array, directions: mx.array) -> tuple[mx.array, mx.array]:
        h = self.encoding(points)
        h = nn.relu(self.sigma_l1(h))
        h = self.sigma_l2(h)
        density = _trunc_exp(h[..., 0])
        geo = h[..., 1:]

        # Outside the scene AABB there is no geometry. Zeroing density there is
        # both physically correct and essential for the hash grid: points beyond
        # the bounds get clamped onto the cube face and would otherwise share
        # garbage features, preventing the field from localizing the object.
        lo, hi = self.bounds
        inside = mx.all((points >= lo) & (points <= hi), axis=-1)
        density = density * inside

        d = self.dir_enc(directions)
        c = mx.concatenate([geo, d], axis=-1)
        c = nn.relu(self.color_l1(c))
        c = nn.relu(self.color_l2(c))
        rgb = mx.sigmoid(self.color_l3(c))
        return density, rgb

NeRF

Bases: Module

The original NeRF MLP: density from position, color from position + view.

Source code in src/mlx3d/nn/nerf.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
class NeRF(nn.Module):
    """The original NeRF MLP: density from position, color from position + view."""

    def __init__(
        self,
        pos_freqs: int = 10,
        dir_freqs: int = 4,
        hidden_dim: int = 256,
        num_layers: int = 8,
        skip_layer: int = 4,
    ):
        super().__init__()
        self.pos_enc = PositionalEncoding(pos_freqs)
        self.dir_enc = PositionalEncoding(dir_freqs)
        pos_dim = 3 * self.pos_enc.output_dim_multiplier
        dir_dim = 3 * self.dir_enc.output_dim_multiplier
        self.skip_layer = skip_layer

        layers = []
        in_dim = pos_dim
        for i in range(num_layers):
            if i == skip_layer:
                in_dim += pos_dim
            layers.append(nn.Linear(in_dim, hidden_dim))
            in_dim = hidden_dim
        self.layers = layers

        self.density_head = nn.Linear(hidden_dim, 1)
        self.feature = nn.Linear(hidden_dim, hidden_dim)
        self.color_hidden = nn.Linear(hidden_dim + dir_dim, hidden_dim // 2)
        self.color_head = nn.Linear(hidden_dim // 2, 3)

    def __call__(self, points: mx.array, directions: mx.array) -> tuple[mx.array, mx.array]:
        """Evaluate density and color.

        Args:
            points: (..., 3) sample positions.
            directions: (..., 3) normalized view directions (broadcastable).

        Returns:
            ``(density, rgb)`` with shapes (...,) and (..., 3).
        """
        x = self.pos_enc(points)
        h = x
        for i, layer in enumerate(self.layers):
            if i == self.skip_layer:
                h = mx.concatenate([h, x], axis=-1)
            h = nn.relu(layer(h))

        density = nn.relu(self.density_head(h)[..., 0])
        feat = self.feature(h)
        d = self.dir_enc(directions)
        ch = nn.relu(self.color_hidden(mx.concatenate([feat, d], axis=-1)))
        rgb = mx.sigmoid(self.color_head(ch))
        return density, rgb

__call__(points, directions)

Evaluate density and color.

Parameters:

Name Type Description Default
points array

(..., 3) sample positions.

required
directions array

(..., 3) normalized view directions (broadcastable).

required

Returns:

Type Description
tuple[array, array]

(density, rgb) with shapes (...,) and (..., 3).

Source code in src/mlx3d/nn/nerf.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __call__(self, points: mx.array, directions: mx.array) -> tuple[mx.array, mx.array]:
    """Evaluate density and color.

    Args:
        points: (..., 3) sample positions.
        directions: (..., 3) normalized view directions (broadcastable).

    Returns:
        ``(density, rgb)`` with shapes (...,) and (..., 3).
    """
    x = self.pos_enc(points)
    h = x
    for i, layer in enumerate(self.layers):
        if i == self.skip_layer:
            h = mx.concatenate([h, x], axis=-1)
        h = nn.relu(layer(h))

    density = nn.relu(self.density_head(h)[..., 0])
    feat = self.feature(h)
    d = self.dir_enc(directions)
    ch = nn.relu(self.color_hidden(mx.concatenate([feat, d], axis=-1)))
    rgb = mx.sigmoid(self.color_head(ch))
    return density, rgb

PositionalEncoding

Bases: Module

Sinusoidal positional encoding from the NeRF paper.

Maps x to [x, sin(2^0 x), cos(2^0 x), ..., sin(2^{L-1} x), cos(2^{L-1} x)].

Source code in src/mlx3d/nn/nerf.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding from the NeRF paper.

    Maps x to ``[x, sin(2^0 x), cos(2^0 x), ..., sin(2^{L-1} x), cos(2^{L-1} x)]``.
    """

    def __init__(self, num_freqs: int, include_input: bool = True):
        super().__init__()
        self.num_freqs = num_freqs
        self.include_input = include_input
        self._freqs = 2.0 ** mx.arange(num_freqs)

    @property
    def output_dim_multiplier(self) -> int:
        return 2 * self.num_freqs + (1 if self.include_input else 0)

    def __call__(self, x: mx.array) -> mx.array:
        xb = x[..., None] * self._freqs  # (..., D, L)
        enc = mx.concatenate([mx.sin(xb), mx.cos(xb)], axis=-1)
        enc = enc.reshape(*x.shape[:-1], -1)
        if self.include_input:
            enc = mx.concatenate([x, enc], axis=-1)
        return enc

OccupancyGrid

A dense res^3 occupancy cache over an axis-aligned box.

Parameters:

Name Type Description Default
resolution int

cells per axis.

128
bounds tuple[float, float]

(lo, hi) world extent of the grid (same on every axis).

(-1.5, 1.5)
Source code in src/mlx3d/nn/occupancy.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
class OccupancyGrid:
    """A dense ``res^3`` occupancy cache over an axis-aligned box.

    Args:
        resolution: cells per axis.
        bounds: ``(lo, hi)`` world extent of the grid (same on every axis).
    """

    def __init__(self, resolution: int = 128, bounds: tuple[float, float] = (-1.5, 1.5)):
        self.resolution = int(resolution)
        self.bounds = bounds
        self.occupancy = mx.zeros((resolution, resolution, resolution), dtype=mx.bool_)

    def _cell_centers(self) -> mx.array:
        lo, hi = self.bounds
        r = self.resolution
        c = (mx.arange(r, dtype=mx.float32) + 0.5) / r  # cell centers in [0, 1)
        coords = lo + c * (hi - lo)
        gz, gy, gx = mx.meshgrid(coords, coords, coords, indexing="ij")
        return mx.stack([gx, gy, gz], axis=-1)  # (r, r, r, 3)

    def update(
        self,
        density_fn: Callable[[mx.array], mx.array],
        threshold: float = 0.01,
        chunk: int = 1 << 18,
    ) -> None:
        """Refresh occupancy by thresholding the density field at cell centers.

        Args:
            density_fn: callable mapping ``(P, 3)`` points to ``(P,)`` densities.
            threshold: cells with density above this are marked occupied.
            chunk: points evaluated per batch (bounds memory for fine grids).
        """
        centers = self._cell_centers().reshape(-1, 3)
        out = []
        for s in range(0, centers.shape[0], chunk):
            out.append(mx.stop_gradient(density_fn(centers[s : s + chunk])))
        density = mx.concatenate(out) if len(out) > 1 else out[0]
        r = self.resolution
        self.occupancy = (density > threshold).reshape(r, r, r)
        mx.eval(self.occupancy)

    def query(self, points: mx.array) -> mx.array:
        """Return a boolean mask of which ``(..., 3)`` points fall in occupied cells.

        Points outside the grid bounds are reported empty.
        """
        lo, hi = self.bounds
        r = self.resolution
        norm = (points - lo) / (hi - lo)  # -> [0, 1]
        idx = mx.floor(norm * r).astype(mx.int32)
        inside = mx.all((idx >= 0) & (idx < r), axis=-1)
        ci = mx.clip(idx, 0, r - 1)
        flat = (ci[..., 0] * r + ci[..., 1]) * r + ci[..., 2]
        occ = self.occupancy.reshape(-1)[flat]
        return occ & inside

    @property
    def occupied_fraction(self) -> float:
        """Fraction of cells currently marked occupied (useful for diagnostics)."""
        return float(self.occupancy.astype(mx.float32).mean())

occupied_fraction property

Fraction of cells currently marked occupied (useful for diagnostics).

update(density_fn, threshold=0.01, chunk=1 << 18)

Refresh occupancy by thresholding the density field at cell centers.

Parameters:

Name Type Description Default
density_fn Callable[[array], array]

callable mapping (P, 3) points to (P,) densities.

required
threshold float

cells with density above this are marked occupied.

0.01
chunk int

points evaluated per batch (bounds memory for fine grids).

1 << 18
Source code in src/mlx3d/nn/occupancy.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def update(
    self,
    density_fn: Callable[[mx.array], mx.array],
    threshold: float = 0.01,
    chunk: int = 1 << 18,
) -> None:
    """Refresh occupancy by thresholding the density field at cell centers.

    Args:
        density_fn: callable mapping ``(P, 3)`` points to ``(P,)`` densities.
        threshold: cells with density above this are marked occupied.
        chunk: points evaluated per batch (bounds memory for fine grids).
    """
    centers = self._cell_centers().reshape(-1, 3)
    out = []
    for s in range(0, centers.shape[0], chunk):
        out.append(mx.stop_gradient(density_fn(centers[s : s + chunk])))
    density = mx.concatenate(out) if len(out) > 1 else out[0]
    r = self.resolution
    self.occupancy = (density > threshold).reshape(r, r, r)
    mx.eval(self.occupancy)

query(points)

Return a boolean mask of which (..., 3) points fall in occupied cells.

Points outside the grid bounds are reported empty.

Source code in src/mlx3d/nn/occupancy.py
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def query(self, points: mx.array) -> mx.array:
    """Return a boolean mask of which ``(..., 3)`` points fall in occupied cells.

    Points outside the grid bounds are reported empty.
    """
    lo, hi = self.bounds
    r = self.resolution
    norm = (points - lo) / (hi - lo)  # -> [0, 1]
    idx = mx.floor(norm * r).astype(mx.int32)
    inside = mx.all((idx >= 0) & (idx < r), axis=-1)
    ci = mx.clip(idx, 0, r - 1)
    flat = (ci[..., 0] * r + ci[..., 1]) * r + ci[..., 2]
    occ = self.occupancy.reshape(-1)[flat]
    return occ & inside

render_rays_occupancy(model, origins, directions, near, far, grid, num_samples=128, eval_fraction=1.0, stratified=False, white_background=False)

Render rays, evaluating model only at occupied samples.

Parameters:

Name Type Description Default
model Field

a field model(points, directions) -> (density, rgb) (e.g. :class:HashGridNeRF).

required
origins array

(R, 3) ray origins.

required
directions array

(R, 3) ray directions.

required
near float

near sampling bound.

required
far float

far sampling bound.

required
grid OccupancyGrid

occupancy cache identifying non-empty space (kept fixed / detached).

required
num_samples int

samples per ray.

128
eval_fraction float

fraction of all R * num_samples samples to actually evaluate (the compaction budget). With sparse occupancy a small fraction covers every occupied sample; the rest are forced empty.

1.0
stratified bool

jitter samples (training) vs. deterministic (eval).

False
white_background bool

composite onto white.

False

Returns:

Type Description
dict[str, array]

Same dict as :func:~mlx3d.renderer.volume_render.

Source code in src/mlx3d/nn/accel.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def render_rays_occupancy(
    model: Field,
    origins: mx.array,
    directions: mx.array,
    near: float,
    far: float,
    grid: OccupancyGrid,
    num_samples: int = 128,
    eval_fraction: float = 1.0,
    stratified: bool = False,
    white_background: bool = False,
) -> dict[str, mx.array]:
    """Render rays, evaluating ``model`` only at occupied samples.

    Args:
        model: a field ``model(points, directions) -> (density, rgb)`` (e.g.
            :class:`HashGridNeRF`).
        origins: ``(R, 3)`` ray origins.
        directions: ``(R, 3)`` ray directions.
        near: near sampling bound.
        far: far sampling bound.
        grid: occupancy cache identifying non-empty space (kept fixed / detached).
        num_samples: samples per ray.
        eval_fraction: fraction of all ``R * num_samples`` samples to actually
            evaluate (the compaction budget). With sparse occupancy a small
            fraction covers every occupied sample; the rest are forced empty.
        stratified: jitter samples (training) vs. deterministic (eval).
        white_background: composite onto white.

    Returns:
        Same dict as :func:`~mlx3d.renderer.volume_render`.
    """
    r = origins.shape[0]
    points, t_vals = sample_along_rays(origins, directions, near, far, num_samples, stratified)
    view = mx.broadcast_to(directions[:, None, :], points.shape)

    m = r * num_samples
    pts = points.reshape(m, 3)
    views = view.reshape(m, 3)
    occupied = grid.query(pts)  # (M,) bool, detached (grid is constant)

    budget = max(1, min(m, int(m * eval_fraction)))
    # Bring occupied samples to the front, then keep the first `budget`.
    order = mx.argsort((~occupied).astype(mx.int32))
    idx = order[:budget]
    keep = occupied[idx].astype(mx.float32)[:, None]  # zero out any empty stragglers

    density_c, rgb_c = model(pts[idx], views[idx])
    density_c = density_c[:, None] * keep
    rgb_c = rgb_c * keep

    density = mx.zeros((m, 1)).at[idx].add(density_c).reshape(r, num_samples)
    rgb = mx.zeros((m, 3)).at[idx].add(rgb_c).reshape(r, num_samples, 3)
    return volume_render(density, rgb, t_vals, directions, white_background)

render_rays(model, origins, directions, near, far, num_coarse=64, num_fine=0, fine_model=None, stratified=True, white_background=False)

Render a batch of rays with optional hierarchical sampling.

Returns a dict with rgb, depth, acc (and rgb_coarse when fine sampling is enabled).

Source code in src/mlx3d/nn/nerf.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def render_rays(
    model: NeRF,
    origins: mx.array,
    directions: mx.array,
    near: float,
    far: float,
    num_coarse: int = 64,
    num_fine: int = 0,
    fine_model: NeRF | None = None,
    stratified: bool = True,
    white_background: bool = False,
) -> dict[str, mx.array]:
    """Render a batch of rays with optional hierarchical sampling.

    Returns a dict with ``rgb``, ``depth``, ``acc`` (and ``rgb_coarse`` when
    fine sampling is enabled).
    """
    dirs_n = directions / mx.linalg.norm(directions, axis=-1, keepdims=True)

    points, t_vals = sample_along_rays(
        origins, directions, near, far, num_coarse, stratified=stratified
    )
    view = mx.broadcast_to(dirs_n[:, None, :], points.shape)
    density, rgb = model(points, view)
    out = volume_render(density, rgb, t_vals, directions, white_background)

    if num_fine <= 0:
        return out

    mids = 0.5 * (t_vals[:, 1:] + t_vals[:, :-1])
    t_fine = sample_pdf(mids, out["weights"][:, 1:-1], num_fine, deterministic=not stratified)
    t_all = mx.sort(mx.concatenate([t_vals, t_fine], axis=-1), axis=-1)
    points = origins[:, None, :] + t_all[..., None] * directions[:, None, :]
    view = mx.broadcast_to(dirs_n[:, None, :], points.shape)
    fmodel = fine_model if fine_model is not None else model
    density, rgb = fmodel(points, view)
    fine_out = volume_render(density, rgb, t_all, directions, white_background)
    fine_out["rgb_coarse"] = out["rgb"]
    return fine_out