Skip to content

mlx3d.losses

mlx3d.losses

LPIPS

Bases: Module

Learned Perceptual Image Patch Similarity (VGG-16).

Call lpips(pred, target) with images shaped (H, W, 3) or (N, H, W, 3) in [0, 1]; returns a scalar perceptual distance (lower = more similar). Differentiable w.r.t. the images.

Source code in src/mlx3d/losses/lpips.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class LPIPS(nn.Module):
    """Learned Perceptual Image Patch Similarity (VGG-16).

    Call ``lpips(pred, target)`` with images shaped ``(H, W, 3)`` or
    ``(N, H, W, 3)`` in ``[0, 1]``; returns a scalar perceptual distance
    (lower = more similar). Differentiable w.r.t. the images.
    """

    def __init__(self):
        super().__init__()
        blocks = []
        for block in _VGG_BLOCKS:
            convs = []
            for cin, cout in zip(block[:-1], block[1:]):
                convs.append(nn.Conv2d(cin, cout, kernel_size=3, padding=1))
            blocks.append(convs)
        self.blocks = blocks
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # Per-channel "lin" weights, one vector per tapped stage. Kept
        # non-negative (via abs in the forward) so the result is a valid
        # distance, matching the constrained LPIPS linear heads.
        self.lin_weights = [mx.random.uniform(shape=(c,)) * 0.1 for c in _TAP_CHANNELS]

    def _features(self, x: mx.array) -> list[mx.array]:
        feats = []
        h = x
        for bi, convs in enumerate(self.blocks):
            for conv in convs:
                h = nn.relu(conv(h))
            feats.append(h)
            if bi < len(self.blocks) - 1:
                h = self.pool(h)
        return feats

    def __call__(self, pred: mx.array, target: mx.array) -> mx.array:
        if pred.ndim == 3:
            pred, target = pred[None], target[None]
        shift = mx.array(_SHIFT)
        scale = mx.array(_SCALE)
        # [0, 1] -> [-1, 1] -> LPIPS scaling.
        a = (pred * 2.0 - 1.0 - shift) / scale
        b = (target * 2.0 - 1.0 - shift) / scale

        fa, fb = self._features(a), self._features(b)
        total = mx.zeros(())
        for feat_a, feat_b, w in zip(fa, fb, self.lin_weights):
            na = feat_a / mx.maximum(mx.linalg.norm(feat_a, axis=-1, keepdims=True), 1e-10)
            nb = feat_b / mx.maximum(mx.linalg.norm(feat_b, axis=-1, keepdims=True), 1e-10)
            diff = (na - nb) ** 2  # (N, H, W, C)
            total = total + mx.mean(mx.sum(diff * mx.abs(w), axis=-1))
        return total

    def load_weights_file(self, path: str) -> None:
        """Load converted VGG-16 + lin weights from a ``.safetensors``/``.npz`` file.

        The file must contain MLX arrays matching this module's parameter tree
        (see the conversion recipe in the docs).
        """
        self.load_weights(path)

load_weights_file(path)

Load converted VGG-16 + lin weights from a .safetensors/.npz file.

The file must contain MLX arrays matching this module's parameter tree (see the conversion recipe in the docs).

Source code in src/mlx3d/losses/lpips.py
92
93
94
95
96
97
98
def load_weights_file(self, path: str) -> None:
    """Load converted VGG-16 + lin weights from a ``.safetensors``/``.npz`` file.

    The file must contain MLX arrays matching this module's parameter tree
    (see the conversion recipe in the docs).
    """
    self.load_weights(path)

chamfer_distance(x, y, x_normals=None, y_normals=None, single_directional=False)

Bidirectional (squared) chamfer distance between batches of point clouds.

Parameters:

Name Type Description Default
x array

(N, P1, 3) or (P1, 3).

required
y array

(N, P2, 3) or (P2, 3).

required
x_normals array | None

optional (N, P1, 3) normals for x.

None
y_normals array | None

optional (N, P2, 3) normals for y; if both are given, a normal-consistency term (1 - |cos|) is also returned.

None
single_directional bool

only use the x -> y direction.

False

Returns:

Type Description
array

(loss, loss_normals); loss_normals is None if normals were

array | None

not provided.

Source code in src/mlx3d/losses/chamfer.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def chamfer_distance(
    x: mx.array,
    y: mx.array,
    x_normals: mx.array | None = None,
    y_normals: mx.array | None = None,
    single_directional: bool = False,
) -> tuple[mx.array, mx.array | None]:
    """Bidirectional (squared) chamfer distance between batches of point clouds.

    Args:
        x: (N, P1, 3) or (P1, 3).
        y: (N, P2, 3) or (P2, 3).
        x_normals: optional (N, P1, 3) normals for ``x``.
        y_normals: optional (N, P2, 3) normals for ``y``; if both are given,
            a normal-consistency term (1 - |cos|) is also returned.
        single_directional: only use the x -> y direction.

    Returns:
        ``(loss, loss_normals)``; ``loss_normals`` is ``None`` if normals were
        not provided.
    """
    squeeze = x.ndim == 2
    if squeeze:
        x, y = x[None], y[None]
        x_normals = x_normals[None] if x_normals is not None else None
        y_normals = y_normals[None] if y_normals is not None else None

    d_xy, idx_xy = knn_points(x, y, K=1)
    cham_x = d_xy[..., 0].mean()
    if single_directional:
        loss = cham_x
    else:
        d_yx, idx_yx = knn_points(y, x, K=1)
        loss = cham_x + d_yx[..., 0].mean()

    loss_normals = None
    if x_normals is not None and y_normals is not None:
        nn_y = mx.take_along_axis(
            y_normals,
            mx.broadcast_to(idx_xy[..., 0:1].astype(mx.int32), x.shape),
            axis=1,
        )
        cos_x = mx.sum(x_normals * nn_y, axis=-1)
        loss_normals = (1.0 - mx.abs(cos_x)).mean()
        if not single_directional:
            nn_x = mx.take_along_axis(
                x_normals,
                mx.broadcast_to(idx_yx[..., 0:1].astype(mx.int32), y.shape),
                axis=1,
            )
            cos_y = mx.sum(y_normals * nn_x, axis=-1)
            loss_normals = loss_normals + (1.0 - mx.abs(cos_y)).mean()

    return loss, loss_normals

ms_ssim(pred, target, max_val=1.0, window_size=11, sigma=1.5, weights=_MSSSIM_WEIGHTS)

Multi-scale SSIM (Wang et al., 2003) — a perceptual image-quality metric.

Evaluates SSIM across len(weights) scales (2x average-pooling between them), combining the contrast-structure term at coarse scales with the full SSIM at the finest. More correlated with perceived quality than single-scale SSIM, and unlike LPIPS needs no pretrained network. Fully differentiable; use 1 - ms_ssim(...) as a loss.

Parameters:

Name Type Description Default
pred array

(H, W, C) or (N, H, W, C) image in [0, max_val].

required
target array

image with the same shape as pred.

required
weights tuple[float, ...]

per-scale weights; the image must be larger than window_size * 2**(len(weights) - 1).

_MSSSIM_WEIGHTS

Returns:

Type Description
array

Scalar MS-SSIM in [0, 1].

Source code in src/mlx3d/losses/image_metrics.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def ms_ssim(
    pred: mx.array,
    target: mx.array,
    max_val: float = 1.0,
    window_size: int = 11,
    sigma: float = 1.5,
    weights: tuple[float, ...] = _MSSSIM_WEIGHTS,
) -> mx.array:
    """Multi-scale SSIM (Wang et al., 2003) — a perceptual image-quality metric.

    Evaluates SSIM across ``len(weights)`` scales (2x average-pooling between
    them), combining the contrast-structure term at coarse scales with the full
    SSIM at the finest. More correlated with perceived quality than single-scale
    SSIM, and unlike LPIPS needs no pretrained network. Fully differentiable;
    use ``1 - ms_ssim(...)`` as a loss.

    Args:
        pred: (H, W, C) or (N, H, W, C) image in [0, max_val].
        target: image with the same shape as ``pred``.
        weights: per-scale weights; the image must be larger than
            ``window_size * 2**(len(weights) - 1)``.

    Returns:
        Scalar MS-SSIM in [0, 1].
    """
    if pred.ndim == 3:
        pred, target = pred[None], target[None]
    n_scales = len(weights)
    min_hw = window_size * (2 ** (n_scales - 1))
    if min(pred.shape[1], pred.shape[2]) < min_hw:
        raise ValueError(
            f"ms_ssim needs spatial dims >= {min_hw} for {n_scales} scales; "
            f"got {pred.shape[1]}x{pred.shape[2]}. Use fewer weights or a larger image."
        )

    cs_factors = []
    last_ssim = None
    for i, _ in enumerate(weights):
        ssim_mean, cs_mean = _ssim_maps(pred, target, max_val, window_size, sigma)
        if i < n_scales - 1:
            cs_factors.append(mx.maximum(cs_mean, 0.0))
            pred, target = _avgpool2(pred), _avgpool2(target)
        else:
            last_ssim = mx.maximum(ssim_mean, 0.0)

    out = last_ssim ** weights[-1]
    for cs, w in zip(cs_factors, weights[:-1]):
        out = out * cs**w
    return out

psnr(pred, target, max_val=1.0)

Peak signal-to-noise ratio in dB. Images in [0, max_val].

Source code in src/mlx3d/losses/image_metrics.py
17
18
19
20
def psnr(pred: mx.array, target: mx.array, max_val: float = 1.0) -> mx.array:
    """Peak signal-to-noise ratio in dB. Images in [0, max_val]."""
    mse = ((pred - target) ** 2).mean()
    return 10.0 * mx.log10((max_val * max_val) / mx.maximum(mse, 1e-12))

ssim(pred, target, max_val=1.0, window_size=11, sigma=1.5)

Structural similarity (mean SSIM) with a Gaussian window.

Parameters:

Name Type Description Default
pred array

(H, W, C) or (N, H, W, C) image in [0, max_val].

required
target array

image with the same shape as pred.

required

Returns:

Type Description
array

Scalar mean SSIM. Use 1 - ssim(...) as a loss (the Gaussian

array

filtering is fully differentiable).

Source code in src/mlx3d/losses/image_metrics.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def ssim(
    pred: mx.array,
    target: mx.array,
    max_val: float = 1.0,
    window_size: int = 11,
    sigma: float = 1.5,
) -> mx.array:
    """Structural similarity (mean SSIM) with a Gaussian window.

    Args:
        pred: (H, W, C) or (N, H, W, C) image in [0, max_val].
        target: image with the same shape as ``pred``.

    Returns:
        Scalar mean SSIM. Use ``1 - ssim(...)`` as a loss (the Gaussian
        filtering is fully differentiable).
    """
    if pred.ndim == 3:
        pred, target = pred[None], target[None]
    return _ssim_maps(pred, target, max_val, window_size, sigma)[0]

mesh_edge_loss(meshes, target_length=0.0)

Mean squared deviation of edge lengths from target_length.

Source code in src/mlx3d/losses/mesh_losses.py
11
12
13
14
15
16
17
18
19
20
def mesh_edge_loss(meshes: Meshes, target_length: float = 0.0) -> mx.array:
    """Mean squared deviation of edge lengths from ``target_length``."""
    verts = meshes.verts_packed()
    edges = meshes.edges_packed()
    if edges.shape[0] == 0:
        return mx.array(0.0)
    v0 = verts[edges[:, 0]]
    v1 = verts[edges[:, 1]]
    lengths = mx.linalg.norm(v0 - v1, axis=-1)
    return ((lengths - target_length) ** 2).mean()

mesh_laplacian_smoothing(meshes, method='uniform')

Laplacian smoothing loss: mean norm of the uniform graph Laplacian.

For each vertex, measures the distance to the centroid of its neighbors; minimizing it pulls the surface toward locally smooth configurations.

Source code in src/mlx3d/losses/mesh_losses.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def mesh_laplacian_smoothing(meshes: Meshes, method: str = "uniform") -> mx.array:
    """Laplacian smoothing loss: mean norm of the uniform graph Laplacian.

    For each vertex, measures the distance to the centroid of its neighbors;
    minimizing it pulls the surface toward locally smooth configurations.
    """
    if method != "uniform":
        raise NotImplementedError("Only the 'uniform' Laplacian is implemented.")
    verts = meshes.verts_packed()
    edges = meshes.edges_packed()
    if edges.shape[0] == 0:
        return mx.array(0.0)

    # Sum of neighbors and neighbor counts via scatter-add over both directions.
    neighbor_sum = mx.zeros_like(verts)
    counts = mx.zeros((verts.shape[0],))
    e0, e1 = edges[:, 0], edges[:, 1]
    neighbor_sum = neighbor_sum.at[e0].add(verts[e1])
    neighbor_sum = neighbor_sum.at[e1].add(verts[e0])
    ones = mx.ones((edges.shape[0],))
    counts = counts.at[e0].add(ones)
    counts = counts.at[e1].add(ones)
    counts = mx.maximum(counts, 1.0)[:, None]
    lap = neighbor_sum / counts - verts
    return mx.linalg.norm(lap, axis=-1).mean()

mesh_normal_consistency(meshes)

Penalty on the angle between normals of faces sharing an edge.

Returns the mean of 1 - cos(n_a, n_b) over all interior edges. The face-adjacency structure is topology-only and computed on CPU once.

Source code in src/mlx3d/losses/mesh_losses.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def mesh_normal_consistency(meshes: Meshes) -> mx.array:
    """Penalty on the angle between normals of faces sharing an edge.

    Returns the mean of ``1 - cos(n_a, n_b)`` over all interior edges. The
    face-adjacency structure is topology-only and computed on CPU once.
    """
    faces_np = np.array(meshes.faces_packed())
    if faces_np.shape[0] == 0:
        return mx.array(0.0)

    edges = np.concatenate([faces_np[:, [0, 1]], faces_np[:, [1, 2]], faces_np[:, [2, 0]]], axis=0)
    edges.sort(axis=1)
    face_idx = np.tile(np.arange(faces_np.shape[0]), 3)
    order = np.lexsort((edges[:, 1], edges[:, 0]))
    edges_sorted = edges[order]
    faces_sorted = face_idx[order]
    same = (edges_sorted[1:] == edges_sorted[:-1]).all(axis=1)
    pair_a = faces_sorted[:-1][same]
    pair_b = faces_sorted[1:][same]
    if pair_a.size == 0:
        return mx.array(0.0)

    normals = meshes.faces_normals_packed()
    normals = normals / mx.maximum(mx.linalg.norm(normals, axis=-1, keepdims=True), 1e-12)
    na = normals[mx.array(pair_a.astype(np.int32))]
    nb = normals[mx.array(pair_b.astype(np.int32))]
    cos = mx.sum(na * nb, axis=-1)
    return (1.0 - cos).mean()

closest_point_on_triangle(p, a, b, c)

Closest point on triangle (a, b, c) to each query p.

All inputs broadcast to a common (..., 3) shape. Uses the Voronoi-region method (Ericson, Real-Time Collision Detection), fully vectorized.

Source code in src/mlx3d/losses/point_mesh.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def closest_point_on_triangle(p: mx.array, a: mx.array, b: mx.array, c: mx.array) -> mx.array:
    """Closest point on triangle ``(a, b, c)`` to each query ``p``.

    All inputs broadcast to a common ``(..., 3)`` shape. Uses the Voronoi-region
    method (Ericson, *Real-Time Collision Detection*), fully vectorized.
    """
    ab = b - a
    ac = c - a
    ap = p - a
    d1 = mx.sum(ab * ap, axis=-1, keepdims=True)
    d2 = mx.sum(ac * ap, axis=-1, keepdims=True)

    bp = p - b
    d3 = mx.sum(ab * bp, axis=-1, keepdims=True)
    d4 = mx.sum(ac * bp, axis=-1, keepdims=True)

    cp = p - c
    d5 = mx.sum(ab * cp, axis=-1, keepdims=True)
    d6 = mx.sum(ac * cp, axis=-1, keepdims=True)

    vc = d1 * d4 - d3 * d2
    vb = d5 * d2 - d1 * d6
    va = d3 * d6 - d5 * d4

    eps = 1e-12
    # Edge AB: a + v*ab
    v_ab = d1 / mx.where(mx.abs(d1 - d3) < eps, mx.ones_like(d1), d1 - d3)
    pt_ab = a + mx.clip(v_ab, 0.0, 1.0) * ab
    # Edge AC: a + w*ac
    w_ac = d2 / mx.where(mx.abs(d2 - d6) < eps, mx.ones_like(d2), d2 - d6)
    pt_ac = a + mx.clip(w_ac, 0.0, 1.0) * ac
    # Edge BC: b + w*(c-b)
    denom_bc = (d4 - d3) + (d5 - d6)
    w_bc = (d4 - d3) / mx.where(mx.abs(denom_bc) < eps, mx.ones_like(denom_bc), denom_bc)
    pt_bc = b + mx.clip(w_bc, 0.0, 1.0) * (c - b)
    # Face interior (barycentric)
    denom = va + vb + vc
    denom = mx.where(mx.abs(denom) < eps, mx.ones_like(denom), denom)
    v = vb / denom
    w = vc / denom
    pt_face = a + ab * v + ac * w

    # Region selection by priority (vertices -> edges -> face).
    in_a = (d1 <= 0) & (d2 <= 0)
    in_b = (d3 >= 0) & (d4 <= d3)
    in_c = (d6 >= 0) & (d5 <= d6)
    in_ab = (vc <= 0) & (d1 >= 0) & (d3 <= 0)
    in_ac = (vb <= 0) & (d2 >= 0) & (d6 <= 0)
    in_bc = (va <= 0) & ((d4 - d3) >= 0) & ((d5 - d6) >= 0)

    out = pt_face
    out = mx.where(in_bc, pt_bc, out)
    out = mx.where(in_ac, pt_ac, out)
    out = mx.where(in_ab, pt_ab, out)
    out = mx.where(in_c, c, out)
    out = mx.where(in_b, b, out)
    out = mx.where(in_a, a, out)
    return out

point_mesh_face_distance(meshes, points, face_chunk_size=2048)

Mean squared distance from each point to the nearest mesh face.

Parameters:

Name Type Description Default
meshes Meshes

a single-mesh :class:~mlx3d.structures.Meshes.

required
points array

(P, 3) query points.

required
face_chunk_size int

faces processed per chunk to bound the (P, F) memory. Lower it for very large meshes.

2048

Returns:

Type Description
array

Scalar mean over points of the squared distance to the closest triangle.

array

Differentiable w.r.t. both points and mesh vertices.

Source code in src/mlx3d/losses/point_mesh.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def point_mesh_face_distance(
    meshes: Meshes,
    points: mx.array,
    face_chunk_size: int = 2048,
) -> mx.array:
    """Mean squared distance from each point to the nearest mesh face.

    Args:
        meshes: a single-mesh :class:`~mlx3d.structures.Meshes`.
        points: ``(P, 3)`` query points.
        face_chunk_size: faces processed per chunk to bound the ``(P, F)``
            memory. Lower it for very large meshes.

    Returns:
        Scalar mean over points of the squared distance to the closest triangle.
        Differentiable w.r.t. both points and mesh vertices.
    """
    verts = meshes.verts_packed()
    faces = meshes.faces_packed().astype(mx.int32)
    tri = verts[faces]  # (F, 3, 3)
    p = points[:, None, :]  # (P, 1, 3)

    best = None
    for start in range(0, tri.shape[0], face_chunk_size):
        chunk = tri[start : start + face_chunk_size]  # (C, 3, 3)
        a = chunk[:, 0, :][None]  # (1, C, 3)
        b = chunk[:, 1, :][None]
        c = chunk[:, 2, :][None]
        closest = closest_point_on_triangle(p, a, b, c)  # (P, C, 3)
        d = mx.sum((p - closest) ** 2, axis=-1)  # (P, C)
        chunk_min = mx.min(d, axis=-1)  # (P,)
        best = chunk_min if best is None else mx.minimum(best, chunk_min)
    return mx.mean(best)