- NaNs don’t originate the place they seem — they silently propagate throughout layers
torch.autograd.set_detect_anomalyis too sluggish and sometimes deceptive for actual debugging- A ahead hook–primarily based detector can catch NaNs on the precise layer and batch they first happen
- Overhead is ~3–4 ms per ahead cross, far decrease than anomaly detection (particularly on GPU)
- Gradient explosion is the true root trigger normally — catching it early prevents NaNs fully
- The system logs structured occasions (layer, batch, stats) for exact debugging
- Designed for manufacturing: thread-safe, memory-bounded, and scalable
It was batch 47,000. A ResNet variant I had been coaching for six hours on a customized medical imaging dataset. The loss was converging cleanly — 1.4, 1.1, 0.87, 0.73 — after which, nothing. Not an error. Not a crash. Simply nan.
I added torch.autograd.set_detect_anomaly(True) and restarted. The coaching slowed to a crawl — roughly 7–10× longer per batch on CPU alone — and after three hours I lastly acquired a stack hint pointing to a layer that, frankly, seemed superb. The actual wrongdoer was a studying charge scheduler interacting badly with a customized normalization layer two layers upstream. set_detect_anomaly had pointed me on the symptom, not the supply.
That debugging session price me most of a day. So I constructed one thing higher.
NaNs don’t crash your mannequin — they quietly corrupt it. By the point you discover, you’re already debugging the flawed layer.
Full code: https://github.com/Emmimal/pytorch-nan-detector/
The Downside with set_detect_anomaly
PyTorch ships with torch.autograd.set_detect_anomaly(True), which is the usual advice for debugging NaN points. It really works by retaining the total computation graph and checking for anomalies throughout the backward cross. That is highly effective, but it surely comes with severe prices that make it unsuitable for something past a fast native sanity verify.
The core concern is that it forces PyTorch’s autograd engine right into a synchronous mode the place it saves intermediate activations for each single operation. On GPU, this implies breaking the asynchronous execution pipeline — each kernel launch has to finish earlier than the subsequent one begins. The end result, as reported within the PyTorch documentation and extensively noticed in apply, is an overhead that ranges from roughly 10–15× on CPU to 50–100× on GPU for bigger fashions [1][2].
There’s a second drawback: set_detect_anomaly factors you at the place the NaN propagated to within the backward cross, not essentially the place it originated. If a NaN enters your community at layer 3 of a 50-layer mannequin, the backward cross will floor an error someplace within the gradient computation for a later layer, and you’re left working backward from there.
My benchmark, run on a small CPU MLP (64→256→256→10), measured:
| Methodology | Imply latency | Overhead vs baseline |
|---|---|---|
| No detection | ~0.60 ms | baseline |
| NaNDetector (ahead hooks) | ~3–4 ms | ~5–6× |
set_detect_anomaly |
~7–8 ms | ~12–13× |
On this small mannequin absolutely the distinction is modest. At scale — a transformer with lots of of tens of millions of parameters on a number of GPUs — the hole is the distinction between a coaching run that completes and one that doesn’t.
The Strategy: Ahead Hooks

PyTorch’s register_forward_hook API allows you to connect a callback to any nn.Module that fires each time that module completes a ahead cross [3]. The callback receives the module itself, its inputs, and its outputs. This implies you possibly can examine each tensor flowing by means of each layer in actual time — with no impression on the computation graph, no compelled synchronization, and no retained activations.
The important thing perception is that you simply solely must do the NaN verify, not replay the computation. A verify in opposition to torch.isnan() and torch.isinf() on an output tensor is a single CUDA kernel invocation and completes in microseconds.
def hook(module, inputs, output):
if torch.isnan(output).any():
print(f"NaN detected in {layer_name}")
That’s the core of the thought. What follows is the production-hardened model.
The Implementation
The total supply is obtainable at: https://github.com/Emmimal/pytorch-nan-detector/
I’ll stroll by means of the 4 elements that matter.
Part 1: The NaNEvent dataclass
When a NaN is detected, you want greater than a print assertion. You want a structured file you possibly can examine after the very fact, log to disk, or ship to an alerting system.
@dataclass
class NaNEvent:
batch_idx: int
layer_name: str
module_type: str
input_has_nan: bool
output_has_nan: bool
input_has_inf: bool
output_has_inf: bool
output_shape: tuple
output_stats: dict = area(default_factory=dict)
is_backward: bool = False
The output_stats area comprises the min, max, and imply of the finite values within the output tensor in the intervening time of detection. That is surprisingly helpful — a layer output the place 3 values are NaN however the remaining are finite tells a special story than one that’s all NaN.
The is_backward flag distinguishes whether or not the occasion was caught in a ahead hook or a backward hook, which issues for root trigger evaluation.
Part 2: Thread-safe hook registration
Crucial manufacturing consideration is thread security. PyTorch’s DataLoader runs employee processes that may set off ahead hooks from background threads. In the event you mutate triggered = True and self.occasion = ev with no lock, you’re going to get race situations on multi-worker setups.
self._lock = threading.Lock()
def _make_fwd_hook(self, layer_name: str):
def hook(module, inputs, output):
with self._lock:
if self.triggered and self.stop_on_first:
return
current_batch = self._batch_idx
# ... tensor checks occur exterior the lock
if out_nan or out_inf:
self._record_event(...) # lock re-acquired inside
return hook
The tensor checks themselves occur exterior the lock as a result of torch.isnan() is read-only and thread-safe. Solely the shared state mutations are locked.
Part 3: Bounded reminiscence
A refined concern with lengthy coaching runs: should you accumulate overhead timings in an unbounded checklist, you’ll finally exhaust reminiscence on runs lasting tens of millions of batches. The repair is a straightforward cap:
_OVERHEAD_CAP = 1000
with self._lock:
if len(self._overhead_ms) < self._OVERHEAD_CAP:
self._overhead_ms.append(elapsed)
The identical logic applies to all_events when stop_on_first=False — a max_events parameter (default 100) prevents unbounded accumulation throughout pathological runs.
Part 4: Gradient norm guard
The most typical real-world path to a NaN shouldn’t be a bug that straight produces nan — it’s a studying charge that’s too excessive inflicting gradient norms to blow up to inf, which then propagates into the weights and produces NaN activations on the subsequent ahead cross. By the point your ahead hook fires, you’re already one step too late.
The check_grad_norms() methodology addresses this by strolling all parameters after loss.backward() and logging a GradEvent for any parameter whose gradient norm exceeds a threshold:
def check_grad_norms(self) -> bool:
if self.grad_norm_warn is None:
return False
for title, module in self.mannequin.named_modules():
for pname, param in module.named_parameters(recurse=False):
if param.grad is None:
proceed
norm = param.grad.detach().float().norm().merchandise()
if not math.isfinite(norm) or norm > self.grad_norm_warn:
# log GradEvent
Within the demo beneath, this methodology catches gradient explosion at batch 1 — one full coaching step earlier than the NaN would have appeared within the ahead cross.

Utilization
Primary: context supervisor
from nan_detector import NaNDetector
with NaNDetector(mannequin) as det:
for batch_idx, (x, y) in enumerate(loader):
det.set_batch(batch_idx)
loss = criterion(mannequin(x), y)
loss.backward()
det.check_grad_norms()
optimizer.step()
if det.triggered:
print(det.occasion)
break
When the detector fires, det.occasion comprises the total NaNEvent with layer title, module sort, batch index, and output statistics.
Manufacturing: drop-in coaching loop
from nan_detector import train_with_nan_guard
losses, occasion = train_with_nan_guard(
mannequin, loader, criterion, optimizer,
machine="cuda",
grad_norm_warn=50.0,
)
if occasion:
print(f"NaN at batch {occasion.batch_idx}, layer {occasion.layer_name}")
Superior: backward hooks + readable layer names
For catching gradient NaNs straight (not simply norm warnings), allow check_backward=True. Use OrderedDict when constructing Sequential fashions to get readable names in all log output:
from collections import OrderedDict
mannequin = nn.Sequential(OrderedDict([
("fc1", nn.Linear(16, 32)),
("relu1", nn.ReLU()),
("fc2", nn.Linear(32, 1)),
]))
with NaNDetector(mannequin, check_backward=True, grad_norm_warn=10.0) as det:
...
With out OrderedDict, PyTorch names layers by index (0.weight, 2.bias). With it, you get fc1.weight, fc2.bias — a small factor that saves actual time when debugging deep fashions.
Skipping layers
Some layer sorts are anticipated to supply non-finite outputs below regular situations — nn.Dropout throughout eval, sure normalization layers throughout the first ahead cross earlier than operating stats are established. Skip them with:
det = NaNDetector(mannequin, skip_types=(nn.Dropout, nn.BatchNorm1d))
Demo Output
Operating the three demos produces the next output:
────────────────────────────────────────────────────────────
Demo 1: Ahead NaN detection + loss curve plot
────────────────────────────────────────────────────────────
[NaNDetector] Connected 5 hooks.
============================================================
NaN/Inf detected! [FORWARD PASS]
Batch : 12
Layer : layer4
Kind : Linear
Flags : NaN in INPUT, NaN in OUTPUT
Out form : (8, 1)
Out stats : min=n/a (all non-finite) max=n/a (all non-finite) imply=n/a (all non-finite)
============================================================
[NaNDetector] Indifferent. Avg overhead: 0.109 ms/forward-pass
────────────────────────────────────────────────────────────
Demo 2: Backward / grad-norm detection + grad norm plot
────────────────────────────────────────────────────────────
[NaNDetector] Connected 8 hooks (+ backward).
[GradNorm WARNING] batch=1 layer=fc1.weight norm=inf threshold=10.0
[GradNorm WARNING] batch=1 layer=fc1.bias norm=inf threshold=10.0
[GradNorm WARNING] batch=1 layer=fc2.weight norm=inf threshold=10.0
[GradNorm WARNING] batch=1 layer=fc2.bias norm=4.37e+18 threshold=10.0
Caught at batch 1

The hook overhead of 0.109 ms per ahead cross in Demo 1 is the true quantity you possibly can cite. The benchmark determine of ~3 ms displays a bigger mannequin with 5 registered hook callbacks operating concurrently — which is the extra reasonable manufacturing situation.
Identified Limitations
Ahead hooks see activations, not all computation. If a NaN originates inside a customized torch.autograd.Perform‘s backward() methodology, or inside a C++/CUDA extension that doesn’t floor by means of named nn.Module submodules, the ahead hook won’t catch it. Use check_backward=True for gradient-side protection, and grad_norm_warn for early warning.
Overhead scales with mannequin depth. The benchmark was run on a 5-layer MLP. A 200-layer transformer may have 200 hook callbacks firing per ahead cross. The overhead continues to be sub-millisecond per hook, but it surely accumulates. Mitigate through the use of skip_types to exclude non-parametric layers like ReLU, Dropout, and LayerNorm if overhead turns into a priority.
CPU benchmark ratios are noisy. The overhead ratio between NaNDetector and set_detect_anomaly diverse between 5× and 6× throughout runs in my testing, as a result of CPU microbenchmarks at sub-millisecond scale are delicate to OS scheduling and cache state. Absolutely the millisecond numbers are extra steady. The 50–100× determine cited for GPU is drawn from the PyTorch documentation and group benchmarks [1][2], not my very own GPU measurements.
What This Does Not Substitute
It is a debugging and monitoring instrument, not an alternative choice to good coaching hygiene. The usual suggestions nonetheless apply: gradient clipping (torch.nn.utils.clip_grad_norm_), cautious studying charge scheduling, enter normalization, and weight initialization. NaNDetector tells you the place and when the issue occurred — it doesn’t let you know why, and fixing the basis trigger nonetheless requires engineering judgment.
If you’re hitting NaNs in mixed-precision coaching (fp16/bf16), the most typical culprits are loss scaling overflow and layer norm instability, and people are value investigating straight earlier than reaching for a debugging hook.
Benchmark Methodology
All benchmarks had been run on CPU (Home windows 11, PyTorch 2.x) utilizing a 4-layer MLP with enter dimension 64, two hidden layers of 256, and output dimension 10. Batch dimension was 64. Every methodology ran 30 ahead passes. The primary cross was included within the imply — cold-start results are actual and must be counted. Occasions had been measured with time.perf_counter() across the ahead name solely, not together with information loading or loss computation.
The total benchmark perform is included within the supply and could be run with benchmark(n_batches=30, batch_size=64).
References
[1] PyTorch Documentation. “Autograd Mechanics — Anomaly Detection.” pytorch.org. Out there at: https://pytorch.org/docs/stable/autograd.html#anomaly-detection
[2] PyTorch Documentation. torch.autograd.set_detect_anomaly. pytorch.org. Out there at: https://docs.pytorch.org/docs/stable/autograd.html
[3] PyTorch Documentation. torch.nn.Module.register_forward_hook. pytorch.org. Out there at: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook
[4] PyTorch Documentation. torch.nn.Module.register_full_backward_hook. pytorch.org. Out there at: https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook
[5] PyTorch Documentation. “Gradient Clipping — clip_grad_norm_.” pytorch.org. Out there at: https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
[6] Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., … & Chintala, S. (2019). PyTorch: An crucial model, high-performance deep studying library. arXiv preprint arXiv:1912.01703. https://doi.org/10.48550/arXiv.1912.01703
[7] Python Software program Basis. threading — Thread-based parallelism. Python 3 Documentation. Out there at: https://docs.python.org/3/library/threading.html
[8] Python Software program Basis. dataclasses — Information Courses. Python 3 Documentation. Out there at: https://docs.python.org/3/library/dataclasses.html
[9] Hunter, J. D. (2007). Matplotlib: A 2D graphics atmosphere. Computing in Science & Engineering, 9(3), 90–95. https://doi.org/10.1109/MCSE.2007.55
Disclosure
I constructed and wrote about this instrument myself. There is no such thing as a sponsorship, no affiliation with PyTorch or the PyTorch Basis, and no monetary relationship with any firm talked about on this article. The benchmarks had been run by myself {hardware} and are reproducible utilizing the code within the repository linked above.
All code on this article is unique. The instrument was written from scratch; no present open-source NaN detection library was used as a base. In the event you use this in your personal work, attribution is appreciated however not required — the code is MIT licensed.
The benchmark comparability in opposition to set_detect_anomaly relies by myself measurements on a particular {hardware} configuration. Outcomes will differ by mannequin structure, {hardware}, and PyTorch model. The 50–100× GPU overhead determine is drawn from PyTorch’s official documentation [1][2] and isn’t my very own GPU measurement.
Full supply code, together with all three demos and the benchmark perform: https://github.com/Emmimal/pytorch-nan-detector/

