Note
Go to the end to download the full example code.
Benchmark different HEALPix padding implementations#
CUDA available: Yes
Current device: 0
Device name: NVIDIA RTX 6000 Ada Generation
Device count: 1
Device capability: (8, 9)
Benchmarking results neval=10 p.size()=torch.Size([1, 12, 384, 128, 128]) padding=64 dtype=torch.float32
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return x[slicers]
Python: gb_per_sec=66.46 peak_memory=6399.00MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return node.target(*args, **kwargs) # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return func(*args, **kwargs)
Python + compile: gb_per_sec=361.24 peak_memory=2631.00MB
HEALPix Pad: gb_per_sec=350.54 peak_memory=2592.00MB
Zephyr pad: gb_per_sec=61.93 peak_memory=3744.00MB
Zephyr pad doesn't work well with torch.compile. Doesn't finish compiling.
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return x[slicers]
Python: channels dim last*: gb_per_sec=61.78 peak_memory=6099.01MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return node.target(*args, **kwargs) # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return func(*args, **kwargs)
Python + torch.compile: channels dim last*: gb_per_sec=395.40 peak_memory=2631.00MB
HEALPix Pad: channels dim last: gb_per_sec=380.37 peak_memory=2592.00MB
Benchmarking results neval=10 p.size()=torch.Size([2, 12, 384, 128, 128]) padding=64 dtype=torch.float32
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return x[slicers]
Python: gb_per_sec=67.54 peak_memory=12735.00MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return node.target(*args, **kwargs) # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return func(*args, **kwargs)
Python + compile: gb_per_sec=345.64 peak_memory=5223.01MB
HEALPix Pad: gb_per_sec=350.59 peak_memory=5184.01MB
Zephyr pad: gb_per_sec=83.91 peak_memory=7488.01MB
Zephyr pad doesn't work well with torch.compile. Doesn't finish compiling.
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return x[slicers]
Python: channels dim last*: gb_per_sec=62.17 peak_memory=12147.01MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return node.target(*args, **kwargs) # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
return func(*args, **kwargs)
Python + torch.compile: channels dim last*: gb_per_sec=408.63 peak_memory=5223.01MB
HEALPix Pad: channels dim last: gb_per_sec=404.53 peak_memory=5184.01MB
* shape for Python channels dim last: torch.Size([2, 196608, 384])
* shape for HEALPix Pad channels dim last: torch.Size([2, 12, 384, 128, 128])
import time
import torch
from earth2grid import healpix
from earth2grid.healpix import pad_backend
# Print GPU information
if torch.cuda.is_available():
print("CUDA available: Yes")
print(f"Current device: {torch.cuda.current_device()}")
print(f"Device name: {torch.cuda.get_device_name()}")
print(f"Device count: {torch.cuda.device_count()}")
print(f"Device capability: {torch.cuda.get_device_capability()}")
else:
print("CUDA not available")
print("\n")
nside = 128
padding = nside // 2
channels = 384
dtype = torch.float32
neval = 10
def test_func(label, pad, compile=False):
# Reset memory stats and clear cache
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# warm up
if compile:
pad = torch.compile(pad)
out = pad(p, padding=padding)
torch.cuda.synchronize()
start = time.time()
for _ in range(neval):
out = pad(p, padding=padding)
torch.cuda.synchronize()
stop = time.time()
gb_per_sec = out.nbytes * neval / (stop - start) / 1e9
peak_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
label = label + ":"
label = label + max(30 - len(label), 0) * " "
print(f"{label} {gb_per_sec=:.2f} peak_memory={peak_memory:.2f}MB")
for batch_size in [1, 2]:
p = torch.randn(size=(batch_size, 12, channels, nside, nside), dtype=dtype)
print(f"Benchmarking results {neval=} {p.size()=} {padding=} {dtype=}")
p = p.cuda()
with pad_backend(healpix.PaddingBackends.indexing):
test_func("Python", healpix.pad)
test_func("Python + compile", healpix.pad, compile=True)
with pad_backend(healpix.PaddingBackends.cuda):
test_func("HEALPix Pad", healpix.pad)
with pad_backend(healpix.PaddingBackends.zephyr):
test_func("Zephyr pad", healpix.pad)
print("Zephyr pad doesn't work well with torch.compile. Doesn't finish compiling.")
p = torch.randn(size=(batch_size, 12 * nside * nside, channels), dtype=dtype).cuda()
test_func("Python: channels dim last*", lambda x, padding: healpix.pad_with_dim(x, padding, dim=1), compile=False)
test_func(
"Python + torch.compile: channels dim last*",
lambda x, padding: healpix.pad_with_dim(x, padding, dim=1),
compile=True,
)
p_python_shape = p.shape
p = p.view(batch_size, 12, nside, nside, channels).permute(0, 1, 4, 2, 3)
with pad_backend(healpix.PaddingBackends.cuda):
test_func("HEALPix Pad: channels dim last", healpix.pad)
print("")
print(f"* shape for Python channels dim last: {p_python_shape}")
print(f"* shape for HEALPix Pad channels dim last: {p.shape}")
Total running time of the script: (0 minutes 18.625 seconds)