Benchmark different HEALPix padding implementations

Benchmark different HEALPix padding implementations#

CUDA available: Yes
Current device: 0
Device name: NVIDIA RTX 6000 Ada Generation
Device count: 1
Device capability: (8, 9)


Benchmarking results neval=10 p.size()=torch.Size([1, 12, 384, 128, 128]) padding=64 dtype=torch.float32
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return x[slicers]
Python:                        gb_per_sec=66.46 peak_memory=6399.00MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return node.target(*args, **kwargs)  # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return func(*args, **kwargs)
Python + compile:              gb_per_sec=361.24 peak_memory=2631.00MB
HEALPix Pad:                   gb_per_sec=350.54 peak_memory=2592.00MB
Zephyr pad:                    gb_per_sec=61.93 peak_memory=3744.00MB
Zephyr pad doesn't work well with torch.compile. Doesn't finish compiling.
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return x[slicers]
Python: channels dim last*:    gb_per_sec=61.78 peak_memory=6099.01MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return node.target(*args, **kwargs)  # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return func(*args, **kwargs)
Python + torch.compile: channels dim last*: gb_per_sec=395.40 peak_memory=2631.00MB
HEALPix Pad: channels dim last: gb_per_sec=380.37 peak_memory=2592.00MB

Benchmarking results neval=10 p.size()=torch.Size([2, 12, 384, 128, 128]) padding=64 dtype=torch.float32
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return x[slicers]
Python:                        gb_per_sec=67.54 peak_memory=12735.00MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return node.target(*args, **kwargs)  # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return func(*args, **kwargs)
Python + compile:              gb_per_sec=345.64 peak_memory=5223.01MB
HEALPix Pad:                   gb_per_sec=350.59 peak_memory=5184.01MB
Zephyr pad:                    gb_per_sec=83.91 peak_memory=7488.01MB
Zephyr pad doesn't work well with torch.compile. Doesn't finish compiling.
/home/nbrenowitz/workspace/earth2grid/earth2grid/healpix/_padding/pure_python.py:131: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return x[slicers]
Python: channels dim last*:    gb_per_sec=62.17 peak_memory=12147.01MB
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return node.target(*args, **kwargs)  # type: ignore[operator]
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/interpreter.py:336: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return target(*args, **kwargs)
/home/nbrenowitz/workspace/earth2grid/.venv/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py:1409: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
  return func(*args, **kwargs)
Python + torch.compile: channels dim last*: gb_per_sec=408.63 peak_memory=5223.01MB
HEALPix Pad: channels dim last: gb_per_sec=404.53 peak_memory=5184.01MB

* shape for Python channels dim last: torch.Size([2, 196608, 384])
* shape for HEALPix Pad channels dim last: torch.Size([2, 12, 384, 128, 128])

import time

import torch

from earth2grid import healpix
from earth2grid.healpix import pad_backend

# Print GPU information
if torch.cuda.is_available():
    print("CUDA available: Yes")
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name()}")
    print(f"Device count: {torch.cuda.device_count()}")
    print(f"Device capability: {torch.cuda.get_device_capability()}")
else:
    print("CUDA not available")
print("\n")

nside = 128
padding = nside // 2
channels = 384
dtype = torch.float32

neval = 10


def test_func(label, pad, compile=False):
    # Reset memory stats and clear cache
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # warm up
    if compile:
        pad = torch.compile(pad)
    out = pad(p, padding=padding)
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(neval):
        out = pad(p, padding=padding)
    torch.cuda.synchronize()
    stop = time.time()
    gb_per_sec = out.nbytes * neval / (stop - start) / 1e9
    peak_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
    label = label + ":"
    label = label + max(30 - len(label), 0) * " "
    print(f"{label} {gb_per_sec=:.2f} peak_memory={peak_memory:.2f}MB")


for batch_size in [1, 2]:
    p = torch.randn(size=(batch_size, 12, channels, nside, nside), dtype=dtype)
    print(f"Benchmarking results {neval=} {p.size()=} {padding=} {dtype=}")

    p = p.cuda()

    with pad_backend(healpix.PaddingBackends.indexing):
        test_func("Python", healpix.pad)
        test_func("Python + compile", healpix.pad, compile=True)

    with pad_backend(healpix.PaddingBackends.cuda):
        test_func("HEALPix Pad", healpix.pad)

    with pad_backend(healpix.PaddingBackends.zephyr):
        test_func("Zephyr pad", healpix.pad)
        print("Zephyr pad doesn't work well with torch.compile. Doesn't finish compiling.")

    p = torch.randn(size=(batch_size, 12 * nside * nside, channels), dtype=dtype).cuda()
    test_func("Python: channels dim last*", lambda x, padding: healpix.pad_with_dim(x, padding, dim=1), compile=False)
    test_func(
        "Python + torch.compile: channels dim last*",
        lambda x, padding: healpix.pad_with_dim(x, padding, dim=1),
        compile=True,
    )
    p_python_shape = p.shape

    p = p.view(batch_size, 12, nside, nside, channels).permute(0, 1, 4, 2, 3)
    with pad_backend(healpix.PaddingBackends.cuda):
        test_func("HEALPix Pad: channels dim last", healpix.pad)

    print("")


print(f"* shape for Python channels dim last: {p_python_shape}")
print(f"* shape for HEALPix Pad channels dim last: {p.shape}")

Total running time of the script: (0 minutes 18.625 seconds)

Gallery generated by Sphinx-Gallery