Skip to content

mlx3d.renderer

mlx3d.renderer

Renderer

Bases: Protocol

Anything callable as renderer(camera, *scene) -> RenderOutput.

Use it purely as a type hint for code that accepts a pluggable renderer::

def turntable(renderer: Renderer, cameras, scene): ...

Both library functions and your own closures satisfy it without subclassing.

Source code in src/mlx3d/renderer/protocols.py
51
52
53
54
55
56
57
58
59
60
61
62
@runtime_checkable
class Renderer(Protocol):
    """Anything callable as ``renderer(camera, *scene) -> RenderOutput``.

    Use it purely as a type hint for code that accepts a pluggable renderer::

        def turntable(renderer: Renderer, cameras, scene): ...

    Both library functions and your own closures satisfy it without subclassing.
    """

    def __call__(self, camera: Camera, *args: object, **kwargs: object) -> RenderOutput: ...

RenderOutput

Bases: TypedDict

The dict every image-space renderer returns.

Keys are optional so depth-only or alpha-only passes are still valid outputs, but RGB renderers populate all three:

  • image: (H, W, 3) RGB in [0, 1].
  • alpha: (H, W) coverage / opacity in [0, 1].
  • depth: (H, W) expected depth in world units.

Renderers may attach extra keys (e.g. means2d for splatting); consumers should read by key rather than assume an exact set.

Source code in src/mlx3d/renderer/protocols.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class RenderOutput(TypedDict, total=False):
    """The dict every image-space renderer returns.

    Keys are optional so depth-only or alpha-only passes are still valid
    outputs, but RGB renderers populate all three:

    - ``image``: ``(H, W, 3)`` RGB in ``[0, 1]``.
    - ``alpha``: ``(H, W)`` coverage / opacity in ``[0, 1]``.
    - ``depth``: ``(H, W)`` expected depth in world units.

    Renderers may attach extra keys (e.g. ``means2d`` for splatting); consumers
    should read by key rather than assume an exact set.
    """

    image: mx.array
    alpha: mx.array
    depth: mx.array

Fragments dataclass

Per-pixel rasterizer output.

Attributes:

Name Type Description
pix_to_face array

(H, W) int32 index of the nearest face, -1 if empty.

zbuf array

(H, W) interpolated camera-space depth (0 where empty).

bary array

(H, W, 3) barycentric coordinates of the hit (differentiable w.r.t. vertex positions).

valid array

(H, W) bool mask of covered pixels.

vert_ids array

(H, W, 3) vertex indices of the hit face.

Source code in src/mlx3d/renderer/rasterizer.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
@dataclass
class Fragments:
    """Per-pixel rasterizer output.

    Attributes:
        pix_to_face: ``(H, W)`` int32 index of the nearest face, ``-1`` if empty.
        zbuf: ``(H, W)`` interpolated camera-space depth (0 where empty).
        bary: ``(H, W, 3)`` barycentric coordinates of the hit (differentiable
            w.r.t. vertex positions).
        valid: ``(H, W)`` bool mask of covered pixels.
        vert_ids: ``(H, W, 3)`` vertex indices of the hit face.
    """

    pix_to_face: mx.array
    zbuf: mx.array
    bary: mx.array
    valid: mx.array
    vert_ids: mx.array

AmbientLights dataclass

Uniform ambient illumination.

Source code in src/mlx3d/renderer/shading.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@dataclass
class AmbientLights:
    """Uniform ambient illumination."""

    color: tuple[float, float, float] = (1.0, 1.0, 1.0)

    def diffuse(self, normals: mx.array, points: mx.array) -> mx.array:
        return mx.zeros_like(normals)

    def specular(self, normals, points, camera_center, shininess) -> mx.array:
        return mx.zeros_like(normals)

    @property
    def ambient(self) -> mx.array:
        return _arr(self.color)

DirectionalLights dataclass

A light at infinity. direction is the direction the light travels.

Source code in src/mlx3d/renderer/shading.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
@dataclass
class DirectionalLights:
    """A light at infinity. ``direction`` is the direction the light travels."""

    direction: tuple[float, float, float] = (0.0, -1.0, 0.0)
    color: tuple[float, float, float] = (1.0, 1.0, 1.0)
    ambient_color: tuple[float, float, float] = (0.0, 0.0, 0.0)

    def _to_light(self, points: mx.array) -> mx.array:
        d = -_arr(self.direction)
        d = d / mx.maximum(mx.linalg.norm(d), 1e-8)
        return mx.broadcast_to(d, points.shape)

    def diffuse(self, normals: mx.array, points: mx.array) -> mx.array:
        light_dir = self._to_light(points)
        ndl = mx.maximum(mx.sum(normals * light_dir, axis=-1, keepdims=True), 0.0)
        return ndl * _arr(self.color)

    def specular(self, normals, points, camera_center, shininess) -> mx.array:
        return _blinn_phong(
            normals, points, self._to_light(points), camera_center, shininess, _arr(self.color)
        )

    @property
    def ambient(self) -> mx.array:
        return _arr(self.ambient_color)

PointLights dataclass

A point light at location (world space).

Source code in src/mlx3d/renderer/shading.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
@dataclass
class PointLights:
    """A point light at ``location`` (world space)."""

    location: tuple[float, float, float] = (0.0, 1.0, -3.0)
    color: tuple[float, float, float] = (1.0, 1.0, 1.0)
    ambient_color: tuple[float, float, float] = (0.0, 0.0, 0.0)

    def _to_light(self, points: mx.array) -> mx.array:
        d = _arr(self.location) - points
        return d / mx.maximum(mx.linalg.norm(d, axis=-1, keepdims=True), 1e-8)

    def diffuse(self, normals: mx.array, points: mx.array) -> mx.array:
        ndl = mx.maximum(mx.sum(normals * self._to_light(points), axis=-1, keepdims=True), 0.0)
        return ndl * _arr(self.color)

    def specular(self, normals, points, camera_center, shininess) -> mx.array:
        return _blinn_phong(
            normals, points, self._to_light(points), camera_center, shininess, _arr(self.color)
        )

    @property
    def ambient(self) -> mx.array:
        return _arr(self.ambient_color)

render_mesh_soft(camera, mesh_or_verts, faces=None, verts_colors=None, face_colors=None, texcoords=None, faces_texcoords_idx=None, texture=None, sigma=0.01, depth_temperature=25.0, background=0.0, eps=1e-08, face_chunk_size=256)

Render a triangle mesh with a SoftRas-style differentiable rasterizer.

Fully MLX-differentiable w.r.t. vertices, vertex colors, face colors, and texture values. Topology and UV indices are treated as discrete.

The rasterizer processes faces in batches of face_chunk_size to keep memory use bounded: each chunk creates (chunk, H, W) intermediates so even large meshes can be rendered on 8-16 GB machines. Set face_chunk_size=None to disable chunking (faster for small meshes that fit comfortably in memory).

Parameters:

Name Type Description Default
camera Camera

Pinhole camera.

required
mesh_or_verts Meshes | array

A :class:~mlx3d.structures.Meshes (single mesh) or (V, 3) vertex array.

required
faces array | None

(F, 3) index array, required when mesh_or_verts is an array.

None
verts_colors array | None

(V, 3) per-vertex colors, interpolated in barycentric space.

None
face_colors array | None

(F, 3) constant color per face.

None
texcoords array | None

(VT, 2) UV coordinates for textured rendering.

None
faces_texcoords_idx array | None

(F, 3) per-corner indices into texcoords.

None
texture array | None

(H, W, 3) diffuse texture image.

None
sigma float

Soft boundary width (larger → smoother, less sharp edges).

0.01
depth_temperature float

Controls depth-based face ordering; higher values make nearer faces dominate more sharply.

25.0
background array | tuple[float, float, float] | float

Background color — scalar, (3,) array, or tuple.

0.0
eps float

Numerical epsilon for safe divisions.

1e-08
face_chunk_size int | None

Process this many faces per chunk; None = no chunking.

256

Returns:

Type Description
dict[str, array]

A dict with keys "image" (H, W, 3), "alpha" (H, W),

dict[str, array]

and "depth" (H, W).

Source code in src/mlx3d/renderer/mesh.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def render_mesh_soft(
    camera: Camera,
    mesh_or_verts: Meshes | mx.array,
    faces: mx.array | None = None,
    verts_colors: mx.array | None = None,
    face_colors: mx.array | None = None,
    texcoords: mx.array | None = None,
    faces_texcoords_idx: mx.array | None = None,
    texture: mx.array | None = None,
    sigma: float = 1e-2,
    depth_temperature: float = 25.0,
    background: mx.array | tuple[float, float, float] | float = 0.0,
    eps: float = 1e-8,
    face_chunk_size: int | None = 256,
) -> dict[str, mx.array]:
    """Render a triangle mesh with a SoftRas-style differentiable rasterizer.

    Fully MLX-differentiable w.r.t. vertices, vertex colors, face colors, and
    texture values. Topology and UV indices are treated as discrete.

    The rasterizer processes faces in batches of ``face_chunk_size`` to keep
    memory use bounded: each chunk creates ``(chunk, H, W)`` intermediates so
    even large meshes can be rendered on 8-16 GB machines.  Set
    ``face_chunk_size=None`` to disable chunking (faster for small meshes that
    fit comfortably in memory).

    Args:
        camera: Pinhole camera.
        mesh_or_verts: A :class:`~mlx3d.structures.Meshes` (single mesh) or
            ``(V, 3)`` vertex array.
        faces: ``(F, 3)`` index array, required when ``mesh_or_verts`` is an
            array.
        verts_colors: ``(V, 3)`` per-vertex colors, interpolated in barycentric
            space.
        face_colors: ``(F, 3)`` constant color per face.
        texcoords: ``(VT, 2)`` UV coordinates for textured rendering.
        faces_texcoords_idx: ``(F, 3)`` per-corner indices into ``texcoords``.
        texture: ``(H, W, 3)`` diffuse texture image.
        sigma: Soft boundary width (larger → smoother, less sharp edges).
        depth_temperature: Controls depth-based face ordering; higher values
            make nearer faces dominate more sharply.
        background: Background color — scalar, ``(3,)`` array, or tuple.
        eps: Numerical epsilon for safe divisions.
        face_chunk_size: Process this many faces per chunk; ``None`` = no
            chunking.

    Returns:
        A dict with keys ``"image"`` ``(H, W, 3)``, ``"alpha"`` ``(H, W)``,
        and ``"depth"`` ``(H, W)``.
    """
    verts, faces_idx = _as_verts_faces(mesh_or_verts, faces)
    faces_idx = faces_idx.astype(mx.int32)
    h, w = camera.height, camera.width
    f = faces_idx.shape[0]

    xy, z = camera.project_points(verts)
    tri_xy = xy[faces_idx]  # (F, 3, 2)
    tri_z = z[faces_idx]  # (F, 3)

    # Global max inverse depth from vertex depths — O(F) space, used for
    # numerical stability of the exp-based depth weighting across chunks.
    valid_verts = tri_z > camera.znear
    if bool(valid_verts.any()):
        min_z = mx.min(mx.where(valid_verts, tri_z, mx.full(tri_z.shape, 1e9)))
        max_inv_depth = 1.0 / mx.maximum(min_z, camera.znear)
    else:
        max_inv_depth = mx.array(1.0 / float(camera.znear))

    # Pixel-centre grids: (1, H, W)
    px = mx.arange(w, dtype=mx.float32) + 0.5
    py = mx.arange(h, dtype=mx.float32) + 0.5
    gx = mx.broadcast_to(px[None, :], (h, w))[None]
    gy = mx.broadcast_to(py[:, None], (h, w))[None]

    # Optional per-face texture UV coords — indexed once for the whole batch.
    uv_tri_full: mx.array | None = None
    vc_full: mx.array | None = None
    fc_full: mx.array | None = None
    if texture is not None:
        if texcoords is None or faces_texcoords_idx is None:
            raise ValueError("texcoords and faces_texcoords_idx are required with texture.")
        uv_tri_full = texcoords[faces_texcoords_idx.astype(mx.int32)]  # (F, 3, 2)
    elif verts_colors is not None:
        vc_full = verts_colors[faces_idx]  # (F, 3, C)
    else:
        fc_full = face_colors if face_colors is not None else mx.ones((f, 3), dtype=mx.float32)

    # Accumulators for incremental sum over chunks.
    sum_w: mx.array = mx.zeros((h, w), dtype=mx.float32)
    sum_wc: mx.array = mx.zeros((h, w, 3), dtype=mx.float32)
    sum_wz: mx.array = mx.zeros((h, w), dtype=mx.float32)

    chunk = f if face_chunk_size is None else max(1, int(face_chunk_size))
    for start in range(0, f, chunk):
        end = min(start + chunk, f)
        sl = slice(start, end)

        txy = tri_xy[sl]  # (C, 3, 2)
        tz = tri_z[sl]  # (C, 3)
        C = end - start

        x0c = txy[:, 0, 0][:, None, None]
        y0c = txy[:, 0, 1][:, None, None]
        x1c = txy[:, 1, 0][:, None, None]
        y1c = txy[:, 1, 1][:, None, None]
        x2c = txy[:, 2, 0][:, None, None]
        y2c = txy[:, 2, 1][:, None, None]

        denom = (y1c - y2c) * (x0c - x2c) + (x2c - x1c) * (y0c - y2c)
        valid_area = mx.abs(denom) > eps
        denom = mx.where(valid_area, denom, mx.ones_like(denom))

        l0 = ((y1c - y2c) * (gx - x2c) + (x2c - x1c) * (gy - y2c)) / denom
        l1 = ((y2c - y0c) * (gx - x2c) + (x0c - x2c) * (gy - y2c)) / denom
        l2 = 1.0 - l0 - l1

        signed_inside = mx.minimum(mx.minimum(l0, l1), l2)
        coverage = mx.sigmoid(signed_inside / max(float(sigma), 1e-6))

        z_face = (
            l0 * tz[:, 0][:, None, None]
            + l1 * tz[:, 1][:, None, None]
            + l2 * tz[:, 2][:, None, None]
        )
        in_front = mx.all(tz > camera.znear, axis=-1)[:, None, None]
        coverage = coverage * valid_area * in_front

        # Barycentric coords go outside [0, 1] for pixels off the triangle, so
        # ``z_face`` there can fall below ``znear`` and make ``inv_depth`` blow
        # up. Clamp the exponent at 0 (the globally nearest face gets weight 1):
        # this both matches the soft z-buffer semantics and avoids the
        # ``coverage(=0) * exp(=inf) = NaN`` that would otherwise appear where a
        # distant face projects far from the pixel.
        inv_depth = 1.0 / mx.maximum(z_face, camera.znear)
        depth_w = mx.exp(mx.minimum(depth_temperature * (inv_depth - max_inv_depth), 0.0))
        weights = coverage * depth_w  # (C, H, W)

        if texture is not None:
            uv_tri_c = uv_tri_full[sl]  # type: ignore[index]
            uv = (
                l0[..., None] * uv_tri_c[:, 0, :][:, None, None, :]
                + l1[..., None] * uv_tri_c[:, 1, :][:, None, None, :]
                + l2[..., None] * uv_tri_c[:, 2, :][:, None, None, :]
            )
            colors = sample_texture(texture, uv)
        elif vc_full is not None:
            vc = vc_full[sl]
            colors = (
                l0[..., None] * vc[:, 0, :][:, None, None, :]
                + l1[..., None] * vc[:, 1, :][:, None, None, :]
                + l2[..., None] * vc[:, 2, :][:, None, None, :]
            )
        else:
            colors = mx.broadcast_to(fc_full[sl][:, None, None, :], (C, h, w, 3))  # type: ignore[index]

        sum_w = sum_w + mx.sum(weights, axis=0)
        sum_wc = sum_wc + mx.sum(weights[..., None] * colors, axis=0)
        sum_wz = sum_wz + mx.sum(weights * z_face, axis=0)

    image = sum_wc / mx.maximum(sum_w[..., None], eps)
    alpha = 1.0 - mx.exp(-sum_w)
    bg = mx.array(background, dtype=mx.float32)
    if bg.ndim == 0:
        bg = mx.broadcast_to(bg, (3,))
    image = image * alpha[..., None] + bg * (1.0 - alpha[..., None])
    depth = sum_wz / mx.maximum(sum_w, eps)

    return {"image": image, "alpha": alpha, "depth": depth}

sample_texture(texture, uv)

Bilinearly sample an image texture at UV coordinates.

Parameters:

Name Type Description Default
texture array

(H, W, 3) RGB texture in [0, 1].

required
uv array

(..., 2) UV coordinates. OBJ convention is used: v=0 is the bottom of the image.

required
Source code in src/mlx3d/renderer/mesh.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def sample_texture(texture: mx.array, uv: mx.array) -> mx.array:
    """Bilinearly sample an image texture at UV coordinates.

    Args:
        texture: ``(H, W, 3)`` RGB texture in ``[0, 1]``.
        uv: ``(..., 2)`` UV coordinates. OBJ convention is used: ``v=0`` is
            the bottom of the image.
    """
    tex = texture.astype(mx.float32)
    h, w = tex.shape[:2]
    u = mx.clip(uv[..., 0], 0.0, 1.0) * (w - 1)
    v = (1.0 - mx.clip(uv[..., 1], 0.0, 1.0)) * (h - 1)

    x0 = mx.floor(u).astype(mx.int32)
    y0 = mx.floor(v).astype(mx.int32)
    x1 = mx.minimum(x0 + 1, w - 1)
    y1 = mx.minimum(y0 + 1, h - 1)
    wx = (u - x0.astype(mx.float32))[..., None]
    wy = (v - y0.astype(mx.float32))[..., None]

    flat = tex.reshape(h * w, 3)
    c00 = flat[(y0 * w + x0).reshape(-1)].reshape(*uv.shape[:-1], 3)
    c10 = flat[(y0 * w + x1).reshape(-1)].reshape(*uv.shape[:-1], 3)
    c01 = flat[(y1 * w + x0).reshape(-1)].reshape(*uv.shape[:-1], 3)
    c11 = flat[(y1 * w + x1).reshape(-1)].reshape(*uv.shape[:-1], 3)
    return (1 - wx) * (1 - wy) * c00 + wx * (1 - wy) * c10 + (1 - wx) * wy * c01 + wx * wy * c11

render_points(camera, points, colors=None, radius=2.0, window=5, depth_temperature=10.0, background=0.0, eps=1e-08)

Render a point cloud with soft Gaussian splats.

Parameters:

Name Type Description Default
camera Camera

the :class:Camera to render from.

required
points array

(P, 3) world-space positions.

required
colors array | None

(P, 3) per-point colors in [0, 1] (defaults to white).

None
radius float

Gaussian sigma in pixels.

2.0
window int

splat window size in pixels (odd).

5
depth_temperature float

sharpness of the soft depth weighting; larger values approach hard z-ordering.

10.0
background float

scalar background intensity.

0.0

Returns:

Type Description
dict[str, array]

dict with image (H, W, 3), alpha (H, W) coverage, and

dict[str, array]

depth (H, W) soft depth.

Source code in src/mlx3d/renderer/points.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def render_points(
    camera: Camera,
    points: mx.array,
    colors: mx.array | None = None,
    radius: float = 2.0,
    window: int = 5,
    depth_temperature: float = 10.0,
    background: float = 0.0,
    eps: float = 1e-8,
) -> dict[str, mx.array]:
    """Render a point cloud with soft Gaussian splats.

    Args:
        camera: the :class:`Camera` to render from.
        points: (P, 3) world-space positions.
        colors: (P, 3) per-point colors in [0, 1] (defaults to white).
        radius: Gaussian sigma in pixels.
        window: splat window size in pixels (odd).
        depth_temperature: sharpness of the soft depth weighting; larger
            values approach hard z-ordering.
        background: scalar background intensity.

    Returns:
        dict with ``image`` (H, W, 3), ``alpha`` (H, W) coverage, and
        ``depth`` (H, W) soft depth.
    """
    H, W = camera.height, camera.width
    P = points.shape[0]
    if colors is None:
        colors = mx.ones((P, 3))

    xy, z = camera.project_points(points)  # (P, 2), (P,)
    in_front = z > camera.znear

    # Window pixel offsets around each point's containing pixel.
    half = window // 2
    base = mx.stop_gradient(mx.floor(xy))  # (P, 2) gradient flows via the Gaussian
    offs = mx.arange(-half, half + 1, dtype=mx.float32)
    ox = mx.broadcast_to(offs[None, :], (window, window)).reshape(-1)
    oy = mx.broadcast_to(offs[:, None], (window, window)).reshape(-1)
    px = base[:, 0:1] + ox[None, :]  # (P, K)
    py = base[:, 1:2] + oy[None, :]

    # Gaussian weight from the continuous projected position to pixel centers.
    dx = px + 0.5 - xy[:, 0:1]
    dy = py + 0.5 - xy[:, 1:2]
    w = mx.exp(-(dx * dx + dy * dy) / (2.0 * radius * radius))  # (P, K)

    # Depth-aware weighting: nearer points dominate (soft z-buffer).
    inv_depth = 1.0 / mx.maximum(z, camera.znear)
    depth_w = mx.exp(depth_temperature * (inv_depth - inv_depth.max()))
    w = w * depth_w[:, None] * in_front[:, None]

    # Mask out-of-bounds pixels and flatten indices.
    valid = (px >= 0) & (px < W) & (py >= 0) & (py < H)
    w = w * valid
    idx = (mx.clip(py, 0, H - 1) * W + mx.clip(px, 0, W - 1)).astype(mx.int32)  # (P, K)
    flat_idx = idx.reshape(-1)
    flat_w = w.reshape(-1)

    wc = (w[..., None] * colors[:, None, :]).reshape(-1, 3)  # (P*K, 3)
    wz = (w * z[:, None]).reshape(-1)

    sum_wc = mx.zeros((H * W, 3)).at[flat_idx].add(wc)
    sum_w = mx.zeros((H * W,)).at[flat_idx].add(flat_w)
    sum_wz = mx.zeros((H * W,)).at[flat_idx].add(wz)

    norm = mx.maximum(sum_w, eps)[:, None]
    image = sum_wc / norm
    alpha = 1.0 - mx.exp(-sum_w)  # soft coverage
    image = image * alpha[:, None] + background * (1.0 - alpha[:, None])

    return {
        "image": image.reshape(H, W, 3),
        "alpha": alpha.reshape(H, W),
        "depth": (sum_wz / mx.maximum(sum_w, eps)).reshape(H, W),
    }

interpolate_face_attributes(frag, attrs)

Interpolate a per-vertex attribute over the rasterized fragments.

Parameters:

Name Type Description Default
frag Fragments

fragments from :func:rasterize_meshes.

required
attrs array

(V, C) per-vertex values.

required

Returns:

Type Description
array

(H, W, C) interpolated values; 0 on empty pixels. Differentiable

array

w.r.t. both attrs and the vertex positions (through the barycentrics).

Source code in src/mlx3d/renderer/rasterizer.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
def interpolate_face_attributes(frag: Fragments, attrs: mx.array) -> mx.array:
    """Interpolate a per-vertex attribute over the rasterized fragments.

    Args:
        frag: fragments from :func:`rasterize_meshes`.
        attrs: ``(V, C)`` per-vertex values.

    Returns:
        ``(H, W, C)`` interpolated values; ``0`` on empty pixels. Differentiable
        w.r.t. both ``attrs`` and the vertex positions (through the barycentrics).
    """
    tri_attr = attrs[frag.vert_ids]  # (H, W, 3, C)
    out = mx.sum(frag.bary[..., None] * tri_attr, axis=-2)  # (H, W, C)
    return out * frag.valid[..., None]

rasterize_meshes(camera, mesh_or_verts, faces=None)

Rasterize a triangle mesh to per-pixel :class:Fragments.

Parameters:

Name Type Description Default
camera Camera

the viewing camera.

required
mesh_or_verts Meshes | array

a single-mesh :class:~mlx3d.structures.Meshes or a (V, 3) vertex array.

required
faces array | None

(F, 3) indices, required when passing a raw vertex array.

None
Source code in src/mlx3d/renderer/rasterizer.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def rasterize_meshes(
    camera: Camera, mesh_or_verts: Meshes | mx.array, faces: mx.array | None = None
) -> Fragments:
    """Rasterize a triangle mesh to per-pixel :class:`Fragments`.

    Args:
        camera: the viewing camera.
        mesh_or_verts: a single-mesh :class:`~mlx3d.structures.Meshes` or a
            ``(V, 3)`` vertex array.
        faces: ``(F, 3)`` indices, required when passing a raw vertex array.
    """
    verts, faces_idx = _as_verts_faces(mesh_or_verts, faces)
    faces_idx = faces_idx.astype(mx.int32)
    h, w = int(camera.height), int(camera.width)

    xy, z = camera.project_points(verts)  # (V, 2), (V,)
    tri_xy = xy[faces_idx]  # (F, 3, 2)
    tri_z = z[faces_idx]  # (F, 3)

    # Bin faces into screen tiles via a per-face bounding circle (reusing the
    # Gaussian-Splatting tiler), so each pixel only scans faces in its tile.
    sxy = mx.stop_gradient(tri_xy)
    centroid = mx.mean(sxy, axis=1)  # (F, 2)
    radii = mx.sqrt(mx.max(mx.sum((sxy - centroid[:, None, :]) ** 2, axis=-1), axis=1))
    depths = mx.mean(mx.stop_gradient(tri_z), axis=1)  # (F,) for tile sort order
    sorted_ids, tile_ranges, tiles_x, tiles_y = bin_gaussians(centroid, radii, depths, w, h)

    params = mx.array([w, h, tiles_x], dtype=mx.int32)
    znear_arr = mx.array([float(camera.znear)], dtype=mx.float32)

    # Visibility is discrete: detach the kernel inputs so autodiff never tries
    # to differentiate the custom kernel. Gradients reach the vertices through
    # the MLX barycentric recompute below instead.
    pix_to_face, zbuf = _raster_kernel(
        inputs=[
            mx.contiguous(sxy.reshape(-1).astype(mx.float32)),
            mx.contiguous(mx.stop_gradient(tri_z).reshape(-1).astype(mx.float32)),
            sorted_ids.astype(mx.int32),
            tile_ranges.astype(mx.int32),
            params,
            znear_arr,
        ],
        output_shapes=[(h * w,), (h * w,)],
        output_dtypes=[mx.int32, mx.float32],
        grid=(tiles_x * 16, tiles_y * 16, 1),
        threadgroup=(16, 16, 1),
    )
    pix_to_face = pix_to_face.reshape(h, w)
    valid = pix_to_face >= 0
    fidx = mx.where(valid, pix_to_face, 0)  # safe gather index

    vert_ids = faces_idx[fidx]  # (H, W, 3)
    # Recompute barycentrics in MLX so they are differentiable w.r.t. vertices.
    tri = xy[vert_ids]  # (H, W, 3, 2)
    ax, ay = tri[..., 0, 0], tri[..., 0, 1]
    bx, by = tri[..., 1, 0], tri[..., 1, 1]
    cxv, cyv = tri[..., 2, 0], tri[..., 2, 1]
    px = (mx.arange(w, dtype=mx.float32) + 0.5)[None, :]
    py = (mx.arange(h, dtype=mx.float32) + 0.5)[:, None]
    denom = (by - cyv) * (ax - cxv) + (cxv - bx) * (ay - cyv)
    denom = mx.where(mx.abs(denom) < _EPS, mx.ones_like(denom), denom)
    l0 = ((by - cyv) * (px - cxv) + (cxv - bx) * (py - cyv)) / denom
    l1 = ((cyv - ay) * (px - cxv) + (ax - cxv) * (py - cyv)) / denom
    l2 = 1.0 - l0 - l1
    bary = mx.stack([l0, l1, l2], axis=-1)  # (H, W, 3)
    bary = bary * valid[..., None]

    return Fragments(
        pix_to_face=pix_to_face,
        zbuf=zbuf.reshape(h, w),
        bary=bary,
        valid=valid,
        vert_ids=vert_ids,
    )

sample_along_rays(origins, directions, near, far, num_samples, stratified=True)

Sample points along rays between near and far.

Parameters:

Name Type Description Default
origins array

(R, 3) ray origins.

required
directions array

(R, 3) ray directions (need not be normalized).

required
num_samples int

samples per ray.

required
stratified bool

jitter samples within their bins (use False for eval).

True

Returns:

Type Description
tuple[array, array]

(points, t_vals) with shapes (R, S, 3) and (R, S).

Source code in src/mlx3d/renderer/rays.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def sample_along_rays(
    origins: mx.array,
    directions: mx.array,
    near: float,
    far: float,
    num_samples: int,
    stratified: bool = True,
) -> tuple[mx.array, mx.array]:
    """Sample points along rays between ``near`` and ``far``.

    Args:
        origins: (R, 3) ray origins.
        directions: (R, 3) ray directions (need not be normalized).
        num_samples: samples per ray.
        stratified: jitter samples within their bins (use ``False`` for eval).

    Returns:
        ``(points, t_vals)`` with shapes (R, S, 3) and (R, S).
    """
    R = origins.shape[0]
    t = mx.linspace(near, far, num_samples)  # (S,)
    t = mx.broadcast_to(t[None], (R, num_samples))
    if stratified:
        mids = 0.5 * (t[:, 1:] + t[:, :-1])
        upper = mx.concatenate([mids, t[:, -1:]], axis=-1)
        lower = mx.concatenate([t[:, :1], mids], axis=-1)
        u = mx.random.uniform(shape=(R, num_samples))
        t = lower + (upper - lower) * u
    points = origins[:, None, :] + t[..., None] * directions[:, None, :]
    return points, t

sample_pdf(bins, weights, num_samples, deterministic=False, eps=1e-05)

Importance-sample new t-values from a piecewise-constant PDF.

Used for the hierarchical (fine) sampling stage of NeRF.

Parameters:

Name Type Description Default
bins array

(R, S+1) bin edges (e.g. midpoints of the coarse samples).

required
weights array

(R, S) unnormalized bin weights.

required
num_samples int

new samples per ray.

required

Returns:

Type Description
array

(R, num_samples) sampled t-values. Gradients are not propagated

array

through the sampling (treated as constants, as in NeRF).

Source code in src/mlx3d/renderer/rays.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def sample_pdf(
    bins: mx.array,
    weights: mx.array,
    num_samples: int,
    deterministic: bool = False,
    eps: float = 1e-5,
) -> mx.array:
    """Importance-sample new t-values from a piecewise-constant PDF.

    Used for the hierarchical (fine) sampling stage of NeRF.

    Args:
        bins: (R, S+1) bin edges (e.g. midpoints of the coarse samples).
        weights: (R, S) unnormalized bin weights.
        num_samples: new samples per ray.

    Returns:
        (R, num_samples) sampled t-values. Gradients are not propagated
        through the sampling (treated as constants, as in NeRF).
    """
    weights = mx.stop_gradient(weights) + eps
    pdf = weights / weights.sum(axis=-1, keepdims=True)
    cdf = mx.cumsum(pdf, axis=-1)
    cdf = mx.concatenate([mx.zeros_like(cdf[:, :1]), cdf], axis=-1)  # (R, S+1)

    R = cdf.shape[0]
    if deterministic:
        u = mx.linspace(0.0, 1.0 - 1e-6, num_samples)
        u = mx.broadcast_to(u[None], (R, num_samples))
    else:
        u = mx.random.uniform(shape=(R, num_samples))

    # Inverse-CDF lookup: idx[r, m] = number of cdf entries <= u (O(S*M), fine on GPU).
    idx = (cdf[:, None, :-1] <= u[..., None]).sum(axis=-1) - 1  # (R, M)
    idx = mx.clip(idx, 0, cdf.shape[-1] - 2).astype(mx.int32)

    cdf_low = mx.take_along_axis(cdf, idx, axis=-1)
    cdf_high = mx.take_along_axis(cdf, idx + 1, axis=-1)
    bins_low = mx.take_along_axis(bins, idx, axis=-1)
    bins_high = mx.take_along_axis(bins, idx + 1, axis=-1)

    denom = mx.where(cdf_high - cdf_low < eps, mx.ones_like(cdf_low), cdf_high - cdf_low)
    frac = (u - cdf_low) / denom
    return bins_low + frac * (bins_high - bins_low)

volume_render(densities, colors, t_vals, directions=None, white_background=False)

Composite densities and colors along rays with the NeRF quadrature rule.

Parameters:

Name Type Description Default
densities array

(R, S) non-negative volume densities (sigma).

required
colors array

(R, S, 3) per-sample RGB in [0, 1].

required
t_vals array

(R, S) sample distances along each ray.

required
directions array | None

optional (R, 3) ray directions; if given, deltas are scaled by their norms so densities live in world units.

None
white_background bool

composite onto white instead of black.

False

Returns:

Type Description
dict[str, array]

dict with rgb (R, 3), depth (R,), acc (R,) opacity, and

dict[str, array]

weights (R, S).

Source code in src/mlx3d/renderer/rays.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def volume_render(
    densities: mx.array,
    colors: mx.array,
    t_vals: mx.array,
    directions: mx.array | None = None,
    white_background: bool = False,
) -> dict[str, mx.array]:
    """Composite densities and colors along rays with the NeRF quadrature rule.

    Args:
        densities: (R, S) non-negative volume densities (sigma).
        colors: (R, S, 3) per-sample RGB in [0, 1].
        t_vals: (R, S) sample distances along each ray.
        directions: optional (R, 3) ray directions; if given, deltas are
            scaled by their norms so densities live in world units.
        white_background: composite onto white instead of black.

    Returns:
        dict with ``rgb`` (R, 3), ``depth`` (R,), ``acc`` (R,) opacity, and
        ``weights`` (R, S).
    """
    deltas = t_vals[:, 1:] - t_vals[:, :-1]
    deltas = mx.concatenate([deltas, mx.full(deltas[:, :1].shape, 1e10)], axis=-1)  # (R, S)
    if directions is not None:
        deltas = deltas * mx.linalg.norm(directions, axis=-1, keepdims=True)

    alpha = 1.0 - mx.exp(-densities * deltas)  # (R, S)
    # Transmittance: T_i = prod_{j<i} (1 - alpha_j)
    one_minus = mx.clip(1.0 - alpha, 1e-10, 1.0)
    trans = mx.cumprod(one_minus, axis=-1)
    trans = mx.concatenate([mx.ones_like(trans[:, :1]), trans[:, :-1]], axis=-1)
    weights = alpha * trans  # (R, S)

    rgb = mx.sum(weights[..., None] * colors, axis=-2)
    depth = mx.sum(weights * t_vals, axis=-1)
    acc = mx.sum(weights, axis=-1)
    if white_background:
        rgb = rgb + (1.0 - acc[..., None])
    return {"rgb": rgb, "depth": depth, "acc": acc, "weights": weights}

pbr_shading(points, normals, albedo, camera_center, lights, roughness=0.5, metallic=0.0)

Cook-Torrance/GGX material shading.

This is a compact real-time PBR shader: GGX normal distribution, Schlick Fresnel, and Smith-Schlick geometry term. It is not a full environment-lit renderer, but it gives glTF-style base-color/metallic/roughness controls with stable MLX autodiff.

Source code in src/mlx3d/renderer/shading.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def pbr_shading(
    points: mx.array,
    normals: mx.array,
    albedo: mx.array,
    camera_center: mx.array,
    lights: list[Light],
    roughness: float | mx.array = 0.5,
    metallic: float | mx.array = 0.0,
) -> mx.array:
    """Cook-Torrance/GGX material shading.

    This is a compact real-time PBR shader: GGX normal distribution, Schlick
    Fresnel, and Smith-Schlick geometry term. It is not a full environment-lit
    renderer, but it gives glTF-style base-color/metallic/roughness controls
    with stable MLX autodiff.
    """
    n = normals
    v = _arr(camera_center) - points
    v = v / mx.maximum(mx.linalg.norm(v, axis=-1, keepdims=True), 1e-8)
    rough = mx.clip(_arr(roughness), 0.04, 1.0)
    metal = mx.clip(_arr(metallic), 0.0, 1.0)
    if rough.ndim == 0:
        rough = mx.broadcast_to(rough, albedo.shape[:-1] + (1,))
    if metal.ndim == 0:
        metal = mx.broadcast_to(metal, albedo.shape[:-1] + (1,))

    f0 = 0.04 * (1.0 - metal) + albedo * metal
    ambient = mx.zeros((3,))
    color = mx.zeros_like(albedo)
    alpha = rough * rough
    alpha2 = alpha * alpha
    k = ((rough + 1.0) ** 2) / 8.0
    n_dot_v = mx.maximum(mx.sum(n * v, axis=-1, keepdims=True), 1e-5)

    for light in lights:
        light_dir, light_color = _direct_light_terms(light, points)
        if light_dir is None:
            ambient = ambient + light_color
            continue
        h = light_dir + v
        h = h / mx.maximum(mx.linalg.norm(h, axis=-1, keepdims=True), 1e-8)
        n_dot_l = mx.maximum(mx.sum(n * light_dir, axis=-1, keepdims=True), 0.0)
        n_dot_h = mx.maximum(mx.sum(n * h, axis=-1, keepdims=True), 0.0)
        v_dot_h = mx.maximum(mx.sum(v * h, axis=-1, keepdims=True), 0.0)

        denom = n_dot_h * n_dot_h * (alpha2 - 1.0) + 1.0
        D = alpha2 / mx.maximum(math.pi * denom * denom, 1e-8)
        Gv = n_dot_v / mx.maximum(n_dot_v * (1.0 - k) + k, 1e-8)
        Gl = n_dot_l / mx.maximum(n_dot_l * (1.0 - k) + k, 1e-8)
        F = f0 + (1.0 - f0) * ((1.0 - v_dot_h) ** 5)
        specular = (D * Gv * Gl * F) / mx.maximum(4.0 * n_dot_v * n_dot_l, 1e-6)
        diffuse = (1.0 - F) * (1.0 - metal) * albedo / math.pi
        color = color + (diffuse + specular) * light_color * n_dot_l

    color = color + ambient * albedo * (1.0 - metal)
    return mx.clip(color, 0.0, 1.0)

phong_shading(points, normals, albedo, camera_center, lights, shininess=32.0, specular_strength=0.3)

Blinn-Phong shade per-pixel buffers.

Parameters:

Name Type Description Default
points array

(H, W, 3) world positions.

required
normals array

(H, W, 3) unit world normals.

required
albedo array

(H, W, 3) base color.

required
camera_center array

(3,) camera position.

required
lights list[Light]

list of light sources.

required
shininess float

specular exponent.

32.0
specular_strength float

scalar weight on the specular term.

0.3
Source code in src/mlx3d/renderer/shading.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def phong_shading(
    points: mx.array,
    normals: mx.array,
    albedo: mx.array,
    camera_center: mx.array,
    lights: list[Light],
    shininess: float = 32.0,
    specular_strength: float = 0.3,
) -> mx.array:
    """Blinn-Phong shade per-pixel buffers.

    Args:
        points: ``(H, W, 3)`` world positions.
        normals: ``(H, W, 3)`` unit world normals.
        albedo: ``(H, W, 3)`` base color.
        camera_center: ``(3,)`` camera position.
        lights: list of light sources.
        shininess: specular exponent.
        specular_strength: scalar weight on the specular term.
    """
    ambient = mx.zeros((3,))
    diffuse = mx.zeros_like(points)
    specular = mx.zeros_like(points)
    for light in lights:
        ambient = ambient + light.ambient
        diffuse = diffuse + light.diffuse(normals, points)
        specular = specular + light.specular(normals, points, camera_center, shininess)
    color = albedo * (ambient + diffuse) + specular_strength * specular
    return mx.clip(color, 0.0, 1.0)

render_mesh(camera, mesh_or_verts, faces=None, verts_colors=None, texture=None, verts_uvs=None, faces_uvs=None, lights=None, shininess=32.0, specular_strength=0.3, roughness=0.5, metallic=0.0, background=0.0, shading='phong', ssaa=1)

Render a mesh with the hard rasterizer and Blinn-Phong lighting.

Albedo comes from a UV texture when texture is given, otherwise from verts_colors (default mid-grey). ssaa > 1 supersamples (renders at ssaa x resolution and box-downsamples) for antialiased edges.

Besides image/alpha/depth/normals the result includes render passes (AOVs): position ((H, W, 3) world-space hit point) and face_id ((H, W) nearest-face index, -1 where empty).

Parameters:

Name Type Description Default
camera Camera

viewing camera.

required
mesh_or_verts Meshes | array

a single-mesh :class:~mlx3d.structures.Meshes or (V, 3) vertices (faces then required).

required
faces array | None

(F, 3) indices when passing raw vertices.

None
verts_colors array | None

(V, 3) albedo; defaults to mid-grey.

None
texture array | None

(H, W, 3) diffuse texture; requires verts_uvs and faces_uvs.

None
verts_uvs array | None

(VT, 2) UV coordinates.

None
faces_uvs array | None

(F, 3) per-corner indices into verts_uvs.

None
lights list[Light] | None

light list; defaults to one key light + ambient. shading="none" ignores lights and returns flat albedo.

None
shading str

"phong", "pbr", or "none" (unlit albedo).

'phong'
roughness float | array

scalar or (H, W, 1) material roughness for shading="pbr".

0.5
metallic float | array

scalar or (H, W, 1) material metalness for shading="pbr".

0.0
background tuple[float, float, float] | float

scalar or (3,) background color.

0.0

Returns:

Type Description
RenderOutput

{"image", "alpha", "depth", "normals"}.

Source code in src/mlx3d/renderer/shading.py
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def render_mesh(
    camera: Camera,
    mesh_or_verts: Meshes | mx.array,
    faces: mx.array | None = None,
    verts_colors: mx.array | None = None,
    texture: mx.array | None = None,
    verts_uvs: mx.array | None = None,
    faces_uvs: mx.array | None = None,
    lights: list[Light] | None = None,
    shininess: float = 32.0,
    specular_strength: float = 0.3,
    roughness: float | mx.array = 0.5,
    metallic: float | mx.array = 0.0,
    background: tuple[float, float, float] | float = 0.0,
    shading: str = "phong",
    ssaa: int = 1,
) -> RenderOutput:
    """Render a mesh with the hard rasterizer and Blinn-Phong lighting.

    Albedo comes from a UV texture when ``texture`` is given, otherwise from
    ``verts_colors`` (default mid-grey). ``ssaa > 1`` supersamples (renders at
    ``ssaa x`` resolution and box-downsamples) for antialiased edges.

    Besides ``image``/``alpha``/``depth``/``normals`` the result includes render
    passes (AOVs): ``position`` (``(H, W, 3)`` world-space hit point) and
    ``face_id`` (``(H, W)`` nearest-face index, ``-1`` where empty).

    Args:
        camera: viewing camera.
        mesh_or_verts: a single-mesh :class:`~mlx3d.structures.Meshes` or ``(V, 3)``
            vertices (``faces`` then required).
        faces: ``(F, 3)`` indices when passing raw vertices.
        verts_colors: ``(V, 3)`` albedo; defaults to mid-grey.
        texture: ``(H, W, 3)`` diffuse texture; requires ``verts_uvs`` and ``faces_uvs``.
        verts_uvs: ``(VT, 2)`` UV coordinates.
        faces_uvs: ``(F, 3)`` per-corner indices into ``verts_uvs``.
        lights: light list; defaults to one key light + ambient. ``shading="none"``
            ignores lights and returns flat albedo.
        shading: ``"phong"``, ``"pbr"``, or ``"none"`` (unlit albedo).
        roughness: scalar or ``(H, W, 1)`` material roughness for ``shading="pbr"``.
        metallic: scalar or ``(H, W, 1)`` material metalness for ``shading="pbr"``.
        background: scalar or ``(3,)`` background color.

    Returns:
        ``{"image", "alpha", "depth", "normals"}``.
    """
    if ssaa > 1:
        big = _scale_camera(camera, ssaa)
        hi = render_mesh(
            big,
            mesh_or_verts,
            faces,
            verts_colors,
            texture,
            verts_uvs,
            faces_uvs,
            lights,
            shininess,
            specular_strength,
            roughness,
            metallic,
            background,
            shading,
            ssaa=1,
        )
        return _downsample_passes(hi, ssaa)

    mesh = mesh_or_verts if isinstance(mesh_or_verts, Meshes) else Meshes([mesh_or_verts], [faces])
    verts = mesh.verts_packed()

    frag = rasterize_meshes(camera, mesh)
    positions = interpolate_face_attributes(frag, verts)  # world-space AOV
    if texture is not None:
        if verts_uvs is None or faces_uvs is None:
            raise ValueError("verts_uvs and faces_uvs are required with a texture.")
        # Interpolate per-corner UVs over the fragments, then sample the texture.
        fidx = mx.where(frag.valid, frag.pix_to_face, 0)
        uv_tri = verts_uvs[faces_uvs.astype(mx.int32)[fidx]]  # (H, W, 3, 2)
        uv = mx.sum(frag.bary[..., None] * uv_tri, axis=-2)  # (H, W, 2)
        albedo = sample_texture(texture, uv) * frag.valid[..., None]
    else:
        if verts_colors is None:
            verts_colors = mx.full((verts.shape[0], 3), 0.7)
        albedo = interpolate_face_attributes(frag, verts_colors)

    if shading == "none":
        image = albedo
        normals_px = mx.zeros_like(albedo)
    else:
        if lights is None:
            lights = [
                DirectionalLights(direction=(-1.0, -1.0, -0.6), color=(1.0, 1.0, 1.0)),
                AmbientLights(color=(0.25, 0.25, 0.25)),
            ]
        vnormals = mesh.verts_normals_packed()
        normals_px = interpolate_face_attributes(frag, vnormals)
        normals_px = normals_px / mx.maximum(
            mx.linalg.norm(normals_px, axis=-1, keepdims=True), 1e-8
        )
        # Two-sided shading: orient each normal toward the camera so meshes with
        # inward/inconsistent winding still light correctly (a black mesh from a
        # flipped normal is the most common surprise otherwise).
        view_dir = _arr(camera.camera_center) - positions
        facing = mx.sum(normals_px * view_dir, axis=-1, keepdims=True) < 0
        normals_px = mx.where(facing, -normals_px, normals_px)
        if shading == "phong":
            image = phong_shading(
                positions,
                normals_px,
                albedo,
                camera.camera_center,
                lights,
                shininess,
                specular_strength,
            )
        elif shading == "pbr":
            image = pbr_shading(
                positions,
                normals_px,
                albedo,
                camera.camera_center,
                lights,
                roughness=roughness,
                metallic=metallic,
            )
        else:
            raise ValueError("shading must be 'phong', 'pbr', or 'none'.")

    alpha = frag.valid.astype(mx.float32)
    bg = _arr(background)
    if bg.ndim == 0:
        bg = mx.broadcast_to(bg, (3,))
    image = image * alpha[..., None] + bg * (1.0 - alpha[..., None])
    return {
        "image": image,
        "alpha": alpha,
        "depth": frag.zbuf,
        "normals": normals_px,
        "position": positions,
        "face_id": frag.pix_to_face,
    }