Defeating LLM Nondeterminism in Production: Batch Invariance, Not "GPU Randomness"
Defeating LLM Nondeterminism in Production: Batch Invariance, Not "GPU Randomness"
TL;DR. Even with temperature=0 and top_p=1, the same prompt can produce different completions if your inference server's batch size or batch composition changes over time. The core culprit isn't "GPU race conditions" in the forward pass — it's that many kernels are not batch-invariant, so numerics shift when the server merges or slices requests (e.g., via continuous batching). Thinking Machines' new work proposes and open-sources batch-invariant kernels (RMSNorm, matmul, attention — including a fixed-size split strategy for decoding) and demonstrates a vLLM-based deterministic pipeline.
Why this matters (for anyone shipping agents)
- Auditability & regulatory pressure. In banking, healthcare, and other regulated domains, "same input → same output" is a baseline requirement.
- Reliable evals and debugging. If your outputs drift because server load changed, you can't trust eval deltas or reproduction of bugs.
- Operational sanity. With continuous batching, multi-tenant load constantly reshapes batches; if output quality depends on that, you don't really control your product.
Quick glossary
- Deterministic decoding: Greedy or temperature-0 sampling with fixed seeds and fixed runtime configuration.
- Run-to-run determinism (kernel): Invoking the same kernel on the same inputs returns bitwise-identical outputs.
- Batch invariance (kernel/system): An operation's result for item i is independent of other items in the batch and independent of the batch size and internal slicing; i.e., same numerics for item i regardless of what else is co-processed.
- Continuous batching: Inference servers (e.g., vLLM) dynamically merge/split requests across steps to maximize throughput and GPU occupancy.
- KV cache / paged attention: Long-context inference stores K/V tensors in GPU/CPU pages; attention reads a concatenation of cached past tokens plus current tokens.
The popular hypothesis — and what's actually happening
The traditional explanation
People often blame nondeterminism on floating-point non-associativity + concurrency:
- Floating point is non-associative:
(a + b) + c != a + (b + c)in finite precision. - On GPUs, parallel reductions can complete in different orders, so "race order" → different sums → different logits → different tokens.
The observation that overturns it
For the LLM forward pass, you can repeatedly run the same matmul on the same inputs and get bitwise-identical results; i.e., the common forward kernels are run-to-run deterministic under fixed shapes and code paths. The forward pass rarely needs atomics; modern kernels use techniques (e.g., split reductions combined deterministically, or semaphore ordering) that avoid nondeterminism while retaining parallel speed.
So where does user-visible nondeterminism come from?
From batch-size/shape-dependent numerics. Many kernels choose different reduction trees, tilings, or code paths depending on batch size, sequence length, and how the KV cache is partitioned — and all of those are influenced by continuous batching and prefill/decoding strategies. Your request's numbers change because other users changed the batch.
Batch invariance is the real requirement
To make an LLM server deterministic from the user's perspective, each op must be batch-invariant:
- The reduction order and tiling per item must not depend on batch size, batch composition, or chunking of that item's sequence.
- The attention reduction must be invariant to whether tokens arrive via prefill vs. decoding and to the exact KV cache boundary conditions (e.g., partial blocks).
In practice, that means fixing or constraining algorithms/parameters across three reduction-bearing ops:
1) RMSNorm, 2) Matmul, 3) Attention.
What changes under the hood (and why)
1) RMSNorm: data-parallel strategy (avoid split reductions)
- Goal: Each element's reduction over
hidden_dimmust have a fixed order regardless of batch size. - Strategy: Use data parallelism across batch elements (one batch element per core / SM). When batch is small (few rows), resist the urge to "split" reductions across cores (which would alter reduction order). Accept a small slowdown at tiny batch sizes instead of changing the reduction tree.
- Edge case: If you truly must split (extremely small batch sizes), you've lost batch invariance. The paper's recommendation is to ignore that optimization and eat the (usually minor) perf cost to maintain determinism.
2) Matmul: avoid Split-K and shape-adaptive tiling for determinism
- Problem: Many matmul libraries dynamically choose Split-K or different tile sizes based on shapes/batch size. That changes accumulation order.
- Fix: Compile and pin one kernel configuration (tile sizes, wave quantization, etc.) for all shapes relevant to LLM inference. Avoid Split-K in the forward pass.
- Trade-off: ~20% perf loss versus highly tuned cuBLAS at some shapes is reported, but acceptable for inference with large model dims where Split-K is least needed.
3) Attention: two hard parts
(a) Prefill vs. decoding symmetry and KV boundaries
- Many engines handle cached K/V and current K/V in separate loops/blocks. With paged KV and block sizes (e.g., 32/64), you can create different patterns of full vs. masked blocks across the "cache" and "current" segments, which changes the reduction order for the same query token.
- Fix: Materialize/merge the KV layout before launching the attention kernel so the kernel "sees" a consistent, contiguous K/V region. Ensure the kernel's reduction order over that region is identical whether those K/V came from cache or current tokens.
(b) Decoding requires split-KV — do it in a fixed-size way
- During decoding, query length ≈ 1, so to saturate the GPU you must parallelize over the KV length (the reduction dimension). Traditional "balanced" schedulers split KV into equal-length chunks based on how many queries are present — that makes reduction order depend on batch composition.
- Fix: Use a fixed split size, not a fixed number of splits.
- Example: For KV length 1000, use 3×256 + 1×232 regardless of how many queries are in flight.
- The number of splits varies, but the order of elemental reductions per query is identical across batch compositions.
- Implementation note: This required light modifications in the FlexAttention backend to pin scheduling/tiling and to update KV page tables before the kernel.
Where continuous batching and kernel fusion fit in
- Continuous batching (vLLM, SGLang, etc.). Great for throughput; bad for determinism if kernels are not batch-invariant, because the effective shapes per step (and per token!) keep changing as requests join/leave the batch. Batch-invariant kernels make continuous batching compatible with determinism.
- Kernel fusion. Fusion (e.g., FlashAttention v2) reduces memory traffic and accelerates inference. Fusion per se is not the cause of nondeterminism; the issue is how fused kernels pick reduction/tiling strategies. Fused attention that separately loops "cache" vs "current" or picks splits based on batch composition can break batch invariance; fused attention with fixed, shape-independent strategies preserves it.
Implementation: how Thinking Machines made vLLM deterministic
- Kernel library. A set of batch-invariant ops (RMSNorm, matmul, attention) wired via
torch.Libraryso you can transparently swap them into PyTorch models with minimal code changes. - Attention backend. The prototype integrates with PyTorch FlexAttention to control scheduling/tiling and update KV layouts before the kernel runs.
- vLLM demo / PoC. They show a vLLM configuration that routes core ops through the batch-invariant implementations to produce identical token streams for a prompt, regardless of batch size or server load.
- Performance notes.
- Matmul: ~20% slower than cuBLAS under some shapes when pinning a single config (no Split-K, fixed tiles).
- Attention: The fixed split-size FlashDecode preserves batch invariance with modest overhead; the concrete hit depends on KV length and GPU.
What this doesn't solve
- Sampling nondeterminism: If you use non-greedy decoding (
temperature>0, nucleus, etc.), your sampler must be controlled (PRNG seeding, device placement), and even then, across different hardware or different libraries, numerics can diverge. The work here targets the forward pass numerics so that greedy / temp=0 runs are reproducible across batchings. - Cross-stack drift: Determinism still assumes identical model weights, tokenizer, runtime versions, GPU drivers, and kernel choices across runs. Change any of those and you can get drift, batch-invariance or not.
- Backprop/training: Backward passes regularly use atomics (e.g., FlashAttention bwd) or different algorithms; this work is about inference.
Practical checklist: making your server deterministic
1) Pin numerics + stack
- Same model weights/tokenizer, same PyTorch/CUDA, same attention backend (e.g., FlexAttention), same drivers.
- Disable opportunistic autotuning or shape-adaptive kernel selection.
2) Swap in batch-invariant kernels
- RMSNorm: Use data-parallel reductions; avoid split reductions at small batch sizes.
- Matmul: Avoid Split-K; compile and pin a single tile/wave config for your shapes.
- Attention:
- Update/merge KV page tables before kernel; don't treat cache vs current as separate reduction phases.
- Use fixed split-size FlashDecode so decoding reduction order is independent of batch composition.
3) Server integration
- vLLM or SGLang: select a backend that supports FlexAttention; insert
torch.Librarypatches to route ops to batch-invariant versions. - Keep continuous batching enabled — now it won't change outputs.
4) Validation
- Build a harness that:
- Replays the same prompt N times under different concurrent loads and batch sizes.
- Differs logits and tokens step-by-step, not only final strings.
- Confirm bitwise or ulp-level identity for logits; strict token identity for outputs.
Deeper technical notes (for infra folks)
- Why most forward ops don't need atomics. With adequate batch parallelism, reductions can be done within a threadblock/SM and combined deterministically via fixed trees or semaphores. Atomics are more common in backward passes (e.g., scatter-add).
- Why fixed tile sizes matter. GPU GEMM performance often relies on tile/wave quantization that changes with shape. Pinning a single config avoids tiling changes that subtly reorder accumulation.
- KV "boundary conditions" are the hidden landmine. With block size
Band KV lengthL, separate loops over cache/current KVs create extra masked blocks (e.g., 5 blocks instead of 4 for the same total length), changing reduction order. "Pre-assemble" K/V and masks so the same query token always reduces over the same block sequence. - Fixed split-size vs fixed #splits. Fixed #splits (equal partitions) means the per-split spans depend on Q (number of queries processed), breaking batch invariance. Fixed split size holds spans constant; the kernel emits however many splits are needed to cover KV.
What to watch next
- Upstreaming into inference servers. vLLM PRs and equivalent issues in SGLang/FlashInfer/FlexAttention are already in motion. Expect flags like
--deterministicthat pin kernels and schedulers. - Beyond greedy: Deterministic probabilistic decoding (temperature>0) needs careful PRNG control, reproducible sorting/top-k numerics, and pinned device placement to avoid cross-device drift.
- Multi-GPU / tensor parallel. Ensuring identical all-reduce order and identical sharding layouts is the next frontier for full cluster-level determinism.
Bottom line
"Make it deterministic" isn't about forbidding GPUs from running in parallel. It's about engineering batch-invariant numerics so that continuous batching no longer changes the math. Thinking Machines' batch-invariant kernels show this is practical today — with small, predictable performance costs — and unlock the kind of reproducibility production teams have been asking for for years.
Key Takeaways
Whether you're approaching this from a business or technical perspective, the core insight remains: true deterministic AI requires batch-invariant operations. The traditional focus on "GPU randomness" misses the real issue—that modern inference infrastructure changes the underlying mathematics based on server load.
Next Steps
- For Business Leaders: Assess which of your AI applications require deterministic behavior and plan a phased implementation
- For Technical Teams: Start with the validation tools to measure your current nondeterminism, then implement batch-invariant kernels for critical systems
The future of reliable AI systems depends on getting the infrastructure right. Batch invariance isn't just a technical nicety—it's the foundation for trustworthy AI in production.