Compute

Compute Shaders

All compute kernels are Slang modules. One global descriptor heap. No per-shader bindings. Runs on any Vulkan GPU.

Why Slang

Slang provides a true module system (import), generics, interfaces, and compiles to SPIR-V, HLSL, Metal, CUDA, and CPU targets. Ships with the Vulkan SDK. No GLSL. No .comp files.

Bindless Model

All shaders use a single global RWByteAddressBuffer heap[] with 65,536 slots. Buffer indices are passed via push constants. No per-shader descriptor set layouts. One pipeline layout for everything.

Copy.slang

import oa;
struct PushConstants {
uint x_idx;
uint out_idx;
uint count;
};
[[vk::push_constant]] PushConstants pc;
[[vk::binding(0, 0)]] RWByteAddressBuffer heap[];
[shader("compute")]
[numthreads(256, 1, 1)]
void main(uint3 tid : SV_DispatchThreadID) {
uint idx = tid.x;
if (idx >= pc.count) return;
OaStore(heap[pc.out_idx], idx, OaLoad(heap[pc.x_idx], idx));
}

Math Modules

Reusable math modules live in oa/include/oa/math/ — pure math with no entry points or bindings. Import them in any shader.

Example.slang

import activations; // Silu, Gelu, Swiglu
import reductions; // ReduceSum256, InvRms
import storage; // OaLoad, OaStore (BF16/FP32)
import complex; // OaComplex, Cmul, Cadd
// Modules own their groupshared memory internally.
// Pure math — no bindings, no push constants.
float invRms = InvRms(buffer, base, cols, eps, lane);

Workgroup Conventions

PatternnumthreadsGroups
Element-wise(256, 1, 1)DivCeil(N, 256)
Row-reduction(256, 1, 1)1 per row
GEMM(16, 16, 1)tiled
Scan (SSM)(1024, 1, 1)B * DState

Included Shaders

40+ compute shaders ship with the library — forward and backward kernels covering the full ML pipeline:

CategoryShaders
Tensor Opsadd, sub, mul, neg, fill, scale, matmul, transpose, gather, embedding
Reductionssum, max, argmax, reduce_cols, reduce_mean
Activationssilu, gelu, relu, softmax, swiglu
Normalizationrmsnorm, layernorm
Trainingcross_entropy, conv1d, byte_embed + all _backward variants
Optimizersadamw
SSMssm_scan, ssm_scan_complex, matrix_combine + _backward