mlx3d.nn¶
mlx3d.nn
¶
FusedMLP
¶
Bases: Module
A small ReLU MLP with a fused Metal forward path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_dims
|
list[int]
|
sizes |
required |
Source code in src/mlx3d/nn/fused_mlp.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | |
__call__(x)
¶
Differentiable MLX forward (use for training).
Source code in src/mlx3d/nn/fused_mlp.py
94 95 96 97 98 99 100 101 102 | |
forward_fused(x)
¶
Fused single-kernel forward; matches :meth:__call__ exactly.
A correct reference for the fused-MLP idea. See the module note: MLX's
native matmuls are currently faster on Apple GPUs, so prefer
:meth:__call__ in practice.
Source code in src/mlx3d/nn/fused_mlp.py
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | |
HashGridEncoding
¶
Bases: Module
Trainable multi-resolution 3D hash-grid encoder.
This follows the Instant-NGP idea: points are normalized to a unit cube, each level performs trilinear interpolation over hashed grid vertices, and all level features are concatenated.
Source code in src/mlx3d/nn/hashgrid.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | |
HashGridNeRF
¶
Bases: Module
A compact hash-grid NeRF (Instant-NGP style).
The hash-grid hyperparameters (num_levels, features_per_level,
log2_hashmap_size, base_resolution, finest_resolution) are
forwarded to :class:~mlx3d.nn.HashGridEncoding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bounds
|
tuple[float, float]
|
axis-aligned scene bounds the hash grid covers; sample points should lie within this cube (density is zeroed outside it). |
(-1.5, 1.5)
|
geo_feat_dim
|
int
|
size of the geometry feature passed to the color MLP. |
15
|
hidden_dim
|
int
|
width of both small MLPs. |
64
|
dir_freqs
|
int
|
positional-encoding frequencies for the view direction. |
4
|
Source code in src/mlx3d/nn/instant_ngp.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | |
NeRF
¶
Bases: Module
The original NeRF MLP: density from position, color from position + view.
Source code in src/mlx3d/nn/nerf.py
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | |
__call__(points, directions)
¶
Evaluate density and color.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
points
|
array
|
(..., 3) sample positions. |
required |
directions
|
array
|
(..., 3) normalized view directions (broadcastable). |
required |
Returns:
| Type | Description |
|---|---|
tuple[array, array]
|
|
Source code in src/mlx3d/nn/nerf.py
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | |
PositionalEncoding
¶
Bases: Module
Sinusoidal positional encoding from the NeRF paper.
Maps x to [x, sin(2^0 x), cos(2^0 x), ..., sin(2^{L-1} x), cos(2^{L-1} x)].
Source code in src/mlx3d/nn/nerf.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | |
OccupancyGrid
¶
A dense res^3 occupancy cache over an axis-aligned box.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
resolution
|
int
|
cells per axis. |
128
|
bounds
|
tuple[float, float]
|
|
(-1.5, 1.5)
|
Source code in src/mlx3d/nn/occupancy.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | |
occupied_fraction
property
¶
Fraction of cells currently marked occupied (useful for diagnostics).
update(density_fn, threshold=0.01, chunk=1 << 18)
¶
Refresh occupancy by thresholding the density field at cell centers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
density_fn
|
Callable[[array], array]
|
callable mapping |
required |
threshold
|
float
|
cells with density above this are marked occupied. |
0.01
|
chunk
|
int
|
points evaluated per batch (bounds memory for fine grids). |
1 << 18
|
Source code in src/mlx3d/nn/occupancy.py
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | |
query(points)
¶
Return a boolean mask of which (..., 3) points fall in occupied cells.
Points outside the grid bounds are reported empty.
Source code in src/mlx3d/nn/occupancy.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 | |
render_rays_occupancy(model, origins, directions, near, far, grid, num_samples=128, eval_fraction=1.0, stratified=False, white_background=False)
¶
Render rays, evaluating model only at occupied samples.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Field
|
a field |
required |
origins
|
array
|
|
required |
directions
|
array
|
|
required |
near
|
float
|
near sampling bound. |
required |
far
|
float
|
far sampling bound. |
required |
grid
|
OccupancyGrid
|
occupancy cache identifying non-empty space (kept fixed / detached). |
required |
num_samples
|
int
|
samples per ray. |
128
|
eval_fraction
|
float
|
fraction of all |
1.0
|
stratified
|
bool
|
jitter samples (training) vs. deterministic (eval). |
False
|
white_background
|
bool
|
composite onto white. |
False
|
Returns:
| Type | Description |
|---|---|
dict[str, array]
|
Same dict as :func: |
Source code in src/mlx3d/nn/accel.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | |
render_rays(model, origins, directions, near, far, num_coarse=64, num_fine=0, fine_model=None, stratified=True, white_background=False)
¶
Render a batch of rays with optional hierarchical sampling.
Returns a dict with rgb, depth, acc (and rgb_coarse when
fine sampling is enabled).
Source code in src/mlx3d/nn/nerf.py
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | |