Topic 7: Filters
Contents
Topic 7: Filters¶
Fabric filters allow a PE to selectively accept incoming wavelets. This example shows the use of so-called range filters, which specify the wavelets to allow to be forwarded to the CE based on the upper 16 bits of the wavelet contents. Specifically, PE #0 sends all 12 wavelets to the other PEs, while each recipient PE receives and processes only a quarter of the incoming wavelets. See Filter Configuration Semantics for other possible filter configurations.
layout.csl¶
// color/ task ID map
//
// ID var ID var ID var ID var
// 0 9 STARTUP 18 27 reserved (memcpy)
// 1 dataColor 10 19 28 reserved (memcpy)
// 2 resultColor 11 20 29 reserved
// 3 H2D 12 21 reserved (memcpy) 30 reserved (memcpy)
// 4 D2H 13 22 reserved (memcpy) 31 reserved
// 5 14 23 reserved (memcpy) 32
// 6 15 24 33
// 7 16 25 34
// 8 main_task_id 17 26 35
// +-------------+
// | north(d2H) |
// +-------------+
// | core |
// +-------------+
// | south(nop) |
// +-------------+
// IDs for memcpy streaming colors
param MEMCPYH2D_DATA_1_ID: i16;
param MEMCPYD2H_DATA_1_ID: i16;
// Colors
const MEMCPYH2D_DATA_1: color = @get_color(MEMCPYH2D_DATA_1_ID);
const MEMCPYD2H_DATA_1: color = @get_color(MEMCPYD2H_DATA_1_ID);
const dataColor: color = @get_color(1);
const resultColor: color = @get_color(2);
// Task IDs
const STARTUP: local_task_id = @get_local_task_id(9);
const main_task_id: local_task_id = @get_local_task_id(8);
const recv_task_id: data_task_id = @get_data_task_id(dataColor);
const memcpy = @import_module( "<memcpy/get_params>", .{
.width = 4,
.height = 3,
.MEMCPYH2D_1 = MEMCPYH2D_DATA_1,
.MEMCPYD2H_1 = MEMCPYD2H_DATA_1
});
layout {
@set_rectangle(4, 3);
for (@range(i16, 4)) |pe_x| {
const memcpy_params = memcpy.get_params(pe_x);
// north PE only runs d2h
@set_tile_code(pe_x, 0, "memcpyEdge/north.csl", .{
.memcpy_params = memcpy_params,
.USER_OUT_1 = resultColor,
.STARTUP = STARTUP,
});
}
const memcpy_params_0 = memcpy.get_params(0);
const memcpy_params_1 = memcpy.get_params(1);
const memcpy_params_2 = memcpy.get_params(2);
const memcpy_params_3 = memcpy.get_params(3);
@set_tile_code(0, 1, "send.csl", .{
.peId = 0,
.memcpy_params = memcpy_params_0,
.exchColor = dataColor,
.resultColor = resultColor,
.main_task_id = main_task_id
});
const recvStruct = .{ .recvColor = dataColor,
.resultColor = resultColor,
.recv_task_id = recv_task_id };
@set_tile_code(1, 1, "recv.csl", @concat_structs(recvStruct, .{
.peId = 1,
.memcpy_params = memcpy_params_1,
}));
@set_tile_code(2, 1, "recv.csl", @concat_structs(recvStruct, .{
.peId = 2,
.memcpy_params = memcpy_params_2,
}));
@set_tile_code(3, 1, "recv.csl", @concat_structs(recvStruct, .{
.peId = 3,
.memcpy_params = memcpy_params_3,
}));
for (@range(i16, 4)) |pe_x| {
const memcpy_params = memcpy.get_params(pe_x);
// south does nothing
@set_tile_code(pe_x, 2, "memcpyEdge/south.csl", .{
.memcpy_params = memcpy_params,
.STARTUP = STARTUP
});
}
}
send.csl¶
param memcpy_params: comptime_struct;
param peId: u16;
// Colors
param exchColor: color;
param resultColor: color;
// Task IDs
param main_task_id: local_task_id;
// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.
// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------
/// Helper function to pack 16-bit index and 16-bit float value into one 32-bit
/// wavelet.
fn pack(index: u16, data: f16) u32 {
return (@as(u32, index) << 16) | @as(u32, @bitcast(u16, data));
}
const size = 12;
const data = [size]u32 {
pack(0, 10.0), pack( 1, 11.0), pack( 2, 12.0),
pack(3, 13.0), pack( 4, 14.0), pack( 5, 15.0),
pack(6, 16.0), pack( 7, 17.0), pack( 8, 18.0),
pack(9, 19.0), pack(10, 20.0), pack(11, 21.0),
};
/// Function to send all data values to all east neighbors.
fn sendDataToEastTiles() void {
const inDsd = @get_dsd(mem1d_dsd, .{
.tensor_access = |i|{size} -> data[i]
});
const outDsd = @get_dsd(fabout_dsd, .{
.extent = size,
.fabric_color = exchColor,
.output_queue = @get_output_queue(2)
});
// WARNING: "async" is necessary otherwise CE has no resource
// to run memcpy kernel
@mov32(outDsd, inDsd, .{.async=true});
}
/// Function to process (divide by 2) the first three values and send result to
/// the north neighbor (halo PE).
const num_wvlts: u16 = 3;
var buf = @zeros([num_wvlts]f16);
var ptr_buf : [*]f16 = &buf;
fn processAndSendSubset() void {
const outDsd = @get_dsd(fabout_dsd, .{
.extent = num_wvlts,
.fabric_color = resultColor,
.output_queue = @get_output_queue(1)
});
const bufDsd = @get_dsd(mem1d_dsd, .{
.tensor_access = |i|{num_wvlts} -> buf[i]
});
var idx: u16 = 0;
while (idx < num_wvlts) : (idx += 1) {
const payload = @as(u16, data[idx] & 0xffff);
const floatValue = @bitcast(f16, payload);
buf[idx] = floatValue / 2.0;
}
// WARNING: nonblock is necessary otherwise CE has no resource
// to run memcpy kernel
@fmovh(outDsd, bufDsd, .{.async = true});
}
task mainTask() void {
sendDataToEastTiles();
processAndSendSubset();
}
comptime {
@activate(main_task_id);
@bind_local_task(mainTask, main_task_id);
@set_local_color_config(exchColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ EAST } } });
@set_local_color_config(resultColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ NORTH } } });
}
recv.csl¶
param memcpy_params: comptime_struct;
param peId: u16;
// Colors
param recvColor: color;
param resultColor: color;
// Task IDs
param recv_task_id: data_task_id; // data task receives data along recvColor
// ----------
// Every PE needs to import memcpy module otherwise the I/O cannot
// propagate the data to the destination.
// memcpy module reserves input queue 0 and output queue 0
const sys_mod = @import_module( "<memcpy/memcpy>", memcpy_params);
// ----------
/// The recipient simply halves the value in the incoming wavelet and sends the
/// result to the north neighbor (halo PE).
var buf = @zeros([1]f16);
task recvTask(data: f16) void {
@block(recvColor);
buf[0] = data / 2.0;
const outDsd = @get_dsd(fabout_dsd, .{
.extent = 1,
.fabric_color = resultColor,
.output_queue = @get_output_queue(1)
});
const bufDsd = @get_dsd(mem1d_dsd, .{
.tensor_access = |i|{1} -> buf[i]
});
// WARNING: nonblock is necessary otherwise CE has no resource
// to run memcpy kernel
@fmovh(outDsd, bufDsd, .{.async = true, .unblock = recv_task_id});
}
comptime {
@bind_data_task(recvTask, recv_task_id);
const baseRoute = .{
.rx = .{ WEST }
};
const filter = .{
// Each PE should only accept three wavelets starting with the one whose
// index field contains the value peId * 3.
.kind = .{ .range = true },
.min_idx = peId * 3,
.max_idx = peId * 3 + 2,
};
if (peId == 3) {
// This is the last PE, don't forward the wavelet further to the east.
const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP } });
@set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
} else {
// Otherwise, forward incoming wavelets to both CE and to the east neighbor.
const txRoute = @concat_structs(baseRoute, .{ .tx = .{ RAMP, EAST } });
@set_local_color_config(recvColor, .{.routes = txRoute, .filter = filter});
}
// Send result wavelets to the north neighbor (i.e. the halo PEs).
@set_local_color_config(resultColor, .{ .routes = .{ .rx = .{ RAMP }, .tx = .{ NORTH } } });
}
run.py¶
#!/usr/bin/env cs_python
import argparse
import json
import numpy as np
from cerebras.sdk.sdk_utils import memcpy_view
from cerebras.sdk.runtime.sdkruntimepybind import SdkRuntime, MemcpyDataType # pylint: disable=no-name-in-module
from cerebras.sdk.runtime.sdkruntimepybind import MemcpyOrder # pylint: disable=no-name-in-module
parser = argparse.ArgumentParser()
parser.add_argument('--name', help='the test name')
parser.add_argument("--cmaddr", help="IP:port for CS system")
args = parser.parse_args()
dirname = args.name
# Parse the compile metadata
with open(f"{dirname}/out.json", encoding="utf-8") as json_file:
compile_data = json.load(json_file)
params = compile_data["params"]
MEMCPYD2H_DATA_1 = int(params["MEMCPYD2H_DATA_1_ID"])
print(f"MEMCPYD2H_DATA_1 = {MEMCPYD2H_DATA_1}")
memcpy_dtype = MemcpyDataType.MEMCPY_16BIT
runner = SdkRuntime(dirname, cmaddr=args.cmaddr)
runner.load()
runner.run()
print("step 1: streaming D2H at P0.0")
# The D2H buffer must be of type u32
out_tensors_u32 = np.zeros(4*3, np.uint32)
runner.memcpy_d2h(out_tensors_u32, MEMCPYD2H_DATA_1, 0, 0, 4, 1, 3, \
streaming=True, data_type=memcpy_dtype, order=MemcpyOrder.ROW_MAJOR, nonblock=False)
# remove upper 16-bit of each u32
result = memcpy_view(out_tensors_u32, np.dtype(np.float16))
runner.stop()
oracle = [5, 5.5, 6, 6.5, 7, 7.5, 8, 8.5, 9, 9.5, 10, 10.5]
np.testing.assert_allclose(result, oracle, atol=0.0001, rtol=0)
print("SUCCESS!")
commands.sh¶
#!/usr/bin/env bash
set -e
cslc ./layout.csl --fabric-dims=11,5 --fabric-offsets=4,1 -o out \
--params=MEMCPYH2D_DATA_1_ID:3 \
--params=MEMCPYD2H_DATA_1_ID:4 \
--memcpy --channels=1 --width-west-buf=0 --width-east-buf=0
cs_python run.py --name out