Tutorial 8: Data Parallel MLP#
Note: While async concepts are taught using the
tokioruntime, any async runtime can be used.
In this tutorial we show how to build a single-layer MLP, copy it to multiple GPUs, and execute distinct batches of data on each instance:
Input → Linear → ReLU → Output
Where:
Linear: hidden = input @ weights
ReLU: output = max(0, hidden)
The Code#
#[cutile::module]
mod data_parallel_module {
use cutile::core::*;
// mat-mat.
#[cutile::entry()]
fn gemm<const BM: i32, const BN: i32, const BK: i32, const K: i32>(
z: &mut Tensor<f32, { [BM, BN] }>,
x: &Tensor<f32, { [-1, K] }>,
y: &Tensor<f32, { [K, -1] }>,
) {
let part_x = x.partition(const_shape![BM, BK]);
let part_y = y.partition(const_shape![BK, BN]);
let pid: (i32, i32, i32) = get_tile_block_id();
let mut tile_z = load_tile_mut(z);
for i in 0i32..(K / BK) {
let tile_x = part_x.load([pid.0, i]);
let tile_y = part_y.load([i, pid.1]);
tile_z = mma(tile_x, tile_y, tile_z);
}
z.store(tile_z);
}
// mat-vec.
#[cutile::entry()]
pub fn matvec<const BM: i32, const BK: i32, const K: i32>(
z: &mut Tensor<f32, { [BM] }>,
x: &Tensor<f32, { [-1, K] }>,
y: &Tensor<f32, { [K] }>,
) {
let part_x = x.partition(const_shape![BM, BK]);
let part_y = y.partition(const_shape![BK]);
let pid: (i32, i32, i32) = get_tile_block_id();
let mut tile_z = z.load().reshape(const_shape![BM, 1]);
for i in 0i32..(K / BK) {
let tile_x = part_x.load([pid.0, i]);
let tile_y = part_y.load([i]).reshape(const_shape![BK, 1]);
tile_z = mma(tile_x, tile_y, tile_z);
continue;
}
z.store(tile_z.reshape(const_shape![BM]));
}
// ReLU.
#[cutile::entry()]
fn relu<const D: i32>(input_output: &mut Tensor<f32, { [D] }>) {
let zero_tile: Tile<f32, { [D] }> = constant(0.0f32, const_shape![D]);
let input = input_output.load();
input_output.store(max_tile(zero_tile, input));
}
}
use data_parallel_module::{gemm_apply, relu_apply, matvec_apply};
#[tokio::main]
async fn main() {
use cuda_async::device_operation::*;
use data_parallel_module::{gemm_apply, relu_apply, matvec_apply};
use cutile::api;
use cutile::tensor::{Unpartition, Partition, Tensor, ToHostVec};
use cutile::tile_kernel::{IntoDeviceOperationPartition, TileKernel};
use cuda_async::device_context::global_policy;
use cutile::api::copy;
use tokio::task::JoinHandle;
// Get device scheduling policies.
let num_devices = 2;
let devices = {
let mut r = vec![];
for _ in 0..num_devices {
// Pretend we have multiple devices...
// If you actually do have multiple devices, use i in place of 0.
r.push(global_policy(0)?);
}
r
};
let dim = 16;
let block_dim = 4;
let fully_connected_layer = [
block_dim.to_string(),
block_dim.to_string(),
block_dim.to_string(),
dim.to_string(),
];
let output_layer = [
block_dim.to_string(),
block_dim.to_string(),
dim.to_string(),
];
let w0 = api::randn(0.0f32, 1.0, [dim, dim]); // impl DeviceOperation
let w1 = api::randn(0.0f32, 1.0, [dim]); // impl DeviceOperation
let w = zip!(w0.arc(), w1.arc()).schedule(&devices[0])?.await?;
let mut joins = vec![];
for i in 1..num_devices {
let w_copy = tokio::spawn(zip!(copy(&w.0).arc(), copy(&w.1).arc()).schedule(&devices[i])?);
joins.push(w_copy);
}
let mut model_weights = vec![w];
for join in joins {
model_weights.push(join.await.unwrap()?);
}
// Asynchronously compute forward pass for each batch of data on each device.
let mut futures: Vec<JoinHandle<Result<Partition<Tensor<f32>>, cuda_async::error::DeviceError>>> = vec![];
for i in 0..num_devices {
let w = &model_weights[i];
let (w0, w1) = (w.0.clone(), w.1.clone());
let data = api::randn(0.0, 1.0, [dim, dim]).arc();
let out0 = api::zeros::<2, f32>([dim, dim]).partition([block_dim, block_dim]);
let (out0, _, _) = zip!(out0, data, value(w0))
.apply(|args| gemm_apply(args).generics(fully_connected_layer.to_vec()))
.unzip();
let out1 = api::zeros::<1, f32>([dim]).partition([block_dim]);
let (out1, _, _) = zip!(out1, out0.unpartition().arc(), value(w1))
.apply(|args| matvec_apply(args).generics(output_layer.to_vec()))
.unzip();
let (out1,) = out1.and_then(|out1| value((out1,))).apply(relu_apply).unzip();
futures.push(tokio::spawn(out1.schedule(&devices[i])?));
}
// Wait on results.
let mut outputs: Vec<Tensor<f32>> = vec![];
for future in futures.into_iter() {
let tensor = future.await.unwrap()?.unpartition();
outputs.push(tensor);
}
for output in outputs {
println!("{:?}", output.to_host_vec().await?);
}
}
Key Pattern: Compose Device Operations, Then Spawn#
Every device operation in the loop below is non-blocking. The loop itself is non-blocking:
let mut futures: Vec<JoinHandle<Result<Partition<Tensor<f32>>, cuda_async::error::DeviceError>>> = vec![];
for i in 0..num_devices {
// Obtain a reference to the model weights on device i.
let w = &model_weights[i];
let (w0, w1) = (w.0.clone(), w.1.clone());
// Sample random data. Although the sampling procedure is a simulation,
// this can be replaced with a procedure that actually samples a batch of data.
let data = api::randn(0.0, 1.0, [dim, dim]).arc();
// Construct the intermediate output buffer and partition, since we'll be writing to it.
let out0 = api::zeros::<2, f32>([dim, dim]).partition([block_dim, block_dim]);
// Execute GEMM.
let (out0, _, _) = zip!(out0, data, value(w0))
.apply(|args| gemm_apply(args).generics(fully_connected_layer.to_vec()))
.unzip();
// Construct the final output buffer and partition.
let out1 = api::zeros::<1, f32>([dim]).partition([block_dim]);
// Execute MatVec.
let (out1, _, _) = zip!(out1, out0.unpartition().arc(), value(w1))
.apply(|args| matvec_apply(args).generics(output_layer.to_vec()))
.unzip();
// Apply ReLU and unzip. We need to unzip here since arguments to kernels
// are always packed into a tuple.
let (out1,) = out1.and_then(|out1| value((out1,))).apply(relu_apply).unzip();
// out1 now contains the work we would like to schedule on device i.
// By invoking schedule on device i, we generate a device future which is
// ready to execute on device i. By spawning a task for the device future,
// we submit the work for execution to the async runtime (tokio). We then
// collect the task handle into the futures vec.
futures.push(tokio::spawn(out1.schedule(&devices[i])?));
}
After spawning tasks for each forward pass on each device, we wait on the results before proceeding:
let mut outputs: Vec<Tensor<f32>> = vec![];
for future in futures.into_iter() {
let tensor = future.await.unwrap()?.unpartition();
outputs.push(tensor);
}
Key Takeaways#
Concept |
What It Means |
|---|---|
Device operations |
Chainable, resource-agnostic DAGs |
tokio::spawn |
Run batches concurrently |
schedule(device) |
Target a specific GPU |
Lazy execution |
Pipeline is built first, then executed on |
Exercise 1: Fuse the Kernel#
How might we fuse the above kernels into a single kernel? Would this reduce the memory footprint of our computation?
Exercise 2: Overlapping Data Movement with Computation#
What would we need to change to construct a pipeline that overlaps data movement with computation?