Skip to content

mlx3d.transforms

mlx3d.transforms

Transform3d

A batched rigid/affine 3D transform: x -> R @ x + t.

Stores a rotation/scale-shear block rot (..., 3, 3) and a translation trans (..., 3). Transforms compose, invert, and apply to points and normals, and are fully differentiable (handy for pose optimization).

Build one with the constructors :meth:from_rot_trans, :meth:translate, :meth:rotate, :meth:scale, or compose with :meth:compose / @.

Source code in src/mlx3d/transforms/se3.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
class Transform3d:
    """A batched rigid/affine 3D transform: ``x -> R @ x + t``.

    Stores a rotation/scale-shear block ``rot`` ``(..., 3, 3)`` and a translation
    ``trans`` ``(..., 3)``. Transforms compose, invert, and apply to points and
    normals, and are fully differentiable (handy for pose optimization).

    Build one with the constructors :meth:`from_rot_trans`, :meth:`translate`,
    :meth:`rotate`, :meth:`scale`, or compose with :meth:`compose` / ``@``.
    """

    def __init__(self, rot: mx.array | None = None, trans: mx.array | None = None):
        if rot is None:
            rot = mx.eye(3)
        if trans is None:
            trans = mx.zeros(rot.shape[:-2] + (3,))
        self.rot = rot
        self.trans = trans

    # ------------------------------------------------------------ constructors
    @classmethod
    def from_rot_trans(cls, rot: mx.array, trans: mx.array) -> "Transform3d":
        return cls(rot, trans)

    @classmethod
    def translate(cls, t: mx.array) -> "Transform3d":
        t = mx.array(t)
        return cls(mx.broadcast_to(mx.eye(3), t.shape[:-1] + (3, 3)), t)

    @classmethod
    def rotate(cls, rot: mx.array) -> "Transform3d":
        return cls(rot, mx.zeros(rot.shape[:-2] + (3,)))

    @classmethod
    def scale(cls, s: mx.array | float) -> "Transform3d":
        s = mx.array(s, dtype=mx.float32)
        if s.ndim == 0:
            s = mx.broadcast_to(s, (3,))
        diag = s[..., :, None] * mx.eye(3)
        return cls(diag, mx.zeros(diag.shape[:-2] + (3,)))

    # ---------------------------------------------------------------- algebra
    def compose(self, other: "Transform3d") -> "Transform3d":
        """Return the transform that applies ``self`` first, then ``other``."""
        rot = other.rot @ self.rot
        trans = (other.rot @ self.trans[..., None])[..., 0] + other.trans
        return Transform3d(rot, trans)

    def __matmul__(self, other: "Transform3d") -> "Transform3d":
        # (self @ other) applies `other` first, then `self` — matrix-style.
        return other.compose(self)

    def inverse(self) -> "Transform3d":
        rinv = mx.linalg.inv(self.rot, stream=mx.cpu)
        return Transform3d(rinv, -(rinv @ self.trans[..., None])[..., 0])

    # ---------------------------------------------------------------- apply
    def transform_points(self, points: mx.array) -> mx.array:
        """Apply to points ``(..., 3)``: ``R @ x + t`` (row-vector form, broadcasts)."""
        return points @ mx.swapaxes(self.rot, -1, -2) + self.trans[..., None, :]

    def transform_normals(self, normals: mx.array) -> mx.array:
        """Apply to normals with the inverse transform (ignores translation)."""
        return normals @ mx.linalg.inv(self.rot, stream=mx.cpu)

    def get_matrix(self) -> mx.array:
        """Return the homogeneous ``(..., 4, 4)`` matrix form."""
        top = mx.concatenate([self.rot, self.trans[..., :, None]], axis=-1)  # (...,3,4)
        bottom = mx.broadcast_to(mx.array([0.0, 0.0, 0.0, 1.0]), top.shape[:-2] + (1, 4))
        return mx.concatenate([top, bottom], axis=-2)

    def __repr__(self) -> str:
        return f"Transform3d(rot={self.rot.shape}, trans={self.trans.shape})"

compose(other)

Return the transform that applies self first, then other.

Source code in src/mlx3d/transforms/se3.py
183
184
185
186
187
def compose(self, other: "Transform3d") -> "Transform3d":
    """Return the transform that applies ``self`` first, then ``other``."""
    rot = other.rot @ self.rot
    trans = (other.rot @ self.trans[..., None])[..., 0] + other.trans
    return Transform3d(rot, trans)

transform_points(points)

Apply to points (..., 3): R @ x + t (row-vector form, broadcasts).

Source code in src/mlx3d/transforms/se3.py
198
199
200
def transform_points(self, points: mx.array) -> mx.array:
    """Apply to points ``(..., 3)``: ``R @ x + t`` (row-vector form, broadcasts)."""
    return points @ mx.swapaxes(self.rot, -1, -2) + self.trans[..., None, :]

transform_normals(normals)

Apply to normals with the inverse transform (ignores translation).

Source code in src/mlx3d/transforms/se3.py
202
203
204
def transform_normals(self, normals: mx.array) -> mx.array:
    """Apply to normals with the inverse transform (ignores translation)."""
    return normals @ mx.linalg.inv(self.rot, stream=mx.cpu)

get_matrix()

Return the homogeneous (..., 4, 4) matrix form.

Source code in src/mlx3d/transforms/se3.py
206
207
208
209
210
def get_matrix(self) -> mx.array:
    """Return the homogeneous ``(..., 4, 4)`` matrix form."""
    top = mx.concatenate([self.rot, self.trans[..., :, None]], axis=-1)  # (...,3,4)
    bottom = mx.broadcast_to(mx.array([0.0, 0.0, 0.0, 1.0]), top.shape[:-2] + (1, 4))
    return mx.concatenate([top, bottom], axis=-2)

axis_angle_to_matrix(axis_angle)

Convert axis-angle vectors (..., 3) to rotation matrices (..., 3, 3).

Source code in src/mlx3d/transforms/rotations.py
121
122
123
def axis_angle_to_matrix(axis_angle: mx.array) -> mx.array:
    """Convert axis-angle vectors ``(..., 3)`` to rotation matrices ``(..., 3, 3)``."""
    return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))

axis_angle_to_quaternion(axis_angle)

Convert axis-angle vectors (..., 3) to quaternions (..., 4).

Source code in src/mlx3d/transforms/rotations.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def axis_angle_to_quaternion(axis_angle: mx.array) -> mx.array:
    """Convert axis-angle vectors ``(..., 3)`` to quaternions ``(..., 4)``."""
    angles = mx.linalg.norm(axis_angle, axis=-1, keepdims=True)
    half = angles * 0.5
    eps = 1e-6
    small = angles < eps
    # sin(x/2)/x ~= 1/2 - x^2/48 for small x
    sin_half_over_angle = mx.where(
        small, 0.5 - (angles * angles) / 48.0, mx.sin(half) / mx.maximum(angles, eps)
    )
    return mx.concatenate([mx.cos(half), axis_angle * sin_half_over_angle], axis=-1)

euler_angles_to_matrix(euler_angles, convention='XYZ')

Convert Euler angles (..., 3) (radians) to rotation matrices (..., 3, 3).

convention is a 3-letter string of axes, e.g. "XYZ" applies R = R_X(a0) @ R_Y(a1) @ R_Z(a2).

Source code in src/mlx3d/transforms/rotations.py
145
146
147
148
149
150
151
152
153
154
def euler_angles_to_matrix(euler_angles: mx.array, convention: str = "XYZ") -> mx.array:
    """Convert Euler angles ``(..., 3)`` (radians) to rotation matrices ``(..., 3, 3)``.

    ``convention`` is a 3-letter string of axes, e.g. ``"XYZ"`` applies
    R = R_X(a0) @ R_Y(a1) @ R_Z(a2).
    """
    if len(convention) != 3 or any(c not in "XYZ" for c in convention):
        raise ValueError(f"Invalid convention {convention!r}.")
    matrices = [_axis_rotation(axis, euler_angles[..., i]) for i, axis in enumerate(convention)]
    return matrices[0] @ matrices[1] @ matrices[2]

matrix_to_axis_angle(matrix)

Convert rotation matrices (..., 3, 3) to axis-angle vectors (..., 3).

Source code in src/mlx3d/transforms/rotations.py
126
127
128
def matrix_to_axis_angle(matrix: mx.array) -> mx.array:
    """Convert rotation matrices ``(..., 3, 3)`` to axis-angle vectors ``(..., 3)``."""
    return quaternion_to_axis_angle(matrix_to_quaternion(matrix))

matrix_to_quaternion(matrix)

Convert rotation matrices (..., 3, 3) to quaternions (..., 4) in (w, x, y, z) order.

Uses the numerically stable branch selection of Shepperd's method.

Source code in src/mlx3d/transforms/rotations.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def matrix_to_quaternion(matrix: mx.array) -> mx.array:
    """Convert rotation matrices ``(..., 3, 3)`` to quaternions ``(..., 4)`` in (w, x, y, z) order.

    Uses the numerically stable branch selection of Shepperd's method.
    """
    m = matrix
    m00, m01, m02 = m[..., 0, 0], m[..., 0, 1], m[..., 0, 2]
    m10, m11, m12 = m[..., 1, 0], m[..., 1, 1], m[..., 1, 2]
    m20, m21, m22 = m[..., 2, 0], m[..., 2, 1], m[..., 2, 2]

    # Four candidate quaternions, one per branch of Shepperd's method. Each
    # uses 2*sqrt(t_*) as the denominator, where t_* is the (always positive
    # in its branch) diagonal combination.
    t_w = 1 + m00 + m11 + m22
    t_x = 1 + m00 - m11 - m22
    t_y = 1 - m00 + m11 - m22
    t_z = 1 - m00 - m11 + m22

    def _denom(t):
        return 2.0 * mx.sqrt(mx.maximum(t, 1e-12))[..., None]

    q_w = mx.stack([t_w, m21 - m12, m02 - m20, m10 - m01], axis=-1) / _denom(t_w)
    q_x = mx.stack([m21 - m12, t_x, m01 + m10, m02 + m20], axis=-1) / _denom(t_x)
    q_y = mx.stack([m02 - m20, m01 + m10, t_y, m12 + m21], axis=-1) / _denom(t_y)
    q_z = mx.stack([m10 - m01, m02 + m20, m12 + m21, t_z], axis=-1) / _denom(t_z)

    trace = m00 + m11 + m22
    cond_w = (trace > 0)[..., None]
    cond_x = ((m00 >= m11) & (m00 >= m22))[..., None]
    cond_y = (m11 >= m22)[..., None]

    q = mx.where(cond_w, q_w, mx.where(cond_x, q_x, mx.where(cond_y, q_y, q_z)))
    return standardize_quaternion(q / mx.linalg.norm(q, axis=-1, keepdims=True))

matrix_to_rotation_6d(matrix)

Convert rotation matrices (..., 3, 3) to the 6D representation (..., 6).

Source code in src/mlx3d/transforms/rotations.py
171
172
173
def matrix_to_rotation_6d(matrix: mx.array) -> mx.array:
    """Convert rotation matrices ``(..., 3, 3)`` to the 6D representation ``(..., 6)``."""
    return matrix[..., :2, :].reshape(*matrix.shape[:-2], 6)

quaternion_apply(quaternion, point)

Rotate points (..., 3) by unit quaternions (..., 4).

Source code in src/mlx3d/transforms/rotations.py
196
197
198
199
200
201
def quaternion_apply(quaternion: mx.array, point: mx.array) -> mx.array:
    """Rotate points ``(..., 3)`` by unit quaternions ``(..., 4)``."""
    zeros = mx.zeros_like(point[..., :1])
    p = mx.concatenate([zeros, point], axis=-1)
    out = quaternion_multiply(quaternion_multiply(quaternion, p), quaternion_invert(quaternion))
    return out[..., 1:]

quaternion_invert(quaternion)

Inverse of unit quaternions: the conjugate.

Source code in src/mlx3d/transforms/rotations.py
191
192
193
def quaternion_invert(quaternion: mx.array) -> mx.array:
    """Inverse of unit quaternions: the conjugate."""
    return quaternion * mx.array([1.0, -1.0, -1.0, -1.0], dtype=quaternion.dtype)

quaternion_multiply(a, b)

Hamilton product of two quaternion arrays (..., 4).

Source code in src/mlx3d/transforms/rotations.py
176
177
178
179
180
181
182
183
184
185
186
187
188
def quaternion_multiply(a: mx.array, b: mx.array) -> mx.array:
    """Hamilton product of two quaternion arrays ``(..., 4)``."""
    aw, ax, ay, az = a[..., 0], a[..., 1], a[..., 2], a[..., 3]
    bw, bx, by, bz = b[..., 0], b[..., 1], b[..., 2], b[..., 3]
    return mx.stack(
        [
            aw * bw - ax * bx - ay * by - az * bz,
            aw * bx + ax * bw + ay * bz - az * by,
            aw * by - ax * bz + ay * bw + az * bx,
            aw * bz + ax * by - ay * bx + az * bw,
        ],
        axis=-1,
    )

quaternion_to_axis_angle(quaternions)

Convert quaternions (..., 4) to axis-angle vectors (..., 3).

Source code in src/mlx3d/transforms/rotations.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def quaternion_to_axis_angle(quaternions: mx.array) -> mx.array:
    """Convert quaternions ``(..., 4)`` to axis-angle vectors ``(..., 3)``."""
    q = standardize_quaternion(quaternions / mx.linalg.norm(quaternions, axis=-1, keepdims=True))
    norms = mx.linalg.norm(q[..., 1:], axis=-1, keepdims=True)
    half_angles = mx.arctan2(norms, q[..., 0:1])
    angles = 2.0 * half_angles
    eps = 1e-6
    small = mx.abs(angles) < eps
    sin_half_over_angle = mx.where(
        small,
        0.5 - (angles * angles) / 48.0,
        mx.sin(half_angles) / mx.where(small, mx.ones_like(angles), angles),
    )
    return q[..., 1:] / sin_half_over_angle

quaternion_to_matrix(quaternions)

Convert quaternions (..., 4) in (w, x, y, z) order to rotation matrices (..., 3, 3).

Source code in src/mlx3d/transforms/rotations.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def quaternion_to_matrix(quaternions: mx.array) -> mx.array:
    """Convert quaternions ``(..., 4)`` in (w, x, y, z) order to rotation matrices ``(..., 3, 3)``."""
    q = quaternions / mx.linalg.norm(quaternions, axis=-1, keepdims=True)
    w, x, y, z = q[..., 0], q[..., 1], q[..., 2], q[..., 3]
    two = 2.0
    m = mx.stack(
        [
            1 - two * (y * y + z * z),
            two * (x * y - w * z),
            two * (x * z + w * y),
            two * (x * y + w * z),
            1 - two * (x * x + z * z),
            two * (y * z - w * x),
            two * (x * z - w * y),
            two * (y * z + w * x),
            1 - two * (x * x + y * y),
        ],
        axis=-1,
    )
    return m.reshape(*quaternions.shape[:-1], 3, 3)

random_quaternions(n, key=None)

Sample n uniform random unit quaternions, shape (n, 4).

Source code in src/mlx3d/transforms/rotations.py
204
205
206
207
208
209
210
def random_quaternions(n: int, key: mx.array | None = None) -> mx.array:
    """Sample ``n`` uniform random unit quaternions, shape ``(n, 4)``."""
    if key is None:
        q = mx.random.normal((n, 4))
    else:
        q = mx.random.normal((n, 4), key=key)
    return standardize_quaternion(q / mx.linalg.norm(q, axis=-1, keepdims=True))

random_rotations(n, key=None)

Sample n uniform random rotation matrices, shape (n, 3, 3).

Source code in src/mlx3d/transforms/rotations.py
213
214
215
def random_rotations(n: int, key: mx.array | None = None) -> mx.array:
    """Sample ``n`` uniform random rotation matrices, shape ``(n, 3, 3)``."""
    return quaternion_to_matrix(random_quaternions(n, key=key))

rotation_6d_to_matrix(d6)

Convert 6D rotation representation (..., 6) to matrices via Gram-Schmidt.

Reference: Zhou et al., "On the Continuity of Rotation Representations in Neural Networks" (CVPR 2019).

Source code in src/mlx3d/transforms/rotations.py
157
158
159
160
161
162
163
164
165
166
167
168
def rotation_6d_to_matrix(d6: mx.array) -> mx.array:
    """Convert 6D rotation representation ``(..., 6)`` to matrices via Gram-Schmidt.

    Reference: Zhou et al., "On the Continuity of Rotation Representations in
    Neural Networks" (CVPR 2019).
    """
    a1, a2 = d6[..., :3], d6[..., 3:]
    b1 = a1 / mx.maximum(mx.linalg.norm(a1, axis=-1, keepdims=True), _NORM_EPS)
    a2 = a2 - mx.sum(b1 * a2, axis=-1, keepdims=True) * b1
    b2 = a2 / mx.maximum(mx.linalg.norm(a2, axis=-1, keepdims=True), _NORM_EPS)
    b3 = mx.linalg.cross(b1, b2)
    return mx.stack([b1, b2, b3], axis=-2)

standardize_quaternion(quaternions)

Flip quaternions so the real part is non-negative.

Source code in src/mlx3d/transforms/rotations.py
87
88
89
def standardize_quaternion(quaternions: mx.array) -> mx.array:
    """Flip quaternions so the real part is non-negative."""
    return mx.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)

hat(v)

Map (..., 3) vectors to skew-symmetric matrices (..., 3, 3).

Source code in src/mlx3d/transforms/se3.py
28
29
30
31
32
33
34
35
def hat(v: mx.array) -> mx.array:
    """Map ``(..., 3)`` vectors to skew-symmetric matrices ``(..., 3, 3)``."""
    zero = mx.zeros_like(v[..., 0])
    x, y, z = v[..., 0], v[..., 1], v[..., 2]
    row0 = mx.stack([zero, -z, y], axis=-1)
    row1 = mx.stack([z, zero, -x], axis=-1)
    row2 = mx.stack([-y, x, zero], axis=-1)
    return mx.stack([row0, row1, row2], axis=-2)

se3_exp_map(xi)

Exponential map from twists (..., 6) = [v(3), omega(3)] to transforms.

omega is the rotation part and v the translation part of the twist (PyTorch3D ordering). Returns a :class:Transform3d.

Source code in src/mlx3d/transforms/se3.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def se3_exp_map(xi: mx.array) -> "Transform3d":
    """Exponential map from twists ``(..., 6)`` = ``[v(3), omega(3)]`` to transforms.

    ``omega`` is the rotation part and ``v`` the translation part of the twist
    (PyTorch3D ordering). Returns a :class:`Transform3d`.
    """
    v, omega = xi[..., :3], xi[..., 3:]
    r = so3_exp_map(omega)
    theta, theta2 = _safe_theta(omega)
    small = theta < 1e-4
    b = mx.where(
        small,
        0.5 - theta2 / 24.0,
        (1.0 - mx.cos(theta)) / mx.where(small, mx.ones_like(theta2), theta2),
    )
    c = mx.where(
        small,
        1.0 / 6.0 - theta2 / 120.0,
        (theta - mx.sin(theta)) / mx.where(small, mx.ones_like(theta), theta * theta2),
    )
    k = hat(omega)
    eye = mx.broadcast_to(mx.eye(3), k.shape)
    vmat = eye + b[..., None] * k + c[..., None] * (k @ k)  # left Jacobian
    t = (vmat @ v[..., None])[..., 0]
    return Transform3d.from_rot_trans(r, t)

se3_log_map(transform)

Inverse of :func:se3_exp_map: twists (..., 6) = [v, omega] from a transform.

Source code in src/mlx3d/transforms/se3.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def se3_log_map(transform: "Transform3d") -> mx.array:
    """Inverse of :func:`se3_exp_map`: twists ``(..., 6)`` = ``[v, omega]`` from a transform."""
    r, t = transform.rot, transform.trans
    omega = so3_log_map(r)
    theta = mx.linalg.norm(omega, axis=-1, keepdims=True)
    theta2 = theta * theta
    small = theta < 1e-4
    k = hat(omega)
    eye = mx.broadcast_to(mx.eye(3), k.shape)
    # V^{-1} = I - 0.5 K + (1/theta^2)(1 - theta sin / (2(1-cos))) K^2.
    half_theta = 0.5 * theta
    coeff = mx.where(
        small,
        1.0 / 12.0 + theta2 / 720.0,
        (
            1.0
            - half_theta
            * mx.cos(half_theta)
            / mx.sin(mx.where(small, mx.ones_like(half_theta), half_theta))
        )
        / mx.where(small, mx.ones_like(theta2), theta2),
    )
    vinv = eye - 0.5 * k + coeff[..., None] * (k @ k)
    v = (vinv @ t[..., None])[..., 0]
    return mx.concatenate([v, omega], axis=-1)

so3_exp_map(omega)

Rotation matrices (..., 3, 3) from axis-angle vectors (..., 3) (Rodrigues).

Source code in src/mlx3d/transforms/se3.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def so3_exp_map(omega: mx.array) -> mx.array:
    """Rotation matrices ``(..., 3, 3)`` from axis-angle vectors ``(..., 3)`` (Rodrigues)."""
    theta, theta2 = _safe_theta(omega)  # (..., 1)
    small = theta < 1e-4
    # A = sin(theta)/theta, B = (1-cos theta)/theta^2, with Taylor fallbacks.
    a = mx.where(
        small, 1.0 - theta2 / 6.0, mx.sin(theta) / mx.where(small, mx.ones_like(theta), theta)
    )
    b = mx.where(
        small,
        0.5 - theta2 / 24.0,
        (1.0 - mx.cos(theta)) / mx.where(small, mx.ones_like(theta2), theta2),
    )
    a = a[..., None]  # (..., 1, 1)
    b = b[..., None]
    k = hat(omega)
    eye = mx.broadcast_to(mx.eye(3), k.shape)
    return eye + a * k + b * (k @ k)

so3_log_map(r)

Axis-angle vectors (..., 3) from rotation matrices (..., 3, 3).

This is the inverse of :func:so3_exp_map. It delegates to the quaternion based conversion, which stays accurate near theta = pi where the naive theta / (2 sin theta) form is singular.

Source code in src/mlx3d/transforms/se3.py
75
76
77
78
79
80
81
82
83
84
def so3_log_map(r: mx.array) -> mx.array:
    """Axis-angle vectors ``(..., 3)`` from rotation matrices ``(..., 3, 3)``.

    This is the inverse of :func:`so3_exp_map`. It delegates to the quaternion
    based conversion, which stays accurate near ``theta = pi`` where the naive
    ``theta / (2 sin theta)`` form is singular.
    """
    from .rotations import matrix_to_axis_angle

    return matrix_to_axis_angle(r)

vee(m)

Inverse of :func:hat: extract (..., 3) from a skew matrix (..., 3, 3).

Source code in src/mlx3d/transforms/se3.py
38
39
40
def vee(m: mx.array) -> mx.array:
    """Inverse of :func:`hat`: extract ``(..., 3)`` from a skew matrix ``(..., 3, 3)``."""
    return mx.stack([m[..., 2, 1], m[..., 0, 2], m[..., 1, 0]], axis=-1)