Compute
Compute Shaders
All compute kernels are Slang modules. One global descriptor heap. No per-shader bindings. Runs on any Vulkan GPU.
Why Slang
Slang provides a true module system (import), generics, interfaces, and compiles to SPIR-V, HLSL, Metal, CUDA, and CPU targets. Ships with the Vulkan SDK. No GLSL. No .comp files.
Bindless Model
All shaders use a single global RWByteAddressBuffer heap[] with 65,536 slots. Buffer indices are passed via push constants. No per-shader descriptor set layouts. One pipeline layout for everything.
Copy.slang
import oa;struct PushConstants {uint x_idx;uint out_idx;uint count;};[[vk::push_constant]] PushConstants pc;[[vk::binding(0, 0)]] RWByteAddressBuffer heap[];[shader("compute")][numthreads(256, 1, 1)]void main(uint3 tid : SV_DispatchThreadID) {uint idx = tid.x;if (idx >= pc.count) return;OaStore(heap[pc.out_idx], idx, OaLoad(heap[pc.x_idx], idx));}
Math Modules
Reusable math modules live in oa/include/oa/math/ — pure math with no entry points or bindings. Import them in any shader.
Example.slang
import activations; // Silu, Gelu, Swigluimport reductions; // ReduceSum256, InvRmsimport storage; // OaLoad, OaStore (BF16/FP32)import complex; // OaComplex, Cmul, Cadd// Modules own their groupshared memory internally.// Pure math — no bindings, no push constants.float invRms = InvRms(buffer, base, cols, eps, lane);
Workgroup Conventions
| Pattern | numthreads | Groups |
|---|---|---|
| Element-wise | (256, 1, 1) | DivCeil(N, 256) |
| Row-reduction | (256, 1, 1) | 1 per row |
| GEMM | (16, 16, 1) | tiled |
| Scan (SSM) | (1024, 1, 1) | B * DState |
Included Shaders
40+ compute shaders ship with the library — forward and backward kernels covering the full ML pipeline:
| Category | Shaders |
|---|---|
| Tensor Ops | add, sub, mul, neg, fill, scale, matmul, transpose, gather, embedding |
| Reductions | sum, max, argmax, reduce_cols, reduce_mean |
| Activations | silu, gelu, relu, softmax, swiglu |
| Normalization | rmsnorm, layernorm |
| Training | cross_entropy, conv1d, byte_embed + all _backward variants |
| Optimizers | adamw |
| SSM | ssm_scan, ssm_scan_complex, matrix_combine + _backward |