Skip to content

🌍 Generate KM-Scale Weather Maps with Pre-Trained Cascaded cBottle

Step 1: Configure Data Paths

Edit the paths in cBottle/src/cbottle/config/environment.py to point to your data sources.
Set the corresponding profile variables to "" if you're using local storage.

Step 2: Run Inference with the Coarse Generator

python scripts/inference_coarse.py cBottle-3d.zip inference_output --sample.min_samples 1

Step 3 (Optional): Plot the Generated Coarse Maps

This command creates a ZIP archive in the current directory containing visualizations of all output variables.

python scripts/plot.py inference_output/0.nc coarse.zip

Step 4: Super-Resolve a Subregion of the Coarse Map

If --input-path is not provided, the script defaults to using ICON HPX64 data.

python scripts/inference_multidiffusion.py cBottle-SR.zip superres_output \
    --input-path inference_output/0.nc \
    --overlap-size 32 \
    --super-resolution-box 0 -120 50 -40

Step 5 (Optional): Plot the Super-Resolved Output

python scripts/plot.py superres_output/0.nc high_res.zip

Load and Explore Zarr Datasets

Load a Zarr Dataset and Extract a Variable

import xarray as xr
ds = xr.open_zarr('/global/cfs/cdirs/m4581/gsharing/hackathon/scream-cess-healpix/scream2D_hrly_pr_hp10_v7.zarr')
pr = ds.pr[:10].load()

Convert to RING Order and Compute Zonal Average

from earth2grid import healpix
import torch

pr_r = healpix.reorder(torch.from_numpy(pr.values), healpix.PixelOrder.NEST, healpix.PixelOrder.RING)
avg = healpix.zonal_average(pr_r)

Load Data with ZarrLoader

import cbottle.datasets.zarr_loader as zl

loader = zl.ZarrLoader(
    path="/global/cfs/cdirs/m4581/gsharing/hackathon/scream-cess-healpix/scream2D_hrly_rlut_hp10_v7.zarr",
    variables_3d=[],
    variables_2d=["rlut"],
    levels=[]
)

Create a Time-Chunked Dataset

import cbottle.datasets.merged_dataset as md

dataset = md.TimeMergedDataset(
    loader.times,
    time_loaders=[loader],
    transform=lambda t, x: x[0],
    chunk_size=48,
    shuffle=True
)

Train on a Custom Dataset

Step 1: Build a Dataloader

Load Multiple Zarr Datasets

variable_list_2d = ["rlut", "pr"]
loaders = [
    zl.ZarrLoader(
        path=f"/global/cfs/cdirs/m4581/gsharing/hackathon/scream-cess-healpix/scream2D_hrly_{var}_hp10_v7.zarr",
        variables_3d=[],
        variables_2d=[var],
        levels=[]
    )
    for var in variable_list_2d
]

Define a Transform Function for Each Sample

import numpy as np

def encode_task(t, d):
    t = t[0]
    d = d[0]
    condition = []  # empty; will be inferred during training
    target = [d[(var, -1)][None] for var in variable_list_2d]
    return {
        "condition": condition,
        "target": np.stack(target),
        "timestamp": t.timestamp()
    }

Create a DataLoader

dataset = md.TimeMergedDataset(
    loaders[0].times,
    time_loaders=loaders,
    transform=encode_task,
    chunk_size=48,
    shuffle=True
)

import torch
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,
    num_workers=3
)

Monitor I/O Throughput

import tqdm

with tqdm.tqdm(unit='B', unit_scale=True) as pb:
    for i, b in enumerate(data_loader):
        if i == 20:
            break
        pb.update(b["target"].nbytes)

Step 2: Wrap the Dataset with a Train/Test Split

def dataset_wrapper(*, split: str = ""):
    valid_times = loaders[0].times
    train_times = valid_times[:int(len(valid_times) * 0.75)]
    test_times = valid_times[-1:]
    times = {"train": train_times, "test": test_times, "": valid_times}[split]
    chunk_size = {"train": 48, "test": 1, "": 1}[split]

    if times.size == 0:
        raise RuntimeError("No times are selected.")

    dataset = md.TimeMergedDataset(
        times,
        time_loaders=loaders,
        transform=encode_task,
        chunk_size=chunk_size,
        shuffle=True
    )

    # Additional metadata required for training
    dataset.grid = healpix.Grid(level=10, pixel_order=healpix.PixelOrder.NEST)
    dataset.fields_out = variable_list_2d

    return dataset

Step 3: Train the Super-Resolution Model

Requires at least 60 GB of GPU memory.
To run on Perlmutter, set -C 'gpu&hbm80g' to request A100 80GB nodes.

from train_multidiffusion import train as train_super_resolution

train_super_resolution(
    output_path="training_output",
    customized_dataset=dataset_wrapper,
    num_steps=10,
    log_freq=5
)