Skip to content

mlx3d.splatting

mlx3d.splatting

GaussianModel

A scene of 3D Gaussians with the standard 3DGS parameterization.

Raw (unconstrained) parameters live in self.params so they can be fed to MLX optimizers directly:

  • means (N, 3)
  • scales (N, 3): log of the per-axis standard deviations.
  • quats (N, 4): unnormalized rotation quaternions (w, x, y, z).
  • opacities (N,): logits; sigmoid gives opacity.
  • sh_dc (N, 1, 3) and sh_rest (N, K-1, 3): SH color coefficients.
Source code in src/mlx3d/splatting/model.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
class GaussianModel:
    """A scene of 3D Gaussians with the standard 3DGS parameterization.

    Raw (unconstrained) parameters live in ``self.params`` so they can be fed
    to MLX optimizers directly:

    - ``means`` (N, 3)
    - ``scales`` (N, 3): log of the per-axis standard deviations.
    - ``quats`` (N, 4): unnormalized rotation quaternions (w, x, y, z).
    - ``opacities`` (N,): logits; sigmoid gives opacity.
    - ``sh_dc`` (N, 1, 3) and ``sh_rest`` (N, K-1, 3): SH color coefficients.
    """

    def __init__(self, params: dict[str, mx.array], sh_degree: int = 3):
        self.params = params
        self.sh_degree = sh_degree
        # Active degree grows during training (as in the reference implementation).
        self.active_sh_degree = 0

    # ----------------------------------------------------------- construction
    @classmethod
    def from_points(
        cls,
        points: mx.array,
        colors: mx.array | None = None,
        sh_degree: int = 3,
        initial_opacity: float = 0.1,
        scale_init_max_ref: int = 10_000,
        scale_init_chunk_size: int = 1024,
        scale_init_max_scale: float | None = None,
    ) -> "GaussianModel":
        """Initialize from a point cloud (e.g. SfM points).

        Scales are set from the mean distance to the 3 nearest neighbors, as
        in the reference implementation.

        Args:
            scale_init_max_ref: Maximum reference points used for the initial
                nearest-neighbor scale estimate. Large COLMAP clouds otherwise
                spend seconds to minutes materializing huge distance tiles.
            scale_init_chunk_size: Query chunk size for the scale-estimation
                KNN. Lower values reduce peak memory during initialization.
            scale_init_max_scale: Optional cap for initial per-axis Gaussian
                scale. This is useful for COLMAP point clouds with sparse
                outliers whose nearest-neighbor distances would otherwise
                create full-screen splats.
        """
        from ..ops import knn_points

        N = points.shape[0]
        if colors is None:
            colors = mx.full((N, 3), 0.5)

        # Mean distance to the 3 nearest neighbors. For huge clouds, a random
        # reference subset keeps the O(N * R) search memory bounded without
        # meaningfully changing the scale estimate.
        ref_count = min(N, max(1, int(scale_init_max_ref)))
        if N > ref_count:
            ref = points[mx.random.permutation(N)[:ref_count]]
        else:
            ref = points

        if ref_count <= 1:
            mean_sq = mx.full((N,), 1e-8)
        else:
            k = min(4, ref_count)
            d, _ = knn_points(points, ref, K=k, chunk_size=scale_init_chunk_size)
            # If the reference set contains the query point, the first
            # neighbor is the point itself. Skip that zero-distance hit;
            # otherwise use the nearest available neighbors directly.
            self_hit = d[:, 0] <= 1e-12
            nearest = mx.where(self_hit[:, None], d[:, 1:k], d[:, : k - 1])
            mean_sq = mx.maximum(nearest.mean(axis=-1), 1e-8)
        scales = mx.log(mx.sqrt(mean_sq))[:, None] * mx.ones((1, 3))
        if scale_init_max_scale is not None:
            max_log_scale = float(np.log(max(scale_init_max_scale, 1e-8)))
            scales = mx.minimum(scales, mx.full(scales.shape, max_log_scale))

        quats = mx.zeros((N, 4))
        quats = quats.at[:, 0].add(mx.ones((N,)))

        K = num_sh_bases(sh_degree)
        params = {
            "means": mx.array(points),
            "scales": scales,
            "quats": quats,
            "opacities": mx.full((N,), float(np.log(initial_opacity / (1 - initial_opacity)))),
            "sh_dc": rgb_to_sh(mx.array(colors))[:, None, :],
            "sh_rest": mx.zeros((N, K - 1, 3)),
        }
        return cls(params, sh_degree=sh_degree)

    # ------------------------------------------------------------- activations
    @property
    def num_gaussians(self) -> int:
        return self.params["means"].shape[0]

    def __len__(self) -> int:
        return self.num_gaussians

    @property
    def scales_act(self) -> mx.array:
        return mx.exp(self.params["scales"])

    @property
    def opacities_act(self) -> mx.array:
        return mx.sigmoid(self.params["opacities"])

    @property
    def sh(self) -> mx.array:
        return mx.concatenate([self.params["sh_dc"], self.params["sh_rest"]], axis=1)

    def apply_2dgs_constraints(self, max_thickness: float) -> None:
        """Constrain Gaussians to thin surfels for 2DGS-style training.

        The covariance projection and rasterizer already support anisotropic
        oriented Gaussians. Keeping the third local scale small turns each
        Gaussian into an oriented disk while preserving standard 3DGS PLY
        compatibility.
        """
        thickness = max(float(max_thickness), 1e-8)
        max_log_thickness = float(np.log(thickness))
        constrained = self.params["scales"]
        z = mx.minimum(
            constrained[:, 2:3],
            mx.full((self.num_gaussians, 1), max_log_thickness, dtype=constrained.dtype),
        )
        self.params["scales"] = mx.concatenate([constrained[:, :2], z], axis=1)

    # ------------------------------------------------------------------ render
    def render(
        self,
        camera: Camera,
        background: mx.array | None = None,
        antialias: bool = False,
        projection: str = "ewa",
    ) -> dict:
        return render_gaussians(
            camera,
            self.params["means"],
            self.params["quats"],
            self.scales_act,
            self.opacities_act,
            sh=self.sh,
            sh_degree=self.active_sh_degree,
            background=background,
            antialias=antialias,
            projection=projection,
        )

    def render_depth(
        self,
        camera: Camera,
        antialias: bool = False,
        projection: str = "ewa",
    ) -> dict:
        return render_gaussian_depth(
            camera,
            self.params["means"],
            self.params["quats"],
            self.scales_act,
            self.opacities_act,
            antialias=antialias,
            projection=projection,
        )

    def render_features(
        self,
        camera: Camera,
        features: mx.array,
        background: mx.array | None = None,
        normalize: bool = False,
        antialias: bool = False,
        projection: str = "ewa",
    ) -> dict:
        """Render arbitrary per-Gaussian feature channels from this model."""
        return render_gaussian_features(
            camera,
            self.params["means"],
            self.params["quats"],
            self.scales_act,
            self.opacities_act,
            features,
            background=background,
            normalize=normalize,
            antialias=antialias,
            projection=projection,
        )

    def one_up_sh_degree(self) -> None:
        if self.active_sh_degree < self.sh_degree:
            self.active_sh_degree += 1

    # --------------------------------------------------------- surface export
    def surfel_points(
        self,
        min_opacity: float = 0.01,
        max_points: int | None = None,
        orient_towards: tuple[float, float, float] | mx.array | None = None,
    ) -> tuple[mx.array, mx.array]:
        """Return oriented surfel samples from the Gaussian centers.

        This is intended for 2DGS-style checkpoints, where the local ``z`` axis
        is the surfel normal and the first two local scales span the disk. Rows
        are filtered by activated opacity and, when capped, ranked by
        ``opacity * scale_x * scale_y`` so large opaque surfels survive first.
        """
        if max_points is not None and max_points <= 0:
            raise ValueError("max_points must be positive when provided.")
        n = self.num_gaussians
        if n == 0:
            empty = mx.zeros((0, 3), dtype=mx.float32)
            return empty, empty

        opacity = np.array(self.opacities_act)
        scales = np.array(self.scales_act)
        importance = opacity * scales[:, 0] * scales[:, 1]
        keep = opacity >= float(min_opacity)
        if not keep.any():
            keep[int(np.argmax(importance))] = True
        keep_idx = np.where(keep)[0]
        if max_points is not None and keep_idx.size > max_points:
            local = np.argpartition(-importance[keep_idx], max_points - 1)[:max_points]
            keep_idx = keep_idx[local]
        keep_idx = np.sort(keep_idx.astype(np.int32))
        idx = mx.array(keep_idx)

        points = self.params["means"][idx]
        normals = quaternion_to_matrix(self.params["quats"][idx])[:, :, 2]
        normals = normals / mx.maximum(mx.linalg.norm(normals, axis=-1, keepdims=True), 1e-8)
        if orient_towards is not None:
            target = mx.array(orient_towards, dtype=points.dtype)
            flip = mx.sum((target - points) * normals, axis=-1, keepdims=True) < 0
            normals = mx.where(flip, -normals, normals)
        return points, normals

    def extract_surface_mesh(
        self,
        resolution: int = 64,
        padding: float = 0.1,
        min_opacity: float = 0.01,
        max_points: int | None = 200_000,
        orient_towards: tuple[float, float, float] | mx.array | None = None,
    ) -> Meshes:
        """Reconstruct a mesh from oriented Gaussian surfels via Poisson reconstruction."""
        if resolution <= 1:
            raise ValueError("resolution must be greater than 1.")
        points, normals = self.surfel_points(
            min_opacity=min_opacity,
            max_points=max_points,
            orient_towards=orient_towards,
        )
        from ..ops import poisson_reconstruction

        return poisson_reconstruction(points, normals, resolution=resolution, padding=padding)

    # -------------------------------------------------------------- compaction
    def copy(self) -> "GaussianModel":
        """Return a detached copy of the Gaussian table."""
        out = GaussianModel({k: mx.array(v) for k, v in self.params.items()}, self.sh_degree)
        out.active_sh_degree = self.active_sh_degree
        return out

    def compact(
        self,
        min_opacity: float = 0.0,
        max_gaussians: int | None = None,
        target_sh_degree: int | None = None,
    ) -> "GaussianModel":
        """Return a smaller checkpoint by pruning low-importance Gaussians.

        Importance is a conservative, view-independent proxy:
        ``sigmoid(opacity) * max(scale)^2``. This keeps opaque large-footprint
        splats before transparent/subpixel ones, preserves the original order
        of retained rows for checkpoint diffability, and does not mutate this
        model.

        Args:
            min_opacity: prune Gaussians with activated opacity below this
                threshold. If the threshold removes every row, the most
                important Gaussian is kept so the checkpoint remains renderable.
            max_gaussians: optional hard cap; the highest-importance rows are
                retained.
            target_sh_degree: optional lower SH degree for color coefficient
                truncation. This reduces checkpoint size and render cost for
                view-dependent color at the expense of angular detail.
        """
        if max_gaussians is not None and max_gaussians <= 0:
            raise ValueError("max_gaussians must be positive when provided.")
        if target_sh_degree is not None and not (0 <= target_sh_degree <= self.sh_degree):
            raise ValueError("target_sh_degree must be between 0 and the model sh_degree.")

        n = self.num_gaussians
        if n == 0:
            return self.copy()

        opacity = np.array(self.opacities_act)
        max_scale = np.array(self.scales_act).max(axis=1)
        importance = opacity * max_scale * max_scale
        keep = opacity >= float(min_opacity)
        if not keep.any():
            keep[int(np.argmax(importance))] = True

        keep_idx = np.where(keep)[0]
        if max_gaussians is not None and keep_idx.size > max_gaussians:
            local = np.argpartition(-importance[keep_idx], max_gaussians - 1)[:max_gaussians]
            keep_idx = keep_idx[local]
        keep_idx = np.sort(keep_idx.astype(np.int32))
        idx = mx.array(keep_idx)
        params = {k: v[idx] for k, v in self.params.items()}

        sh_degree = self.sh_degree
        if target_sh_degree is not None:
            sh_degree = int(target_sh_degree)
            k = num_sh_bases(sh_degree)
            params["sh_rest"] = params["sh_rest"][:, : max(0, k - 1), :]

        out = GaussianModel(params, sh_degree=sh_degree)
        out.active_sh_degree = min(self.active_sh_degree, sh_degree)
        return out

    # ------------------------------------------------------------- checkpoints
    def save_ply(self, path: str) -> None:
        """Save in the standard 3DGS PLY layout (compatible with most viewers)."""
        p = self.params
        N = self.num_gaussians
        extra: dict[str, mx.array] = {}
        sh_dc = p["sh_dc"]
        for c in range(3):
            extra[f"f_dc_{c}"] = sh_dc[:, 0, c]
        rest = np.array(p["sh_rest"])  # (N, K-1, 3)
        rest_t = rest.transpose(0, 2, 1).reshape(N, -1)  # channel-major as in 3DGS
        for i in range(rest_t.shape[1]):
            extra[f"f_rest_{i}"] = mx.array(rest_t[:, i])
        extra["opacity"] = p["opacities"]
        for i in range(3):
            extra[f"scale_{i}"] = p["scales"][:, i]
        for i in range(4):
            extra[f"rot_{i}"] = p["quats"][:, i]
        save_ply(
            path,
            p["means"],
            normals=mx.zeros((N, 3)),
            extra=extra,
            binary=True,
        )

    @classmethod
    def load_ply(cls, path: str, sh_degree: int = 3) -> "GaussianModel":
        """Load a 3DGS-format PLY checkpoint."""
        data = load_ply(path)
        e = data.extra
        N = data.verts.shape[0]
        n_rest = len([k for k in e if k.startswith("f_rest_")])
        K = num_sh_bases(sh_degree)
        if n_rest != 3 * (K - 1):
            # Infer degree from the file.
            K = n_rest // 3 + 1
            sh_degree = int(np.sqrt(K)) - 1
        sh_dc = mx.stack([e["f_dc_0"], e["f_dc_1"], e["f_dc_2"]], axis=-1)[:, None, :]
        if n_rest > 0:
            rest = np.stack([np.array(e[f"f_rest_{i}"]) for i in range(n_rest)], axis=1)
            rest = rest.reshape(N, 3, K - 1).transpose(0, 2, 1)
            sh_rest = mx.array(rest.astype(np.float32))
        else:
            sh_rest = mx.zeros((N, K - 1, 3))
        params = {
            "means": data.verts,
            "scales": mx.stack([e[f"scale_{i}"] for i in range(3)], axis=-1),
            "quats": mx.stack([e[f"rot_{i}"] for i in range(4)], axis=-1),
            "opacities": e["opacity"],
            "sh_dc": sh_dc,
            "sh_rest": sh_rest,
        }
        model = cls(params, sh_degree=sh_degree)
        model.active_sh_degree = sh_degree
        return model

    # ----------------------------------------------------------- densification
    def select(self, keep_idx: np.ndarray) -> None:
        """Keep only the Gaussians at ``keep_idx`` (in-place)."""
        idx = mx.array(keep_idx.astype(np.int32))
        self.params = {k: v[idx] for k, v in self.params.items()}

    def append(self, new_params: dict[str, mx.array]) -> None:
        """Concatenate new Gaussians (in-place)."""
        self.params = {
            k: mx.concatenate([v, new_params[k]], axis=0) for k, v in self.params.items()
        }

    def densify_and_prune(
        self,
        grad_accum: mx.array,
        grad_count: mx.array,
        grad_threshold: float = 0.0002,
        scene_extent: float = 1.0,
        percent_dense: float = 0.01,
        min_opacity: float = 0.005,
        return_optimizer_state: bool = False,
    ) -> dict[str, object]:
        """Adaptive density control from the 3DGS paper.

        Args:
            grad_accum: (N,) accumulated NDC-space positional gradient norms.
            grad_count: (N,) number of accumulation steps each Gaussian was visible.

        Under-reconstructed regions (high positional gradient, small scale)
        are cloned; over-reconstructed ones (high gradient, large scale) are
        split in two. Nearly transparent or oversized Gaussians are pruned.
        Returns counts of cloned/split/pruned Gaussians.
        """
        avg_grad = np.array(grad_accum) / np.maximum(np.array(grad_count), 1.0)
        scales = np.array(self.scales_act)
        max_scale = scales.max(axis=1)
        high_grad = avg_grad > grad_threshold

        clone_mask = high_grad & (max_scale <= percent_dense * scene_extent)
        split_mask = high_grad & (max_scale > percent_dense * scene_extent)

        p_np = {k: np.array(v) for k, v in self.params.items()}

        # Clone: duplicate as-is.
        clone_idx = np.where(clone_mask)[0]
        clones = {k: v[clone_idx] for k, v in p_np.items()}

        # Split: two samples from each Gaussian, scales shrunk by 1.6.
        split_idx = np.where(split_mask)[0]
        splits: dict[str, np.ndarray] = {}
        if split_idx.size > 0:
            from ..transforms import quaternion_to_matrix

            q = mx.array(p_np["quats"][split_idx])
            R = np.array(
                quaternion_to_matrix(
                    q / np.linalg.norm(p_np["quats"][split_idx], axis=1, keepdims=True)
                )
            )
            s = scales[split_idx]
            n2 = split_idx.size * 2
            samples = np.random.normal(size=(n2, 3)) * np.repeat(s, 2, axis=0)
            offsets = np.einsum("nij,nj->ni", np.repeat(R, 2, axis=0), samples)
            splits = {k: np.repeat(v[split_idx], 2, axis=0) for k, v in p_np.items()}
            splits["means"] = splits["means"] + offsets.astype(np.float32)
            splits["scales"] = splits["scales"] - float(np.log(1.6))

        # Prune originals: split sources, low opacity, oversized.
        opac = 1.0 / (1.0 + np.exp(-np.clip(p_np["opacities"], -50.0, 50.0)))
        prune_mask = (opac < min_opacity) | split_mask | (max_scale > 0.5 * scene_extent)
        keep_idx = np.where(~prune_mask)[0]

        self.select(keep_idx)
        if clone_idx.size > 0:
            self.append({k: mx.array(v) for k, v in clones.items()})
        if split_idx.size > 0:
            self.append({k: mx.array(v) for k, v in splits.items()})

        stats: dict[str, object] = {
            "cloned": int(clone_idx.size),
            "split": int(split_idx.size),
            "pruned": int(prune_mask.sum()),
        }
        if return_optimizer_state:
            stats["_keep_idx"] = keep_idx.astype(np.int32)
            stats["_new_count"] = int(clone_idx.size + 2 * split_idx.size)
        return stats

    def relocate_mcmc(
        self,
        grad_accum: mx.array,
        grad_count: mx.array,
        relocate_frac: float = 0.02,
        min_opacity: float = 0.01,
        jitter_scale: float = 0.25,
    ) -> dict[str, object]:
        """Fixed-budget MCMC-style relocation of underused Gaussians.

        This keeps ``N`` constant: low-opacity or never-visible rows are
        replaced by jittered copies of high-gradient rows. It is inspired by
        MCMC 3DGS relocation, and is intended as an alternative to vanilla
        clone/split/prune density control.
        """
        n = self.num_gaussians
        max_relocate = int(max(0, min(relocate_frac, 1.0)) * n)
        if max_relocate <= 0 or n <= 1:
            return {"relocated": 0}

        avg_grad = grad_accum / mx.maximum(grad_count, 1.0)
        opac = mx.sigmoid(self.params["opacities"])
        counts = grad_count
        underused = (opac < min_opacity) | (counts <= 0)
        underused_count = int(mx.sum(underused.astype(mx.int32)))
        k = max_relocate if underused_count == 0 else min(max_relocate, underused_count)
        k = min(k, n - 1)
        if k == 0:
            return {"relocated": 0}

        # Prefer underused rows. If none exist, fall back to the lowest-opacity
        target_score = mx.where(underused, 2.0 - opac, -opac)
        dst = mx.argpartition(-target_score, kth=k - 1)[:k]
        source_score = avg_grad.at[dst].add(-1e30 - avg_grad[dst])
        src = mx.argpartition(-source_score, kth=k - 1)[:k]
        source_scales = mx.exp(self.params["scales"][src])
        reset_opacity = float(np.log(0.05 / 0.95))
        split_shrink = float(np.log(1.6))
        for name, arr in self.params.items():
            values = arr[src]
            if name == "means" and jitter_scale > 0:
                values = values + mx.random.normal(values.shape) * source_scales * float(
                    jitter_scale
                )
            elif name == "opacities":
                values = mx.full(values.shape, reset_opacity, dtype=values.dtype)
            elif name == "scales":
                values = values - split_shrink
            self.params[name] = arr.at[dst].add(values - arr[dst])
        return {"relocated": int(k), "_moved_idx": dst}

    def reset_opacities(self, max_opacity: float = 0.01) -> None:
        """Clamp opacities down (periodic reset from the 3DGS paper)."""
        logit = float(np.log(max_opacity / (1 - max_opacity)))
        self.params["opacities"] = mx.minimum(
            self.params["opacities"], mx.full(self.params["opacities"].shape, logit)
        )

from_points(points, colors=None, sh_degree=3, initial_opacity=0.1, scale_init_max_ref=10000, scale_init_chunk_size=1024, scale_init_max_scale=None) classmethod

Initialize from a point cloud (e.g. SfM points).

Scales are set from the mean distance to the 3 nearest neighbors, as in the reference implementation.

Parameters:

Name Type Description Default
scale_init_max_ref int

Maximum reference points used for the initial nearest-neighbor scale estimate. Large COLMAP clouds otherwise spend seconds to minutes materializing huge distance tiles.

10000
scale_init_chunk_size int

Query chunk size for the scale-estimation KNN. Lower values reduce peak memory during initialization.

1024
scale_init_max_scale float | None

Optional cap for initial per-axis Gaussian scale. This is useful for COLMAP point clouds with sparse outliers whose nearest-neighbor distances would otherwise create full-screen splats.

None
Source code in src/mlx3d/splatting/model.py
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
@classmethod
def from_points(
    cls,
    points: mx.array,
    colors: mx.array | None = None,
    sh_degree: int = 3,
    initial_opacity: float = 0.1,
    scale_init_max_ref: int = 10_000,
    scale_init_chunk_size: int = 1024,
    scale_init_max_scale: float | None = None,
) -> "GaussianModel":
    """Initialize from a point cloud (e.g. SfM points).

    Scales are set from the mean distance to the 3 nearest neighbors, as
    in the reference implementation.

    Args:
        scale_init_max_ref: Maximum reference points used for the initial
            nearest-neighbor scale estimate. Large COLMAP clouds otherwise
            spend seconds to minutes materializing huge distance tiles.
        scale_init_chunk_size: Query chunk size for the scale-estimation
            KNN. Lower values reduce peak memory during initialization.
        scale_init_max_scale: Optional cap for initial per-axis Gaussian
            scale. This is useful for COLMAP point clouds with sparse
            outliers whose nearest-neighbor distances would otherwise
            create full-screen splats.
    """
    from ..ops import knn_points

    N = points.shape[0]
    if colors is None:
        colors = mx.full((N, 3), 0.5)

    # Mean distance to the 3 nearest neighbors. For huge clouds, a random
    # reference subset keeps the O(N * R) search memory bounded without
    # meaningfully changing the scale estimate.
    ref_count = min(N, max(1, int(scale_init_max_ref)))
    if N > ref_count:
        ref = points[mx.random.permutation(N)[:ref_count]]
    else:
        ref = points

    if ref_count <= 1:
        mean_sq = mx.full((N,), 1e-8)
    else:
        k = min(4, ref_count)
        d, _ = knn_points(points, ref, K=k, chunk_size=scale_init_chunk_size)
        # If the reference set contains the query point, the first
        # neighbor is the point itself. Skip that zero-distance hit;
        # otherwise use the nearest available neighbors directly.
        self_hit = d[:, 0] <= 1e-12
        nearest = mx.where(self_hit[:, None], d[:, 1:k], d[:, : k - 1])
        mean_sq = mx.maximum(nearest.mean(axis=-1), 1e-8)
    scales = mx.log(mx.sqrt(mean_sq))[:, None] * mx.ones((1, 3))
    if scale_init_max_scale is not None:
        max_log_scale = float(np.log(max(scale_init_max_scale, 1e-8)))
        scales = mx.minimum(scales, mx.full(scales.shape, max_log_scale))

    quats = mx.zeros((N, 4))
    quats = quats.at[:, 0].add(mx.ones((N,)))

    K = num_sh_bases(sh_degree)
    params = {
        "means": mx.array(points),
        "scales": scales,
        "quats": quats,
        "opacities": mx.full((N,), float(np.log(initial_opacity / (1 - initial_opacity)))),
        "sh_dc": rgb_to_sh(mx.array(colors))[:, None, :],
        "sh_rest": mx.zeros((N, K - 1, 3)),
    }
    return cls(params, sh_degree=sh_degree)

apply_2dgs_constraints(max_thickness)

Constrain Gaussians to thin surfels for 2DGS-style training.

The covariance projection and rasterizer already support anisotropic oriented Gaussians. Keeping the third local scale small turns each Gaussian into an oriented disk while preserving standard 3DGS PLY compatibility.

Source code in src/mlx3d/splatting/model.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def apply_2dgs_constraints(self, max_thickness: float) -> None:
    """Constrain Gaussians to thin surfels for 2DGS-style training.

    The covariance projection and rasterizer already support anisotropic
    oriented Gaussians. Keeping the third local scale small turns each
    Gaussian into an oriented disk while preserving standard 3DGS PLY
    compatibility.
    """
    thickness = max(float(max_thickness), 1e-8)
    max_log_thickness = float(np.log(thickness))
    constrained = self.params["scales"]
    z = mx.minimum(
        constrained[:, 2:3],
        mx.full((self.num_gaussians, 1), max_log_thickness, dtype=constrained.dtype),
    )
    self.params["scales"] = mx.concatenate([constrained[:, :2], z], axis=1)

render_features(camera, features, background=None, normalize=False, antialias=False, projection='ewa')

Render arbitrary per-Gaussian feature channels from this model.

Source code in src/mlx3d/splatting/model.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def render_features(
    self,
    camera: Camera,
    features: mx.array,
    background: mx.array | None = None,
    normalize: bool = False,
    antialias: bool = False,
    projection: str = "ewa",
) -> dict:
    """Render arbitrary per-Gaussian feature channels from this model."""
    return render_gaussian_features(
        camera,
        self.params["means"],
        self.params["quats"],
        self.scales_act,
        self.opacities_act,
        features,
        background=background,
        normalize=normalize,
        antialias=antialias,
        projection=projection,
    )

surfel_points(min_opacity=0.01, max_points=None, orient_towards=None)

Return oriented surfel samples from the Gaussian centers.

This is intended for 2DGS-style checkpoints, where the local z axis is the surfel normal and the first two local scales span the disk. Rows are filtered by activated opacity and, when capped, ranked by opacity * scale_x * scale_y so large opaque surfels survive first.

Source code in src/mlx3d/splatting/model.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def surfel_points(
    self,
    min_opacity: float = 0.01,
    max_points: int | None = None,
    orient_towards: tuple[float, float, float] | mx.array | None = None,
) -> tuple[mx.array, mx.array]:
    """Return oriented surfel samples from the Gaussian centers.

    This is intended for 2DGS-style checkpoints, where the local ``z`` axis
    is the surfel normal and the first two local scales span the disk. Rows
    are filtered by activated opacity and, when capped, ranked by
    ``opacity * scale_x * scale_y`` so large opaque surfels survive first.
    """
    if max_points is not None and max_points <= 0:
        raise ValueError("max_points must be positive when provided.")
    n = self.num_gaussians
    if n == 0:
        empty = mx.zeros((0, 3), dtype=mx.float32)
        return empty, empty

    opacity = np.array(self.opacities_act)
    scales = np.array(self.scales_act)
    importance = opacity * scales[:, 0] * scales[:, 1]
    keep = opacity >= float(min_opacity)
    if not keep.any():
        keep[int(np.argmax(importance))] = True
    keep_idx = np.where(keep)[0]
    if max_points is not None and keep_idx.size > max_points:
        local = np.argpartition(-importance[keep_idx], max_points - 1)[:max_points]
        keep_idx = keep_idx[local]
    keep_idx = np.sort(keep_idx.astype(np.int32))
    idx = mx.array(keep_idx)

    points = self.params["means"][idx]
    normals = quaternion_to_matrix(self.params["quats"][idx])[:, :, 2]
    normals = normals / mx.maximum(mx.linalg.norm(normals, axis=-1, keepdims=True), 1e-8)
    if orient_towards is not None:
        target = mx.array(orient_towards, dtype=points.dtype)
        flip = mx.sum((target - points) * normals, axis=-1, keepdims=True) < 0
        normals = mx.where(flip, -normals, normals)
    return points, normals

extract_surface_mesh(resolution=64, padding=0.1, min_opacity=0.01, max_points=200000, orient_towards=None)

Reconstruct a mesh from oriented Gaussian surfels via Poisson reconstruction.

Source code in src/mlx3d/splatting/model.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def extract_surface_mesh(
    self,
    resolution: int = 64,
    padding: float = 0.1,
    min_opacity: float = 0.01,
    max_points: int | None = 200_000,
    orient_towards: tuple[float, float, float] | mx.array | None = None,
) -> Meshes:
    """Reconstruct a mesh from oriented Gaussian surfels via Poisson reconstruction."""
    if resolution <= 1:
        raise ValueError("resolution must be greater than 1.")
    points, normals = self.surfel_points(
        min_opacity=min_opacity,
        max_points=max_points,
        orient_towards=orient_towards,
    )
    from ..ops import poisson_reconstruction

    return poisson_reconstruction(points, normals, resolution=resolution, padding=padding)

copy()

Return a detached copy of the Gaussian table.

Source code in src/mlx3d/splatting/model.py
273
274
275
276
277
def copy(self) -> "GaussianModel":
    """Return a detached copy of the Gaussian table."""
    out = GaussianModel({k: mx.array(v) for k, v in self.params.items()}, self.sh_degree)
    out.active_sh_degree = self.active_sh_degree
    return out

compact(min_opacity=0.0, max_gaussians=None, target_sh_degree=None)

Return a smaller checkpoint by pruning low-importance Gaussians.

Importance is a conservative, view-independent proxy: sigmoid(opacity) * max(scale)^2. This keeps opaque large-footprint splats before transparent/subpixel ones, preserves the original order of retained rows for checkpoint diffability, and does not mutate this model.

Parameters:

Name Type Description Default
min_opacity float

prune Gaussians with activated opacity below this threshold. If the threshold removes every row, the most important Gaussian is kept so the checkpoint remains renderable.

0.0
max_gaussians int | None

optional hard cap; the highest-importance rows are retained.

None
target_sh_degree int | None

optional lower SH degree for color coefficient truncation. This reduces checkpoint size and render cost for view-dependent color at the expense of angular detail.

None
Source code in src/mlx3d/splatting/model.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def compact(
    self,
    min_opacity: float = 0.0,
    max_gaussians: int | None = None,
    target_sh_degree: int | None = None,
) -> "GaussianModel":
    """Return a smaller checkpoint by pruning low-importance Gaussians.

    Importance is a conservative, view-independent proxy:
    ``sigmoid(opacity) * max(scale)^2``. This keeps opaque large-footprint
    splats before transparent/subpixel ones, preserves the original order
    of retained rows for checkpoint diffability, and does not mutate this
    model.

    Args:
        min_opacity: prune Gaussians with activated opacity below this
            threshold. If the threshold removes every row, the most
            important Gaussian is kept so the checkpoint remains renderable.
        max_gaussians: optional hard cap; the highest-importance rows are
            retained.
        target_sh_degree: optional lower SH degree for color coefficient
            truncation. This reduces checkpoint size and render cost for
            view-dependent color at the expense of angular detail.
    """
    if max_gaussians is not None and max_gaussians <= 0:
        raise ValueError("max_gaussians must be positive when provided.")
    if target_sh_degree is not None and not (0 <= target_sh_degree <= self.sh_degree):
        raise ValueError("target_sh_degree must be between 0 and the model sh_degree.")

    n = self.num_gaussians
    if n == 0:
        return self.copy()

    opacity = np.array(self.opacities_act)
    max_scale = np.array(self.scales_act).max(axis=1)
    importance = opacity * max_scale * max_scale
    keep = opacity >= float(min_opacity)
    if not keep.any():
        keep[int(np.argmax(importance))] = True

    keep_idx = np.where(keep)[0]
    if max_gaussians is not None and keep_idx.size > max_gaussians:
        local = np.argpartition(-importance[keep_idx], max_gaussians - 1)[:max_gaussians]
        keep_idx = keep_idx[local]
    keep_idx = np.sort(keep_idx.astype(np.int32))
    idx = mx.array(keep_idx)
    params = {k: v[idx] for k, v in self.params.items()}

    sh_degree = self.sh_degree
    if target_sh_degree is not None:
        sh_degree = int(target_sh_degree)
        k = num_sh_bases(sh_degree)
        params["sh_rest"] = params["sh_rest"][:, : max(0, k - 1), :]

    out = GaussianModel(params, sh_degree=sh_degree)
    out.active_sh_degree = min(self.active_sh_degree, sh_degree)
    return out

save_ply(path)

Save in the standard 3DGS PLY layout (compatible with most viewers).

Source code in src/mlx3d/splatting/model.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def save_ply(self, path: str) -> None:
    """Save in the standard 3DGS PLY layout (compatible with most viewers)."""
    p = self.params
    N = self.num_gaussians
    extra: dict[str, mx.array] = {}
    sh_dc = p["sh_dc"]
    for c in range(3):
        extra[f"f_dc_{c}"] = sh_dc[:, 0, c]
    rest = np.array(p["sh_rest"])  # (N, K-1, 3)
    rest_t = rest.transpose(0, 2, 1).reshape(N, -1)  # channel-major as in 3DGS
    for i in range(rest_t.shape[1]):
        extra[f"f_rest_{i}"] = mx.array(rest_t[:, i])
    extra["opacity"] = p["opacities"]
    for i in range(3):
        extra[f"scale_{i}"] = p["scales"][:, i]
    for i in range(4):
        extra[f"rot_{i}"] = p["quats"][:, i]
    save_ply(
        path,
        p["means"],
        normals=mx.zeros((N, 3)),
        extra=extra,
        binary=True,
    )

load_ply(path, sh_degree=3) classmethod

Load a 3DGS-format PLY checkpoint.

Source code in src/mlx3d/splatting/model.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
@classmethod
def load_ply(cls, path: str, sh_degree: int = 3) -> "GaussianModel":
    """Load a 3DGS-format PLY checkpoint."""
    data = load_ply(path)
    e = data.extra
    N = data.verts.shape[0]
    n_rest = len([k for k in e if k.startswith("f_rest_")])
    K = num_sh_bases(sh_degree)
    if n_rest != 3 * (K - 1):
        # Infer degree from the file.
        K = n_rest // 3 + 1
        sh_degree = int(np.sqrt(K)) - 1
    sh_dc = mx.stack([e["f_dc_0"], e["f_dc_1"], e["f_dc_2"]], axis=-1)[:, None, :]
    if n_rest > 0:
        rest = np.stack([np.array(e[f"f_rest_{i}"]) for i in range(n_rest)], axis=1)
        rest = rest.reshape(N, 3, K - 1).transpose(0, 2, 1)
        sh_rest = mx.array(rest.astype(np.float32))
    else:
        sh_rest = mx.zeros((N, K - 1, 3))
    params = {
        "means": data.verts,
        "scales": mx.stack([e[f"scale_{i}"] for i in range(3)], axis=-1),
        "quats": mx.stack([e[f"rot_{i}"] for i in range(4)], axis=-1),
        "opacities": e["opacity"],
        "sh_dc": sh_dc,
        "sh_rest": sh_rest,
    }
    model = cls(params, sh_degree=sh_degree)
    model.active_sh_degree = sh_degree
    return model

select(keep_idx)

Keep only the Gaussians at keep_idx (in-place).

Source code in src/mlx3d/splatting/model.py
395
396
397
398
def select(self, keep_idx: np.ndarray) -> None:
    """Keep only the Gaussians at ``keep_idx`` (in-place)."""
    idx = mx.array(keep_idx.astype(np.int32))
    self.params = {k: v[idx] for k, v in self.params.items()}

append(new_params)

Concatenate new Gaussians (in-place).

Source code in src/mlx3d/splatting/model.py
400
401
402
403
404
def append(self, new_params: dict[str, mx.array]) -> None:
    """Concatenate new Gaussians (in-place)."""
    self.params = {
        k: mx.concatenate([v, new_params[k]], axis=0) for k, v in self.params.items()
    }

densify_and_prune(grad_accum, grad_count, grad_threshold=0.0002, scene_extent=1.0, percent_dense=0.01, min_opacity=0.005, return_optimizer_state=False)

Adaptive density control from the 3DGS paper.

Parameters:

Name Type Description Default
grad_accum array

(N,) accumulated NDC-space positional gradient norms.

required
grad_count array

(N,) number of accumulation steps each Gaussian was visible.

required

Under-reconstructed regions (high positional gradient, small scale) are cloned; over-reconstructed ones (high gradient, large scale) are split in two. Nearly transparent or oversized Gaussians are pruned. Returns counts of cloned/split/pruned Gaussians.

Source code in src/mlx3d/splatting/model.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def densify_and_prune(
    self,
    grad_accum: mx.array,
    grad_count: mx.array,
    grad_threshold: float = 0.0002,
    scene_extent: float = 1.0,
    percent_dense: float = 0.01,
    min_opacity: float = 0.005,
    return_optimizer_state: bool = False,
) -> dict[str, object]:
    """Adaptive density control from the 3DGS paper.

    Args:
        grad_accum: (N,) accumulated NDC-space positional gradient norms.
        grad_count: (N,) number of accumulation steps each Gaussian was visible.

    Under-reconstructed regions (high positional gradient, small scale)
    are cloned; over-reconstructed ones (high gradient, large scale) are
    split in two. Nearly transparent or oversized Gaussians are pruned.
    Returns counts of cloned/split/pruned Gaussians.
    """
    avg_grad = np.array(grad_accum) / np.maximum(np.array(grad_count), 1.0)
    scales = np.array(self.scales_act)
    max_scale = scales.max(axis=1)
    high_grad = avg_grad > grad_threshold

    clone_mask = high_grad & (max_scale <= percent_dense * scene_extent)
    split_mask = high_grad & (max_scale > percent_dense * scene_extent)

    p_np = {k: np.array(v) for k, v in self.params.items()}

    # Clone: duplicate as-is.
    clone_idx = np.where(clone_mask)[0]
    clones = {k: v[clone_idx] for k, v in p_np.items()}

    # Split: two samples from each Gaussian, scales shrunk by 1.6.
    split_idx = np.where(split_mask)[0]
    splits: dict[str, np.ndarray] = {}
    if split_idx.size > 0:
        from ..transforms import quaternion_to_matrix

        q = mx.array(p_np["quats"][split_idx])
        R = np.array(
            quaternion_to_matrix(
                q / np.linalg.norm(p_np["quats"][split_idx], axis=1, keepdims=True)
            )
        )
        s = scales[split_idx]
        n2 = split_idx.size * 2
        samples = np.random.normal(size=(n2, 3)) * np.repeat(s, 2, axis=0)
        offsets = np.einsum("nij,nj->ni", np.repeat(R, 2, axis=0), samples)
        splits = {k: np.repeat(v[split_idx], 2, axis=0) for k, v in p_np.items()}
        splits["means"] = splits["means"] + offsets.astype(np.float32)
        splits["scales"] = splits["scales"] - float(np.log(1.6))

    # Prune originals: split sources, low opacity, oversized.
    opac = 1.0 / (1.0 + np.exp(-np.clip(p_np["opacities"], -50.0, 50.0)))
    prune_mask = (opac < min_opacity) | split_mask | (max_scale > 0.5 * scene_extent)
    keep_idx = np.where(~prune_mask)[0]

    self.select(keep_idx)
    if clone_idx.size > 0:
        self.append({k: mx.array(v) for k, v in clones.items()})
    if split_idx.size > 0:
        self.append({k: mx.array(v) for k, v in splits.items()})

    stats: dict[str, object] = {
        "cloned": int(clone_idx.size),
        "split": int(split_idx.size),
        "pruned": int(prune_mask.sum()),
    }
    if return_optimizer_state:
        stats["_keep_idx"] = keep_idx.astype(np.int32)
        stats["_new_count"] = int(clone_idx.size + 2 * split_idx.size)
    return stats

relocate_mcmc(grad_accum, grad_count, relocate_frac=0.02, min_opacity=0.01, jitter_scale=0.25)

Fixed-budget MCMC-style relocation of underused Gaussians.

This keeps N constant: low-opacity or never-visible rows are replaced by jittered copies of high-gradient rows. It is inspired by MCMC 3DGS relocation, and is intended as an alternative to vanilla clone/split/prune density control.

Source code in src/mlx3d/splatting/model.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
def relocate_mcmc(
    self,
    grad_accum: mx.array,
    grad_count: mx.array,
    relocate_frac: float = 0.02,
    min_opacity: float = 0.01,
    jitter_scale: float = 0.25,
) -> dict[str, object]:
    """Fixed-budget MCMC-style relocation of underused Gaussians.

    This keeps ``N`` constant: low-opacity or never-visible rows are
    replaced by jittered copies of high-gradient rows. It is inspired by
    MCMC 3DGS relocation, and is intended as an alternative to vanilla
    clone/split/prune density control.
    """
    n = self.num_gaussians
    max_relocate = int(max(0, min(relocate_frac, 1.0)) * n)
    if max_relocate <= 0 or n <= 1:
        return {"relocated": 0}

    avg_grad = grad_accum / mx.maximum(grad_count, 1.0)
    opac = mx.sigmoid(self.params["opacities"])
    counts = grad_count
    underused = (opac < min_opacity) | (counts <= 0)
    underused_count = int(mx.sum(underused.astype(mx.int32)))
    k = max_relocate if underused_count == 0 else min(max_relocate, underused_count)
    k = min(k, n - 1)
    if k == 0:
        return {"relocated": 0}

    # Prefer underused rows. If none exist, fall back to the lowest-opacity
    target_score = mx.where(underused, 2.0 - opac, -opac)
    dst = mx.argpartition(-target_score, kth=k - 1)[:k]
    source_score = avg_grad.at[dst].add(-1e30 - avg_grad[dst])
    src = mx.argpartition(-source_score, kth=k - 1)[:k]
    source_scales = mx.exp(self.params["scales"][src])
    reset_opacity = float(np.log(0.05 / 0.95))
    split_shrink = float(np.log(1.6))
    for name, arr in self.params.items():
        values = arr[src]
        if name == "means" and jitter_scale > 0:
            values = values + mx.random.normal(values.shape) * source_scales * float(
                jitter_scale
            )
        elif name == "opacities":
            values = mx.full(values.shape, reset_opacity, dtype=values.dtype)
        elif name == "scales":
            values = values - split_shrink
        self.params[name] = arr.at[dst].add(values - arr[dst])
    return {"relocated": int(k), "_moved_idx": dst}

reset_opacities(max_opacity=0.01)

Clamp opacities down (periodic reset from the 3DGS paper).

Source code in src/mlx3d/splatting/model.py
533
534
535
536
537
538
def reset_opacities(self, max_opacity: float = 0.01) -> None:
    """Clamp opacities down (periodic reset from the 3DGS paper)."""
    logit = float(np.log(max_opacity / (1 - max_opacity)))
    self.params["opacities"] = mx.minimum(
        self.params["opacities"], mx.full(self.params["opacities"].shape, logit)
    )

GaussianTrainer

Optimizes a :class:GaussianModel against posed images.

Source code in src/mlx3d/splatting/trainer.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
class GaussianTrainer:
    """Optimizes a :class:`GaussianModel` against posed images."""

    def __init__(
        self, model: GaussianModel, config: TrainerConfig | None = None, scene_extent: float = 1.0
    ):
        self.model = model
        self.config = config or TrainerConfig()
        if self.config.method not in {"vanilla", "mcmc", "2dgs"}:
            raise ValueError("TrainerConfig.method must be 'vanilla', 'mcmc', or '2dgs'")
        if self.config.projection not in {"ewa", "ut"}:
            raise ValueError("TrainerConfig.projection must be 'ewa' or 'ut'")
        self.scene_extent = scene_extent
        self.step_count = 0
        if self.config.low_memory:
            # MLX caches freed GPU buffers for reuse; under shape churn
            # (densification changes N every 100 steps) the cache can grow
            # by gigabytes. Cap it, and clear at densification events below.
            mx.set_cache_limit(int(self.config.cache_limit_gb * (1 << 30)))
        self._apply_method_constraints()
        self._build_optimizers()
        self._reset_grad_accum()

    def _build_optimizers(self) -> None:
        c = self.config
        lr_means = c.lr_means * self.scene_extent
        lr_means_final = c.lr_means_final * self.scene_extent
        lr_means_max_steps = int(c.lr_means_max_steps)
        if lr_means_max_steps > 0 and lr_means > 0 and lr_means_final > 0:
            means_lr = optim.exponential_decay(
                lr_means,
                (lr_means_final / lr_means) ** (1.0 / lr_means_max_steps),
            )
            if self.step_count:
                offset = self.step_count

                def means_lr_with_offset(step, schedule=means_lr, offset=offset):
                    return schedule(step + offset)

                means_lr = means_lr_with_offset
        else:
            means_lr = lr_means
        lrs = {
            "means": means_lr,
            "scales": c.lr_scales,
            "quats": c.lr_quats,
            "opacities": c.lr_opacities,
            "sh_dc": c.lr_sh_dc,
            "sh_rest": c.lr_sh_rest,
        }
        self.optimizers = {k: optim.Adam(learning_rate=lr, eps=1e-15) for k, lr in lrs.items()}

    def learning_rates(self) -> dict[str, float]:
        """Current per-parameter Adam learning rates."""
        return {k: float(opt.learning_rate) for k, opt in self.optimizers.items()}

    def _reset_grad_accum(self) -> None:
        n = self.model.num_gaussians
        self.grad_accum = mx.zeros((n,), dtype=mx.float32)
        self.grad_count = mx.zeros((n,), dtype=mx.float32)

    def _resize_optimizer_states_after_densify(self, result: dict[str, object]) -> None:
        """Preserve Adam moments for surviving Gaussians after ADC changes N."""
        keep_idx_np = result.get("_keep_idx")
        new_count = int(result.get("_new_count", 0))
        if keep_idx_np is None:
            return
        keep_idx = mx.array(keep_idx_np)
        for name, opt in self.optimizers.items():
            state = opt.state.get(name)
            if not isinstance(state, dict):
                continue
            for slot in ("m", "v"):
                old = state.get(slot)
                if old is None:
                    continue
                kept = old[keep_idx]
                if new_count > 0:
                    zeros = mx.zeros((new_count, *old.shape[1:]), dtype=old.dtype)
                    state[slot] = mx.concatenate([kept, zeros], axis=0)
                else:
                    state[slot] = kept

    def _zero_optimizer_state_rows(self, idx_np) -> None:
        if idx_np is None:
            return
        if idx_np.shape[0] == 0:
            return
        idx = (
            idx_np.astype(mx.int32)
            if isinstance(idx_np, mx.array)
            else mx.array(np.asarray(idx_np, dtype=np.int32))
        )
        for name, opt in self.optimizers.items():
            state = opt.state.get(name)
            if not isinstance(state, dict):
                continue
            for slot in ("m", "v"):
                old = state.get(slot)
                if old is None:
                    continue
                state[slot] = old.at[idx].add(-old[idx])

    def _apply_method_constraints(self) -> None:
        if self.config.method == "2dgs":
            thickness = self.config.two_d_thickness * max(float(self.scene_extent), 1e-8)
            self.model.apply_2dgs_constraints(thickness)

    # ------------------------------------------------------------------ losses
    def _geometry_loss_2dgs(
        self,
        params: dict[str, mx.array],
        camera: Camera,
        proj: dict[str, mx.array],
        opacities: mx.array,
        sorted_ids: mx.array,
        tile_ranges: mx.array,
        tiles_x: int,
        tiles_y: int,
    ) -> tuple[mx.array, dict[str, mx.array]]:
        cfg = self.config
        zero = mx.array(0.0, dtype=mx.float32)
        metrics = {"depth_variance": zero, "normal_consistency": zero}
        if cfg.method != "2dgs" or (
            cfg.lambda_2d_depth_variance <= 0 and cfg.lambda_2d_normal_consistency <= 0
        ):
            return zero, metrics

        depths = proj["depths"]
        normals = quaternion_to_matrix(params["quats"])[:, :, 2]
        view = camera.camera_center - params["means"]
        normals = normals / mx.maximum(mx.linalg.norm(normals, axis=-1, keepdims=True), 1e-8)
        normals = mx.where(mx.sum(normals * view, axis=-1, keepdims=True) < 0, -normals, normals)
        features = mx.concatenate([depths[:, None], (depths * depths)[:, None], normals], axis=1)
        geom = rasterize_features(
            proj["means2d"],
            proj["conics"],
            features,
            opacities,
            sorted_ids,
            tile_ranges,
            camera.width,
            camera.height,
            tiles_x,
            tiles_y,
            normalize=True,
        )
        alpha = geom["alpha"]
        valid = (alpha > cfg.geometry_min_alpha).astype(mx.float32)
        denom = mx.maximum(valid.sum(), 1.0)

        expected_depth = geom["features"][..., 0]
        expected_depth2 = geom["features"][..., 1]
        depth_var = mx.maximum(expected_depth2 - expected_depth * expected_depth, 0.0)
        loss = zero
        if cfg.lambda_2d_depth_variance > 0:
            metrics["depth_variance"] = (depth_var * valid).sum() / denom
            loss = loss + float(cfg.lambda_2d_depth_variance) * metrics["depth_variance"]

        if cfg.lambda_2d_normal_consistency > 0 and camera.width > 1 and camera.height > 1:
            u = mx.arange(camera.width, dtype=mx.float32) + 0.5
            v = mx.arange(camera.height, dtype=mx.float32) + 0.5
            uu = mx.broadcast_to(u[None, :], (camera.height, camera.width))
            vv = mx.broadcast_to(v[:, None], (camera.height, camera.width))
            points = camera.unproject_points(mx.stack([uu, vv], axis=-1), expected_depth)
            tangent_x = points[:-1, 1:] - points[:-1, :-1]
            tangent_y = points[1:, :-1] - points[:-1, :-1]
            depth_normals = mx.linalg.cross(tangent_x, tangent_y)
            depth_normals = depth_normals / mx.maximum(
                mx.linalg.norm(depth_normals, axis=-1, keepdims=True), 1e-8
            )
            rendered_normals = geom["features"][:-1, :-1, 2:5]
            rendered_normals = rendered_normals / mx.maximum(
                mx.linalg.norm(rendered_normals, axis=-1, keepdims=True), 1e-8
            )
            valid4 = valid[:-1, :-1] * valid[:-1, 1:] * valid[1:, :-1] * valid[1:, 1:]
            denom4 = mx.maximum(valid4.sum(), 1.0)
            cos = mx.sum(rendered_normals * depth_normals, axis=-1)
            metrics["normal_consistency"] = ((1.0 - mx.abs(cos)) * valid4).sum() / denom4
            loss = loss + float(cfg.lambda_2d_normal_consistency) * metrics["normal_consistency"]

        return loss, metrics

    def _render_loss(
        self, params, means2d_probe, camera: Camera, target: mx.array, background: mx.array
    ):
        """Photometric loss; ``means2d_probe`` (zeros) exposes screen-space
        positional gradients for densification."""
        project = project_gaussians_ut if self.config.projection == "ut" else project_gaussians
        proj = project(
            camera,
            params["means"],
            params["quats"],
            mx.exp(params["scales"]),
            antialias=self.config.antialias,
        )
        means2d = proj["means2d"] + means2d_probe
        opacities = mx.sigmoid(params["opacities"]) * proj["compensation"]

        sh = mx.concatenate([params["sh_dc"], params["sh_rest"]], axis=1)
        dirs = params["means"] - camera.camera_center
        dirs = dirs / mx.maximum(mx.linalg.norm(dirs, axis=-1, keepdims=True), 1e-8)
        colors = mx.maximum(eval_sh(self.model.active_sh_degree, sh, mx.stop_gradient(dirs)), 0.0)

        sorted_ids, tile_ranges, tiles_x, tiles_y = bin_gaussians(
            means2d,
            proj["radii"],
            proj["depths"],
            camera.width,
            camera.height,
        )
        out = rasterize(
            means2d,
            proj["conics"],
            colors,
            opacities,
            sorted_ids,
            tile_ranges,
            camera.width,
            camera.height,
            tiles_x,
            tiles_y,
            background=background,
        )
        img = out["image"]
        l1 = mx.abs(img - target).mean()
        c = self.config.lambda_dssim
        loss = (1.0 - c) * l1 + c * (1.0 - ssim(img, target))
        geom_loss, geom_metrics = self._geometry_loss_2dgs(
            params,
            camera,
            proj,
            opacities,
            sorted_ids,
            tile_ranges,
            tiles_x,
            tiles_y,
        )
        return loss + geom_loss, (img, proj["radii"], geom_metrics)

    # -------------------------------------------------------------------- step
    def step(self, camera: Camera, target: mx.array) -> dict[str, object]:
        """One optimization step on a single view. Returns logging info."""
        self.step_count += 1
        cfg = self.config
        bg = mx.ones((3,)) if cfg.white_background else mx.zeros((3,))
        params = self.model.params
        probe = mx.zeros((self.model.num_gaussians, 2))
        densify_stats = None
        opacity_reset = False
        sh_degree_changed = False

        def loss_fn(params, probe):
            loss, aux = self._render_loss(params, probe, camera, target, bg)
            return loss, aux

        (loss, (img, radii, geom_metrics)), grads = mx.value_and_grad(loss_fn, argnums=(0, 1))(
            params, probe
        )
        param_grads, probe_grad = grads

        for k, opt in self.optimizers.items():
            self.model.params[k] = opt.apply_gradients(
                {k: param_grads[k]}, {k: self.model.params[k]}
            )[k]
        if cfg.method == "mcmc" and cfg.mcmc_noise_scale > 0:
            means_lr = max(float(self.learning_rates()["means"]), 0.0)
            sigma = cfg.mcmc_noise_scale * self.scene_extent * math.sqrt(means_lr)
            if sigma > 0:
                self.model.params["means"] = (
                    self.model.params["means"]
                    + mx.random.normal(self.model.params["means"].shape) * sigma
                )
        self._apply_method_constraints()
        # Release gradient references before evaluating so their buffers can
        # be recycled within the same step (see the MLX performance guide).
        del param_grads, grads
        mx.eval(self.model.params)

        # Accumulate NDC-space positional gradient norms for densification.
        if cfg.densify_from <= self.step_count <= cfg.densify_until:
            g = mx.stop_gradient(probe_grad)
            gx = g[:, 0] * (camera.width / 2.0)
            gy = g[:, 1] * (camera.height / 2.0)
            norms = mx.sqrt(gx * gx + gy * gy)
            visible = (mx.stop_gradient(radii) > 0).astype(mx.float32)
            self.grad_accum += norms * visible
            self.grad_count += visible

            if self.step_count % cfg.densify_every == 0 and self.step_count > cfg.densify_from:
                if cfg.method == "mcmc":
                    densify_result = self.model.relocate_mcmc(
                        self.grad_accum,
                        self.grad_count,
                        relocate_frac=cfg.mcmc_relocate_frac,
                        min_opacity=cfg.mcmc_min_opacity,
                        jitter_scale=cfg.mcmc_jitter_scale,
                    )
                    self._zero_optimizer_state_rows(densify_result.get("_moved_idx"))
                else:
                    # At the cap, disable growth (infinite threshold) but keep
                    # pruning so the count can recover below the cap.
                    at_cap = (
                        cfg.max_gaussians is not None
                        and self.model.num_gaussians >= cfg.max_gaussians
                    )
                    threshold = float("inf") if at_cap else cfg.densify_grad_threshold
                    densify_result = self.model.densify_and_prune(
                        self.grad_accum,
                        self.grad_count,
                        grad_threshold=threshold,
                        scene_extent=self.scene_extent,
                        return_optimizer_state=True,
                    )
                    self._resize_optimizer_states_after_densify(densify_result)
                self._apply_method_constraints()
                densify_stats = {
                    k: v for k, v in densify_result.items() if not str(k).startswith("_")
                }
                self._reset_grad_accum()
                if cfg.low_memory:
                    mx.eval(self.model.params)
                    mx.clear_cache()

        if self.step_count % cfg.opacity_reset_every == 0 and self.step_count <= cfg.densify_until:
            self.model.reset_opacities()
            opacity_reset = True

        if self.step_count % cfg.sh_increase_every == 0:
            self.model.one_up_sh_degree()
            sh_degree_changed = True

        return {
            "loss": float(loss),
            "num_gaussians": self.model.num_gaussians,
            "step": self.step_count,
            "active_sh_degree": self.model.active_sh_degree,
            "lr_means": self.learning_rates()["means"],
            "method": cfg.method,
            "densify": densify_stats,
            "opacity_reset": opacity_reset,
            "sh_degree_changed": sh_degree_changed,
            "geometry": {k: float(v) for k, v in geom_metrics.items()},
        }

learning_rates()

Current per-parameter Adam learning rates.

Source code in src/mlx3d/splatting/trainer.py
130
131
132
def learning_rates(self) -> dict[str, float]:
    """Current per-parameter Adam learning rates."""
    return {k: float(opt.learning_rate) for k, opt in self.optimizers.items()}

step(camera, target)

One optimization step on a single view. Returns logging info.

Source code in src/mlx3d/splatting/trainer.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
def step(self, camera: Camera, target: mx.array) -> dict[str, object]:
    """One optimization step on a single view. Returns logging info."""
    self.step_count += 1
    cfg = self.config
    bg = mx.ones((3,)) if cfg.white_background else mx.zeros((3,))
    params = self.model.params
    probe = mx.zeros((self.model.num_gaussians, 2))
    densify_stats = None
    opacity_reset = False
    sh_degree_changed = False

    def loss_fn(params, probe):
        loss, aux = self._render_loss(params, probe, camera, target, bg)
        return loss, aux

    (loss, (img, radii, geom_metrics)), grads = mx.value_and_grad(loss_fn, argnums=(0, 1))(
        params, probe
    )
    param_grads, probe_grad = grads

    for k, opt in self.optimizers.items():
        self.model.params[k] = opt.apply_gradients(
            {k: param_grads[k]}, {k: self.model.params[k]}
        )[k]
    if cfg.method == "mcmc" and cfg.mcmc_noise_scale > 0:
        means_lr = max(float(self.learning_rates()["means"]), 0.0)
        sigma = cfg.mcmc_noise_scale * self.scene_extent * math.sqrt(means_lr)
        if sigma > 0:
            self.model.params["means"] = (
                self.model.params["means"]
                + mx.random.normal(self.model.params["means"].shape) * sigma
            )
    self._apply_method_constraints()
    # Release gradient references before evaluating so their buffers can
    # be recycled within the same step (see the MLX performance guide).
    del param_grads, grads
    mx.eval(self.model.params)

    # Accumulate NDC-space positional gradient norms for densification.
    if cfg.densify_from <= self.step_count <= cfg.densify_until:
        g = mx.stop_gradient(probe_grad)
        gx = g[:, 0] * (camera.width / 2.0)
        gy = g[:, 1] * (camera.height / 2.0)
        norms = mx.sqrt(gx * gx + gy * gy)
        visible = (mx.stop_gradient(radii) > 0).astype(mx.float32)
        self.grad_accum += norms * visible
        self.grad_count += visible

        if self.step_count % cfg.densify_every == 0 and self.step_count > cfg.densify_from:
            if cfg.method == "mcmc":
                densify_result = self.model.relocate_mcmc(
                    self.grad_accum,
                    self.grad_count,
                    relocate_frac=cfg.mcmc_relocate_frac,
                    min_opacity=cfg.mcmc_min_opacity,
                    jitter_scale=cfg.mcmc_jitter_scale,
                )
                self._zero_optimizer_state_rows(densify_result.get("_moved_idx"))
            else:
                # At the cap, disable growth (infinite threshold) but keep
                # pruning so the count can recover below the cap.
                at_cap = (
                    cfg.max_gaussians is not None
                    and self.model.num_gaussians >= cfg.max_gaussians
                )
                threshold = float("inf") if at_cap else cfg.densify_grad_threshold
                densify_result = self.model.densify_and_prune(
                    self.grad_accum,
                    self.grad_count,
                    grad_threshold=threshold,
                    scene_extent=self.scene_extent,
                    return_optimizer_state=True,
                )
                self._resize_optimizer_states_after_densify(densify_result)
            self._apply_method_constraints()
            densify_stats = {
                k: v for k, v in densify_result.items() if not str(k).startswith("_")
            }
            self._reset_grad_accum()
            if cfg.low_memory:
                mx.eval(self.model.params)
                mx.clear_cache()

    if self.step_count % cfg.opacity_reset_every == 0 and self.step_count <= cfg.densify_until:
        self.model.reset_opacities()
        opacity_reset = True

    if self.step_count % cfg.sh_increase_every == 0:
        self.model.one_up_sh_degree()
        sh_degree_changed = True

    return {
        "loss": float(loss),
        "num_gaussians": self.model.num_gaussians,
        "step": self.step_count,
        "active_sh_degree": self.model.active_sh_degree,
        "lr_means": self.learning_rates()["means"],
        "method": cfg.method,
        "densify": densify_stats,
        "opacity_reset": opacity_reset,
        "sh_degree_changed": sh_degree_changed,
        "geometry": {k: float(v) for k, v in geom_metrics.items()},
    }

TrainerConfig dataclass

Source code in src/mlx3d/splatting/trainer.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@dataclass
class TrainerConfig:
    method: str = "vanilla"
    """Training strategy: ``vanilla``, ``mcmc``, or ``2dgs``."""
    lr_means: float = 1.6e-4
    lr_means_final: float = 1.6e-6
    lr_means_max_steps: int = 30_000
    lr_scales: float = 5e-3
    lr_quats: float = 1e-3
    lr_opacities: float = 5e-2
    lr_sh_dc: float = 2.5e-3
    lr_sh_rest: float = 2.5e-3 / 20.0
    lambda_dssim: float = 0.2
    antialias: bool = False
    """Use Mip-Splatting-style opacity compensation for projection blur."""
    projection: str = "ewa"
    """Gaussian projection method: ``ewa`` or 3DGUT-style ``ut``."""
    lambda_2d_depth_variance: float = 0.0
    """2DGS-only ray-depth variance penalty; encourages one surface per pixel."""
    lambda_2d_normal_consistency: float = 0.0
    """2DGS-only loss between rendered surfel normals and depth-map normals."""
    geometry_min_alpha: float = 0.05
    """Minimum alpha for pixels included in 2DGS geometry losses."""
    # Adaptive density control.
    densify_from: int = 500
    densify_until: int = 15000
    densify_every: int = 100
    densify_grad_threshold: float = 0.0002
    mcmc_relocate_frac: float = 0.02
    """Max fraction of Gaussians relocated at each MCMC density event."""
    mcmc_min_opacity: float = 0.01
    """Rows below this opacity are considered relocation targets in MCMC mode."""
    mcmc_jitter_scale: float = 0.25
    """Relocated copies are jittered by this multiple of their source scale."""
    mcmc_noise_scale: float = 0.01
    """Per-step SGLD-like xyz noise scale used in MCMC mode."""
    two_d_thickness: float = 1e-4
    """2DGS local-normal thickness as a fraction of scene extent."""
    opacity_reset_every: int = 3000
    sh_increase_every: int = 1000
    white_background: bool = False
    # Low-memory controls.
    max_gaussians: int | None = None
    """Stop clone/split growth above this count (pruning continues)."""
    low_memory: bool = False
    """Cap MLX's buffer cache and clear it after densification events.
    Reduces peak memory noticeably on 8-16 GB machines at a small speed cost."""
    cache_limit_gb: float = 2.0
    """MLX buffer-cache cap used when ``low_memory`` is enabled."""

method = 'vanilla' class-attribute instance-attribute

Training strategy: vanilla, mcmc, or 2dgs.

antialias = False class-attribute instance-attribute

Use Mip-Splatting-style opacity compensation for projection blur.

projection = 'ewa' class-attribute instance-attribute

Gaussian projection method: ewa or 3DGUT-style ut.

lambda_2d_depth_variance = 0.0 class-attribute instance-attribute

2DGS-only ray-depth variance penalty; encourages one surface per pixel.

lambda_2d_normal_consistency = 0.0 class-attribute instance-attribute

2DGS-only loss between rendered surfel normals and depth-map normals.

geometry_min_alpha = 0.05 class-attribute instance-attribute

Minimum alpha for pixels included in 2DGS geometry losses.

mcmc_relocate_frac = 0.02 class-attribute instance-attribute

Max fraction of Gaussians relocated at each MCMC density event.

mcmc_min_opacity = 0.01 class-attribute instance-attribute

Rows below this opacity are considered relocation targets in MCMC mode.

mcmc_jitter_scale = 0.25 class-attribute instance-attribute

Relocated copies are jittered by this multiple of their source scale.

mcmc_noise_scale = 0.01 class-attribute instance-attribute

Per-step SGLD-like xyz noise scale used in MCMC mode.

two_d_thickness = 0.0001 class-attribute instance-attribute

2DGS local-normal thickness as a fraction of scene extent.

max_gaussians = None class-attribute instance-attribute

Stop clone/split growth above this count (pruning continues).

low_memory = False class-attribute instance-attribute

Cap MLX's buffer cache and clear it after densification events. Reduces peak memory noticeably on 8-16 GB machines at a small speed cost.

cache_limit_gb = 2.0 class-attribute instance-attribute

MLX buffer-cache cap used when low_memory is enabled.

project_gaussians(camera, means, quats, scales, blur=0.3, antialias=False)

Project 3D Gaussians into screen space.

Parameters:

Name Type Description Default
camera Camera

the viewing :class:~mlx3d.cameras.Camera.

required
means array

(N, 3) Gaussian centers in world space.

required
quats array

(N, 4) rotations (w, x, y, z), need not be normalized.

required
scales array

(N, 3) per-axis standard deviations.

required
blur float

screen-space dilation added to the diagonal (0.3 px as in 3DGS, which guarantees splats cover at least about one pixel).

0.3
antialias bool

if True, also return Mip-Splatting-style opacity compensation for the added screen-space blur. The conic still uses the blurred covariance, while compensation scales opacity by sqrt(det(cov) / det(cov + blur I)) so subpixel Gaussians do not gain energy from the low-pass filter.

False

Returns:

Type Description
dict[str, array]

dict with: - means2d (N, 2): pixel-space centers. - conics (N, 3): upper-triangular inverse 2D covariance (a, b, c). - depths (N,): camera-space z. - radii (N,): conservative pixel radii (0 for culled Gaussians). - compensation (N,): opacity multiplier for anti-aliased mode, otherwise all ones.

Source code in src/mlx3d/splatting/projection.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def project_gaussians(
    camera: Camera,
    means: mx.array,
    quats: mx.array,
    scales: mx.array,
    blur: float = 0.3,
    antialias: bool = False,
) -> dict[str, mx.array]:
    """Project 3D Gaussians into screen space.

    Args:
        camera: the viewing :class:`~mlx3d.cameras.Camera`.
        means: (N, 3) Gaussian centers in world space.
        quats: (N, 4) rotations (w, x, y, z), need not be normalized.
        scales: (N, 3) per-axis standard deviations.
        blur: screen-space dilation added to the diagonal (0.3 px as in 3DGS,
            which guarantees splats cover at least about one pixel).
        antialias: if ``True``, also return Mip-Splatting-style opacity
            compensation for the added screen-space blur. The conic still uses
            the blurred covariance, while ``compensation`` scales opacity by
            ``sqrt(det(cov) / det(cov + blur I))`` so subpixel Gaussians do not
            gain energy from the low-pass filter.

    Returns:
        dict with:
            - ``means2d`` (N, 2): pixel-space centers.
            - ``conics`` (N, 3): upper-triangular inverse 2D covariance (a, b, c).
            - ``depths`` (N,): camera-space z.
            - ``radii`` (N,): conservative pixel radii (0 for culled Gaussians).
            - ``compensation`` (N,): opacity multiplier for anti-aliased mode,
              otherwise all ones.
    """
    R, t = camera.R, camera.t
    mx_, my_, mz_ = means[:, 0], means[:, 1], means[:, 2]
    x = mx_ * R[0, 0] + my_ * R[0, 1] + mz_ * R[0, 2] + t[0]
    y = mx_ * R[1, 0] + my_ * R[1, 1] + mz_ * R[1, 2] + t[1]
    z = mx_ * R[2, 0] + my_ * R[2, 1] + mz_ * R[2, 2] + t[2]
    z_safe = mx.maximum(z, 1e-6)

    # Project centers.
    u = camera.fx * x / z_safe + camera.cx
    v = camera.fy * y / z_safe + camera.cy
    means2d = mx.stack([u, v], axis=-1)

    # EWA Jacobian, with x/z, y/z clamped to a slightly padded frustum for
    # stability of far off-screen Gaussians (as in the reference CUDA code).
    tan_fov_x = 0.5 * camera.width / camera.fx
    tan_fov_y = 0.5 * camera.height / camera.fy
    tx = mx.clip(x / z_safe, -1.3 * tan_fov_x, 1.3 * tan_fov_x) * z_safe
    ty = mx.clip(y / z_safe, -1.3 * tan_fov_y, 1.3 * tan_fov_y) * z_safe

    inv_z = 1.0 / z_safe
    inv_z2 = inv_z * inv_z

    # Project the rotated/scaled covariance without materializing per-Gaussian
    # 3x3 and 2x3 matrices. Let T = J @ camera.R and M = R_quat @ diag(scales).
    # The screen covariance is (T @ M) @ (T @ M)^T, so only six scalar dot
    # products are needed per Gaussian.
    j00 = camera.fx * inv_z
    j02 = -camera.fx * tx * inv_z2
    j11 = camera.fy * inv_z
    j12 = -camera.fy * ty * inv_z2

    t00 = j00 * R[0, 0] + j02 * R[2, 0]
    t01 = j00 * R[0, 1] + j02 * R[2, 1]
    t02 = j00 * R[0, 2] + j02 * R[2, 2]
    t10 = j11 * R[1, 0] + j12 * R[2, 0]
    t11 = j11 * R[1, 1] + j12 * R[2, 1]
    t12 = j11 * R[1, 2] + j12 * R[2, 2]

    q = quats / mx.linalg.norm(quats, axis=-1, keepdims=True)
    qw, qx, qy, qz = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
    two = 2.0
    r00 = 1.0 - two * (qy * qy + qz * qz)
    r01 = two * (qx * qy - qw * qz)
    r02 = two * (qx * qz + qw * qy)
    r10 = two * (qx * qy + qw * qz)
    r11 = 1.0 - two * (qx * qx + qz * qz)
    r12 = two * (qy * qz - qw * qx)
    r20 = two * (qx * qz - qw * qy)
    r21 = two * (qy * qz + qw * qx)
    r22 = 1.0 - two * (qx * qx + qy * qy)

    sx, sy, sz = scales[:, 0], scales[:, 1], scales[:, 2]
    m00 = (t00 * r00 + t01 * r10 + t02 * r20) * sx
    m01 = (t00 * r01 + t01 * r11 + t02 * r21) * sy
    m02 = (t00 * r02 + t01 * r12 + t02 * r22) * sz
    m10 = (t10 * r00 + t11 * r10 + t12 * r20) * sx
    m11 = (t10 * r01 + t11 * r11 + t12 * r21) * sy
    m12 = (t10 * r02 + t11 * r12 + t12 * r22) * sz

    a0 = m00 * m00 + m01 * m01 + m02 * m02
    b = m00 * m10 + m01 * m11 + m02 * m12
    c0 = m10 * m10 + m11 * m11 + m12 * m12

    a = a0 + blur
    c = c0 + blur

    det = a * c - b * b
    det_safe = mx.maximum(det, 1e-12)
    conics = mx.stack([c / det_safe, -b / det_safe, a / det_safe], axis=-1)
    if antialias and blur > 0:
        det0 = mx.maximum(a0 * c0 - b * b, 0.0)
        compensation = mx.sqrt(det0 / det_safe)
    else:
        compensation = mx.ones_like(det)

    # Conservative radius: 3 sigma of the larger eigenvalue.
    mid = 0.5 * (a + c)
    lam1 = mid + mx.sqrt(mx.maximum(mid * mid - det, 0.01))
    radii = mx.ceil(3.0 * mx.sqrt(mx.maximum(lam1, 0.0)))

    valid = (z > camera.znear) & (det > 0)
    radii = mx.where(valid, radii, mx.zeros_like(radii))
    compensation = mx.where(valid, compensation, mx.zeros_like(compensation))

    return {
        "means2d": means2d,
        "conics": conics,
        "depths": z,
        "radii": radii,
        "compensation": compensation,
    }

project_gaussians_ut(camera, means, quats, scales, blur=0.3, antialias=False)

Project Gaussians with a 3D Unscented Transform.

This 3DGUT-style projection sends six cubature sigma points through the camera's actual :meth:~mlx3d.cameras.Camera.project_points method, so Brown-Conrady distortion, fisheye distortion, and orthographic cameras are handled by the same camera model used elsewhere in the library. It is more expensive than the analytic EWA projection and is therefore opt-in.

Source code in src/mlx3d/splatting/projection.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def project_gaussians_ut(
    camera: Camera,
    means: mx.array,
    quats: mx.array,
    scales: mx.array,
    blur: float = 0.3,
    antialias: bool = False,
) -> dict[str, mx.array]:
    """Project Gaussians with a 3D Unscented Transform.

    This 3DGUT-style projection sends six cubature sigma points through the
    camera's actual :meth:`~mlx3d.cameras.Camera.project_points` method, so
    Brown-Conrady distortion, fisheye distortion, and orthographic cameras are
    handled by the same camera model used elsewhere in the library. It is more
    expensive than the analytic EWA projection and is therefore opt-in.
    """
    r = quaternion_to_matrix(quats)
    sqrt_cov = r * scales[:, None, :]
    offsets = mx.concatenate([sqrt_cov.swapaxes(1, 2), -sqrt_cov.swapaxes(1, 2)], axis=1)
    sigma = means[:, None, :] + mx.sqrt(mx.array(3.0, dtype=means.dtype)) * offsets

    xy, _ = camera.project_points(sigma.reshape(-1, 3))
    xy = xy.reshape(means.shape[0], 6, 2)
    means2d = xy.mean(axis=1)
    d = xy - means2d[:, None, :]
    a0 = mx.mean(d[..., 0] * d[..., 0], axis=1)
    b = mx.mean(d[..., 0] * d[..., 1], axis=1)
    c0 = mx.mean(d[..., 1] * d[..., 1], axis=1)

    a = a0 + blur
    c = c0 + blur
    det = a * c - b * b
    det_safe = mx.maximum(det, 1e-12)
    conics = mx.stack([c / det_safe, -b / det_safe, a / det_safe], axis=-1)
    if antialias and blur > 0:
        det0 = mx.maximum(a0 * c0 - b * b, 0.0)
        compensation = mx.sqrt(det0 / det_safe)
    else:
        compensation = mx.ones_like(det)

    mid = 0.5 * (a + c)
    lam1 = mid + mx.sqrt(mx.maximum(mid * mid - det, 0.01))
    radii = mx.ceil(3.0 * mx.sqrt(mx.maximum(lam1, 0.0)))

    center_cam = camera.world_to_camera(means)
    z = center_cam[:, 2]
    valid = (z > camera.znear) & (det > 0)
    radii = mx.where(valid, radii, mx.zeros_like(radii))
    compensation = mx.where(valid, compensation, mx.zeros_like(compensation))

    return {
        "means2d": means2d,
        "conics": conics,
        "depths": z,
        "radii": radii,
        "compensation": compensation,
    }

quat_scale_to_cov3d(quats, scales)

Build 3D covariances R S S^T R^T from quaternions (N, 4) and scales (N, 3).

Source code in src/mlx3d/splatting/projection.py
17
18
19
20
21
def quat_scale_to_cov3d(quats: mx.array, scales: mx.array) -> mx.array:
    """Build 3D covariances ``R S S^T R^T`` from quaternions (N, 4) and scales (N, 3)."""
    R = quaternion_to_matrix(quats)
    M = R * scales[:, None, :]  # scale the columns: R @ diag(s)
    return M @ M.swapaxes(-1, -2)

rasterize_depth(means2d, conics, opacities, depths, sorted_ids, tile_ranges, width, height, tiles_x, tiles_y)

Forward-only expected depth rasterization for viewer/diagnostic modes.

Source code in src/mlx3d/splatting/rasterize.py
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
def rasterize_depth(
    means2d: mx.array,
    conics: mx.array,
    opacities: mx.array,
    depths: mx.array,
    sorted_ids: mx.array,
    tile_ranges: mx.array,
    width: int,
    height: int,
    tiles_x: int,
    tiles_y: int,
) -> dict[str, mx.array]:
    """Forward-only expected depth rasterization for viewer/diagnostic modes."""
    params = mx.array([width, height, tiles_x, tiles_y], dtype=mx.int32)
    depth, final_T = _depth_kernel(
        inputs=[
            means2d.astype(mx.float32),
            conics.astype(mx.float32),
            opacities.astype(mx.float32),
            depths.astype(mx.float32),
            sorted_ids,
            tile_ranges,
            params,
        ],
        output_shapes=[(height, width), (height, width)],
        output_dtypes=[mx.float32, mx.float32],
        grid=(tiles_x * TILE_SIZE, tiles_y * TILE_SIZE, 1),
        threadgroup=(TILE_SIZE, TILE_SIZE, 1),
        init_value=0,
    )
    return {"depth": depth, "alpha": 1.0 - final_T, "final_T": final_T}

rasterize_features(means2d, conics, features, opacities, sorted_ids, tile_ranges, width, height, tiles_x, tiles_y, background=None, normalize=False)

Alpha-composite arbitrary per-Gaussian feature channels.

This reuses the optimized RGB forward/backward Metal kernels in chunks of three channels. It is intentionally memory-bounded: only one (H, W, 3) output is materialized per chunk, while MLX autodiff sums gradients from each chunk back into the shared Gaussian parameters.

Parameters:

Name Type Description Default
features array

(N, C) per-Gaussian features.

required
background array | None

optional (C,) feature background. Ignored when normalize=True because normalized features are expected values over accumulated alpha, not composited background values.

None
normalize bool

if True, return expected features sum_i alpha_i T_i feature_i / alpha with zeros where alpha == 0. This is useful for depth, normals, embeddings, and semantic logits. If False, return the usual alpha-composited feature buffer including background.

False

Returns:

Type Description
dict[str, array]

dict with features (H, W, C), alpha (H, W), and the saved

dict[str, array]

auxiliary buffers final_T / n_contrib from the first chunk.

Source code in src/mlx3d/splatting/rasterize.py
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
def rasterize_features(
    means2d: mx.array,
    conics: mx.array,
    features: mx.array,
    opacities: mx.array,
    sorted_ids: mx.array,
    tile_ranges: mx.array,
    width: int,
    height: int,
    tiles_x: int,
    tiles_y: int,
    background: mx.array | None = None,
    normalize: bool = False,
) -> dict[str, mx.array]:
    """Alpha-composite arbitrary per-Gaussian feature channels.

    This reuses the optimized RGB forward/backward Metal kernels in chunks of
    three channels. It is intentionally memory-bounded: only one ``(H, W, 3)``
    output is materialized per chunk, while MLX autodiff sums gradients from
    each chunk back into the shared Gaussian parameters.

    Args:
        features: ``(N, C)`` per-Gaussian features.
        background: optional ``(C,)`` feature background. Ignored when
            ``normalize=True`` because normalized features are expected values
            over accumulated alpha, not composited background values.
        normalize: if ``True``, return expected features
            ``sum_i alpha_i T_i feature_i / alpha`` with zeros where
            ``alpha == 0``. This is useful for depth, normals, embeddings, and
            semantic logits. If ``False``, return the usual alpha-composited
            feature buffer including ``background``.

    Returns:
        dict with ``features`` (H, W, C), ``alpha`` (H, W), and the saved
        auxiliary buffers ``final_T`` / ``n_contrib`` from the first chunk.
    """
    if features.ndim != 2:
        raise ValueError("features must have shape (N, C).")
    if features.shape[0] != means2d.shape[0]:
        raise ValueError("features and means2d must have the same first dimension.")
    channels = int(features.shape[1])
    if channels <= 0:
        raise ValueError("features must contain at least one channel.")

    if background is None or normalize:
        bg = mx.zeros((channels,), dtype=features.dtype)
    else:
        bg = mx.array(background, dtype=features.dtype)
        if bg.ndim == 0:
            bg = mx.broadcast_to(bg, (channels,))
        if bg.shape != (channels,):
            raise ValueError(f"background must have shape ({channels},).")

    chunks: list[mx.array] = []
    first: dict[str, mx.array] | None = None
    for start in range(0, channels, 3):
        end = min(start + 3, channels)
        take = end - start
        feat_chunk = features[:, start:end]
        bg_chunk = bg[start:end]
        if take < 3:
            pad = 3 - take
            feat_chunk = mx.concatenate(
                [feat_chunk, mx.zeros((features.shape[0], pad), dtype=features.dtype)],
                axis=1,
            )
            bg_chunk = mx.concatenate([bg_chunk, mx.zeros((pad,), dtype=features.dtype)], axis=0)

        out = rasterize(
            means2d,
            conics,
            feat_chunk,
            opacities,
            sorted_ids,
            tile_ranges,
            width,
            height,
            tiles_x,
            tiles_y,
            background=bg_chunk,
        )
        if first is None:
            first = out
        chunks.append(out["image"][..., :take])

    assert first is not None
    feat = mx.concatenate(chunks, axis=-1)
    alpha = first["alpha"]
    if normalize:
        feat = mx.where(alpha[..., None] > 1e-6, feat / mx.maximum(alpha[..., None], 1e-6), 0.0)

    return {
        "features": feat,
        "alpha": alpha,
        "final_T": first["final_T"],
        "n_contrib": first["n_contrib"],
    }

render_gaussians_reference(camera, means, quats, scales, opacities, colors, background=None, antialias=False)

Reference renderer matching :func:mlx3d.splatting.render_gaussians.

Source code in src/mlx3d/splatting/reference.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def render_gaussians_reference(
    camera: Camera,
    means: mx.array,
    quats: mx.array,
    scales: mx.array,
    opacities: mx.array,
    colors: mx.array,
    background: mx.array | None = None,
    antialias: bool = False,
) -> dict[str, mx.array]:
    """Reference renderer matching :func:`mlx3d.splatting.render_gaussians`."""
    if background is None:
        background = mx.zeros((3,))
    H, W = camera.height, camera.width

    proj = project_gaussians(camera, means, quats, scales, antialias=antialias)
    means2d, conics, depths, radii = (
        proj["means2d"],
        proj["conics"],
        proj["depths"],
        proj["radii"],
    )

    # Depth-sort all Gaussians (global front-to-back order).
    order = mx.argsort(mx.stop_gradient(depths))
    means2d = means2d[order]
    conics = conics[order]
    colors = colors[order]
    opac = (opacities * proj["compensation"])[order]
    radii = radii[order]

    xs = mx.arange(W, dtype=mx.float32) + 0.5
    ys = mx.arange(H, dtype=mx.float32) + 0.5
    px = mx.broadcast_to(xs[None, :], (H, W))
    py = mx.broadcast_to(ys[:, None], (H, W))

    dx = means2d[:, 0][:, None, None] - px[None]  # (N, H, W)
    dy = means2d[:, 1][:, None, None] - py[None]
    a = conics[:, 0][:, None, None]
    b = conics[:, 1][:, None, None]
    c = conics[:, 2][:, None, None]
    power = -0.5 * (a * dx * dx + c * dy * dy) - b * dx * dy
    alpha = mx.minimum(opac[:, None, None] * mx.exp(power), 0.99)
    # Match the kernel cutoffs exactly.
    alpha = mx.where(power > 0.0, mx.zeros_like(alpha), alpha)
    alpha = mx.where(alpha < 1.0 / 255.0, mx.zeros_like(alpha), alpha)
    alpha = alpha * (radii > 0)[:, None, None]

    one_minus = 1.0 - alpha
    trans = mx.cumprod(one_minus, axis=0)
    trans = mx.concatenate([mx.ones_like(trans[:1]), trans[:-1]], axis=0)
    # Early-termination threshold of the kernel: stop once T drops below 1e-4.
    keep = trans > 1e-4
    weights = alpha * trans * keep  # (N, H, W)

    image = mx.sum(weights[..., None] * colors[:, None, None, :], axis=0)
    final_T = mx.prod(mx.where(keep, one_minus, mx.ones_like(one_minus)), axis=0)
    image = image + final_T[..., None] * background
    return {"image": image, "alpha": 1.0 - final_T, "final_T": final_T}

render_gaussian_depth(camera, means, quats, scales, opacities, refine_tiles=False, antialias=False, projection='ewa')

Render expected depth and alpha from 3D Gaussians.

This is a forward-only diagnostic/viewer path. It uses the same projection, tile binning, and alpha compositing math as RGB splatting, but accumulates transmittance-weighted depth instead of color.

Source code in src/mlx3d/splatting/render.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def render_gaussian_depth(
    camera: Camera,
    means: mx.array,
    quats: mx.array,
    scales: mx.array,
    opacities: mx.array,
    refine_tiles: bool = False,
    antialias: bool = False,
    projection: str = "ewa",
) -> dict[str, mx.array]:
    """Render expected depth and alpha from 3D Gaussians.

    This is a forward-only diagnostic/viewer path. It uses the same projection,
    tile binning, and alpha compositing math as RGB splatting, but accumulates
    transmittance-weighted depth instead of color.
    """
    proj = _project(camera, means, quats, scales, projection, antialias)
    opacities = opacities * proj["compensation"]
    sorted_ids, tile_ranges, tiles_x, tiles_y = bin_gaussians(
        proj["means2d"],
        proj["radii"],
        proj["depths"],
        camera.width,
        camera.height,
        conics=proj["conics"] if refine_tiles else None,
    )
    out = rasterize_depth(
        proj["means2d"],
        proj["conics"],
        opacities,
        proj["depths"],
        sorted_ids,
        tile_ranges,
        camera.width,
        camera.height,
        tiles_x,
        tiles_y,
    )
    out.update(
        {
            "means2d": proj["means2d"],
            "depths": proj["depths"],
            "radii": proj["radii"],
            "compensation": proj["compensation"],
        }
    )
    return out

render_gaussian_features(camera, means, quats, scales, opacities, features, background=None, normalize=False, refine_tiles=False, antialias=False, projection='ewa')

Render arbitrary per-Gaussian feature channels.

render_gaussians is specialized for RGB color. This function exposes the same projection/binning/rasterization path for any (N, C) feature tensor: depth-like scalars, normals, semantic logits, learned embeddings, or auxiliary training buffers.

Set normalize=True to return expected features divided by accumulated alpha, matching the expected-depth convention. Leave it False for ordinary alpha compositing with an optional feature-space background.

Source code in src/mlx3d/splatting/render.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def render_gaussian_features(
    camera: Camera,
    means: mx.array,
    quats: mx.array,
    scales: mx.array,
    opacities: mx.array,
    features: mx.array,
    background: mx.array | None = None,
    normalize: bool = False,
    refine_tiles: bool = False,
    antialias: bool = False,
    projection: str = "ewa",
) -> dict[str, mx.array]:
    """Render arbitrary per-Gaussian feature channels.

    ``render_gaussians`` is specialized for RGB color. This function exposes
    the same projection/binning/rasterization path for any ``(N, C)`` feature
    tensor: depth-like scalars, normals, semantic logits, learned embeddings,
    or auxiliary training buffers.

    Set ``normalize=True`` to return expected features divided by accumulated
    alpha, matching the expected-depth convention. Leave it ``False`` for
    ordinary alpha compositing with an optional feature-space background.
    """
    proj = _project(camera, means, quats, scales, projection, antialias)
    opacities = opacities * proj["compensation"]
    sorted_ids, tile_ranges, tiles_x, tiles_y = bin_gaussians(
        proj["means2d"],
        proj["radii"],
        proj["depths"],
        camera.width,
        camera.height,
        conics=proj["conics"] if refine_tiles else None,
    )
    out = rasterize_features(
        proj["means2d"],
        proj["conics"],
        features,
        opacities,
        sorted_ids,
        tile_ranges,
        camera.width,
        camera.height,
        tiles_x,
        tiles_y,
        background=background,
        normalize=normalize,
    )
    out.update(
        {
            "means2d": proj["means2d"],
            "depths": proj["depths"],
            "radii": proj["radii"],
            "compensation": proj["compensation"],
        }
    )
    return out

render_gaussians(camera, means, quats, scales, opacities, colors=None, sh=None, sh_degree=3, background=None, refine_tiles=False, antialias=False, projection='ewa')

Render 3D Gaussians from a camera. Differentiable end to end.

Parameters:

Name Type Description Default
camera Camera

viewing camera.

required
means array

(N, 3) Gaussian centers.

required
quats array

(N, 4) rotations (w, x, y, z).

required
scales array

(N, 3) per-axis standard deviations (positive; apply your activation, e.g. mx.exp, before calling).

required
opacities array

(N,) in [0, 1] (apply sigmoid before calling).

required
colors array | None

(N, 3) RGB in [0, 1]. Mutually exclusive with sh.

None
sh array | None

(N, K, 3) spherical-harmonic coefficients (K >= (sh_degree+1)^2); view-dependent color is evaluated per Gaussian toward the camera.

None
background array | None

(3,) background color (default black).

None
refine_tiles bool

experimental conservative ellipse/tile rejection after square radius binning. Off by default because its extra MLX work is not faster on all scenes.

False
antialias bool

enable Mip-Splatting-style opacity compensation for the projection blur. This reduces over-bright subpixel splats while preserving the existing 3DGS-compatible behavior by default.

False
projection str

"ewa" for the fast analytic pinhole projection or "ut" for a 3DGUT-style Unscented Transform projection through the full camera model, including distortion and fisheye.

'ewa'

Returns:

Type Description
dict[str, array]

dict with image (H, W, 3), alpha (H, W), plus the projection

dict[str, array]

outputs (means2d, depths, radii) for densification

dict[str, array]

bookkeeping.

Source code in src/mlx3d/splatting/render.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def render_gaussians(
    camera: Camera,
    means: mx.array,
    quats: mx.array,
    scales: mx.array,
    opacities: mx.array,
    colors: mx.array | None = None,
    sh: mx.array | None = None,
    sh_degree: int = 3,
    background: mx.array | None = None,
    refine_tiles: bool = False,
    antialias: bool = False,
    projection: str = "ewa",
) -> dict[str, mx.array]:
    """Render 3D Gaussians from a camera. Differentiable end to end.

    Args:
        camera: viewing camera.
        means: (N, 3) Gaussian centers.
        quats: (N, 4) rotations (w, x, y, z).
        scales: (N, 3) per-axis standard deviations (positive; apply your
            activation, e.g. ``mx.exp``, before calling).
        opacities: (N,) in [0, 1] (apply sigmoid before calling).
        colors: (N, 3) RGB in [0, 1]. Mutually exclusive with ``sh``.
        sh: (N, K, 3) spherical-harmonic coefficients (K >= (sh_degree+1)^2);
            view-dependent color is evaluated per Gaussian toward the camera.
        background: (3,) background color (default black).
        refine_tiles: experimental conservative ellipse/tile rejection after
            square radius binning. Off by default because its extra MLX work is
            not faster on all scenes.
        antialias: enable Mip-Splatting-style opacity compensation for the
            projection blur. This reduces over-bright subpixel splats while
            preserving the existing 3DGS-compatible behavior by default.
        projection: ``"ewa"`` for the fast analytic pinhole projection or
            ``"ut"`` for a 3DGUT-style Unscented Transform projection through
            the full camera model, including distortion and fisheye.

    Returns:
        dict with ``image`` (H, W, 3), ``alpha`` (H, W), plus the projection
        outputs (``means2d``, ``depths``, ``radii``) for densification
        bookkeeping.
    """
    if (colors is None) == (sh is None):
        raise ValueError("Provide exactly one of `colors` or `sh`.")

    proj = _project(camera, means, quats, scales, projection, antialias)
    opacities = opacities * proj["compensation"]

    if sh is not None:
        dirs = means - camera.camera_center
        dirs = dirs / mx.maximum(mx.linalg.norm(dirs, axis=-1, keepdims=True), 1e-8)
        colors = mx.maximum(eval_sh(sh_degree, sh, mx.stop_gradient(dirs)), 0.0)

    sorted_ids, tile_ranges, tiles_x, tiles_y = bin_gaussians(
        proj["means2d"],
        proj["radii"],
        proj["depths"],
        camera.width,
        camera.height,
        conics=proj["conics"] if refine_tiles else None,
    )

    out = rasterize(
        proj["means2d"],
        proj["conics"],
        colors,
        opacities,
        sorted_ids,
        tile_ranges,
        camera.width,
        camera.height,
        tiles_x,
        tiles_y,
        background=background,
    )
    out.update(
        {
            "means2d": proj["means2d"],
            "depths": proj["depths"],
            "radii": proj["radii"],
            "compensation": proj["compensation"],
        }
    )
    return out

eval_sh(degree, sh, dirs)

Evaluate SH at unit directions.

Parameters:

Name Type Description Default
degree int

SH degree in [0, 3].

required
sh array

(N, K, C) coefficients with K >= (degree+1)^2.

required
dirs array

(N, 3) unit view directions.

required

Returns:

Type Description
array

(N, C) colors, offset by +0.5 (callers should clamp to >= 0).

Source code in src/mlx3d/splatting/sh.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def eval_sh(degree: int, sh: mx.array, dirs: mx.array) -> mx.array:
    """Evaluate SH at unit directions.

    Args:
        degree: SH degree in [0, 3].
        sh: (N, K, C) coefficients with K >= (degree+1)^2.
        dirs: (N, 3) unit view directions.

    Returns:
        (N, C) colors, offset by +0.5 (callers should clamp to >= 0).
    """
    result = _C0 * sh[:, 0]
    if degree > 0:
        x, y, z = dirs[:, 0:1], dirs[:, 1:2], dirs[:, 2:3]
        result = result - _C1 * y * sh[:, 1] + _C1 * z * sh[:, 2] - _C1 * x * sh[:, 3]
        if degree > 1:
            xx, yy, zz = x * x, y * y, z * z
            xy, yz, xz = x * y, y * z, x * z
            result = (
                result
                + _C2[0] * xy * sh[:, 4]
                + _C2[1] * yz * sh[:, 5]
                + _C2[2] * (2.0 * zz - xx - yy) * sh[:, 6]
                + _C2[3] * xz * sh[:, 7]
                + _C2[4] * (xx - yy) * sh[:, 8]
            )
            if degree > 2:
                result = (
                    result
                    + _C3[0] * y * (3.0 * xx - yy) * sh[:, 9]
                    + _C3[1] * xy * z * sh[:, 10]
                    + _C3[2] * y * (4.0 * zz - xx - yy) * sh[:, 11]
                    + _C3[3] * z * (2.0 * zz - 3.0 * xx - 3.0 * yy) * sh[:, 12]
                    + _C3[4] * x * (4.0 * zz - xx - yy) * sh[:, 13]
                    + _C3[5] * z * (xx - yy) * sh[:, 14]
                    + _C3[6] * x * (xx - 3.0 * yy) * sh[:, 15]
                )
    return result + 0.5

rgb_to_sh(rgb)

Convert RGB in [0, 1] to the DC spherical-harmonic coefficient.

Source code in src/mlx3d/splatting/sh.py
31
32
33
def rgb_to_sh(rgb: mx.array) -> mx.array:
    """Convert RGB in [0, 1] to the DC spherical-harmonic coefficient."""
    return (rgb - 0.5) / _C0

sh_to_rgb(sh_dc)

Convert the DC coefficient back to RGB.

Source code in src/mlx3d/splatting/sh.py
36
37
38
def sh_to_rgb(sh_dc: mx.array) -> mx.array:
    """Convert the DC coefficient back to RGB."""
    return sh_dc * _C0 + 0.5

bin_gaussians(means2d, radii, depths, width, height, conics=None)

Assign Gaussians to screen tiles and sort by (tile, depth).

Parameters:

Name Type Description Default
means2d array

(N, 2) pixel-space centers.

required
radii array

(N,) pixel radii; 0 means culled.

required
depths array

(N,) camera-space depths.

required
conics array | None

optional (N, 3) inverse 2D covariance upper-triangular coefficients (a, b, c). When provided, duplicate tiles whose rectangle cannot intersect the Gaussian's 3-sigma ellipse are conservatively rejected after the radius-bbox pass.

None

Returns:

Type Description
array

(sorted_ids, tile_ranges, tiles_x, tiles_y) where sorted_ids

array

is (D,) int32 Gaussian indices for all duplicates in render order and

int

tile_ranges is (tiles_x * tiles_y, 2) int32 [start, end) ranges

int

into sorted_ids.

Source code in src/mlx3d/splatting/tiles.py
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def bin_gaussians(
    means2d: mx.array,
    radii: mx.array,
    depths: mx.array,
    width: int,
    height: int,
    conics: mx.array | None = None,
) -> tuple[mx.array, mx.array, int, int]:
    """Assign Gaussians to screen tiles and sort by (tile, depth).

    Args:
        means2d: (N, 2) pixel-space centers.
        radii: (N,) pixel radii; 0 means culled.
        depths: (N,) camera-space depths.
        conics: optional (N, 3) inverse 2D covariance upper-triangular
            coefficients ``(a, b, c)``. When provided, duplicate tiles whose
            rectangle cannot intersect the Gaussian's 3-sigma ellipse are
            conservatively rejected after the radius-bbox pass.

    Returns:
        ``(sorted_ids, tile_ranges, tiles_x, tiles_y)`` where ``sorted_ids``
        is (D,) int32 Gaussian indices for all duplicates in render order and
        ``tile_ranges`` is (tiles_x * tiles_y, 2) int32 [start, end) ranges
        into ``sorted_ids``.
    """
    means2d = mx.stop_gradient(means2d)
    radii = mx.stop_gradient(radii)
    depths = mx.stop_gradient(depths)
    conics = None if conics is None else mx.stop_gradient(conics)

    N = means2d.shape[0]
    tiles_x = (width + TILE_SIZE - 1) // TILE_SIZE
    tiles_y = (height + TILE_SIZE - 1) // TILE_SIZE
    num_tiles = tiles_x * tiles_y

    # Inclusive tile bounding box of each Gaussian's 3-sigma square.
    x, y = means2d[:, 0], means2d[:, 1]
    r = radii
    xmin = mx.clip(mx.floor((x - r) / TILE_SIZE), 0, tiles_x - 1).astype(mx.int32)
    xmax = mx.clip(mx.floor((x + r) / TILE_SIZE), 0, tiles_x - 1).astype(mx.int32)
    ymin = mx.clip(mx.floor((y - r) / TILE_SIZE), 0, tiles_y - 1).astype(mx.int32)
    ymax = mx.clip(mx.floor((y + r) / TILE_SIZE), 0, tiles_y - 1).astype(mx.int32)

    on_screen = (radii > 0) & (x + r >= 0) & (x - r < width) & (y + r >= 0) & (y - r < height)
    w_tiles = (xmax - xmin + 1) * on_screen
    h_tiles = (ymax - ymin + 1) * on_screen
    counts = (w_tiles * h_tiles).astype(mx.int32)  # (N,)

    offsets = mx.cumsum(counts) - counts  # exclusive prefix sum
    total = int(counts.sum().item())
    if total == 0:
        return (
            mx.zeros((0,), dtype=mx.int32),
            mx.zeros((num_tiles, 2), dtype=mx.int32),
            tiles_x,
            tiles_y,
        )

    # Expand: duplicate j belongs to Gaussian g(j). Scatter a 1 at each
    # Gaussian's first duplicate slot, cumsum, subtract 1. Zero-count
    # Gaussians scatter onto the same slot as their successor, and the
    # cumulative sum skips them correctly.
    marker = mx.zeros((total,), dtype=mx.int32)
    marker = marker.at[offsets].add(mx.ones((N,), dtype=mx.int32))
    gauss_id = mx.cumsum(marker) - 1  # (D,)

    local = mx.arange(total, dtype=mx.int32) - offsets[gauss_id]  # rank within bbox
    gw = mx.maximum(w_tiles[gauss_id], 1)
    tile_x = xmin[gauss_id] + local % gw
    tile_y = ymin[gauss_id] + local // gw
    tile_id = (tile_y * tiles_x + tile_x).astype(mx.int32)

    if conics is not None:
        # Conservative AccuTile-style refinement: for each duplicate generated
        # by the square 3-sigma bbox, find the minimum conic distance over the
        # continuous tile rectangle. The renderer skips alpha < 1/255, so keep
        # a margin beyond 3 sigma to preserve faint high-opacity tails.
        x0 = (tile_x * TILE_SIZE).astype(mx.float32)
        y0 = (tile_y * TILE_SIZE).astype(mx.float32)
        x1 = mx.minimum((tile_x + 1) * TILE_SIZE, width).astype(mx.float32)
        y1 = mx.minimum((tile_y + 1) * TILE_SIZE, height).astype(mx.float32)

        gx = x[gauss_id]
        gy = y[gauss_id]
        dx_min = x0 - gx
        dx_max = x1 - gx
        dy_min = y0 - gy
        dy_max = y1 - gy

        co = conics[gauss_id]
        a = co[:, 0]
        b = co[:, 1]
        c = co[:, 2]

        def q(dx, dy):
            return a * dx * dx + 2.0 * b * dx * dy + c * dy * dy

        zero = mx.zeros_like(dx_min)
        inside = (dx_min <= 0) & (dx_max >= 0) & (dy_min <= 0) & (dy_max >= 0)

        # Check corners plus the four edge-wise minimizers of the positive
        # definite quadratic form. This gives the exact minimum over the box.
        dx_left = dx_min
        dx_right = dx_max
        dy_bottom = dy_min
        dy_top = dy_max

        dy_at_left = mx.clip(-(b / c) * dx_left, dy_bottom, dy_top)
        dy_at_right = mx.clip(-(b / c) * dx_right, dy_bottom, dy_top)
        dx_at_bottom = mx.clip(-(b / a) * dy_bottom, dx_left, dx_right)
        dx_at_top = mx.clip(-(b / a) * dy_top, dx_left, dx_right)

        qmin = q(dx_left, dy_bottom)
        qmin = mx.minimum(qmin, q(dx_left, dy_top))
        qmin = mx.minimum(qmin, q(dx_right, dy_bottom))
        qmin = mx.minimum(qmin, q(dx_right, dy_top))
        qmin = mx.minimum(qmin, q(dx_left, dy_at_left))
        qmin = mx.minimum(qmin, q(dx_right, dy_at_right))
        qmin = mx.minimum(qmin, q(dx_at_bottom, dy_bottom))
        qmin = mx.minimum(qmin, q(dx_at_top, dy_top))
        qmin = mx.where(inside, zero, qmin)

        keep = qmin <= 12.0
        keep_i = keep.astype(mx.int32)
        active_total = int(keep_i.sum().item())
        if active_total == 0:
            return (
                mx.zeros((0,), dtype=mx.int32),
                mx.zeros((num_tiles, 2), dtype=mx.int32),
                tiles_x,
                tiles_y,
            )

        active_pos = mx.cumsum(keep_i) - 1
        safe_pos = mx.where(keep, active_pos, mx.zeros_like(active_pos))
        gauss_id = (
            mx.zeros((active_total,), dtype=mx.int32)
            .at[safe_pos]
            .add(mx.where(keep, gauss_id, mx.zeros_like(gauss_id)))
        )
        tile_id = (
            mx.zeros((active_total,), dtype=mx.int32)
            .at[safe_pos]
            .add(mx.where(keep, tile_id, mx.zeros_like(tile_id)))
        )
        total = active_total

    # Depth ranks (dense, < N) make an exact composite sort key.
    order = mx.argsort(depths)
    ranks = mx.zeros((N,), dtype=mx.int32).at[order].add(mx.arange(N, dtype=mx.int32))
    key = tile_id.astype(mx.int64) * N + ranks[gauss_id].astype(mx.int64)
    sort_idx = mx.argsort(key)

    sorted_ids = gauss_id[sort_idx].astype(mx.int32)
    sorted_tiles = tile_id[sort_idx].astype(mx.int32)

    # Per-tile [start, end) ranges via scatter-min / scatter-max.
    positions = mx.arange(total, dtype=mx.int32)
    starts = mx.full((num_tiles,), total, dtype=mx.int32).at[sorted_tiles].minimum(positions)
    ends = mx.zeros((num_tiles,), dtype=mx.int32).at[sorted_tiles].maximum(positions + 1)
    starts = mx.minimum(starts, ends)  # empty tiles -> start == end == 0
    tile_ranges = mx.stack([starts, ends], axis=-1)

    return sorted_ids, tile_ranges, tiles_x, tiles_y