import hashlib
import json
import logging
import os
import re
import threading
import time
import typing as ty
from collections.abc import Sequence
from contextlib import AbstractContextManager, contextmanager, nullcontext
from pathlib import Path
from typing import Any, Callable, Optional
import pandas as pd
import torch
import torch.nn as nn
import torchvision
from packaging.version import Version
try:
# PyTorch's aoti_load_package reaches for torch._inductor.codecache without
# importing it, so register the attribute up front when available.
import torch._inductor.codecache # noqa: F401
except ImportError:
pass
from rich.progress import Progress
from nvbenjo import console
from nvbenjo.cfg import TorchModelConfig, TorchRuntimeConfig
from nvbenjo.utils import AMP_PREFIX, TRANSFER_WARNING, PrecisionType, TensorLike, progress_task, sample_gpu_memory
from nvbenjo.torch_ops import * # noqa: F403
logger = logging.getLogger(__name__)
[docs]
def get_model(
type_or_path: str, device: torch.device, runtime_config: TorchRuntimeConfig, verbose=False, **kwargs
) -> ty.Any:
"""Load PyTorch model.
Parameters
----------
type_or_path : str
Model type or path. Supports prefixes to specify the model source:
- ``torchvision:<name>`` -- Load a torchvision model (e.g. ``torchvision:resnet50``), see `torchvision.models.list_models()`
- ``huggingface:<name>`` -- Load a HuggingFace AutoModel (e.g. ``huggingface:bert-base-uncased``), see https://huggingface.co/docs/transformers/model_doc/auto
- ``jit:<path>`` -- Load a TorchScript/JIT model
- ``torchexport:<path>`` -- Load a ``torch.export`` saved model
- ``aot:<path>`` -- Load a pre-compiled AOT model
- *(no prefix)* -- Path to a model saved with ``torch.save`` or ``torch.jit.save``
device : torch.device
Device to load the model onto.
runtime_config : TorchRuntimeConfig
Runtime configuration for the model.
verbose : bool, optional
Whether to print verbose output, by default False
Returns
-------
ty.Any
Loaded model.
Examples
--------
>>> model = get_model("torchvision:resnet18", device=torch.device("cpu"), runtime_config=TorchRuntimeConfig())
>>> model = get_model("/path/to/model.pth", device=torch.device("cuda"), runtime_config=TorchRuntimeConfig())
>>> model = get_model("jit:/path/to/model.pt", device=torch.device("cuda"), runtime_config=TorchRuntimeConfig())
>>> model = get_model("torchexport:/path/to/model.pt2", device=torch.device("cuda"), runtime_config=TorchRuntimeConfig())
>>> model = get_model("aot:/path/to/model.pt2", device=torch.device("cuda"), runtime_config=TorchRuntimeConfig())
>>> model = get_model("huggingface:bert-base-uncased", device=torch.device("cpu"), runtime_config=TorchRuntimeConfig())
"""
type_or_path = os.path.expanduser(type_or_path)
if type_or_path.startswith("jit:"):
if verbose and console is not None:
console.print(f"Loading jit model {type_or_path}")
type_or_path = type_or_path[len("jit:") :]
return torch.jit.load(os.path.expanduser(type_or_path), map_location=device)
elif type_or_path.startswith("torchexport:"):
if verbose and console is not None:
console.print(f"Loading torchexport model {type_or_path}")
type_or_path = type_or_path[len("torchexport:") :]
program = torch.export.load(os.path.expanduser(type_or_path))
module = program.module()
module = module.to(device)
return module
elif type_or_path.startswith("aot:"):
if verbose and console is not None:
console.print(f"Loading AOT model {type_or_path}")
type_or_path = type_or_path[len("aot:") :]
return torch._inductor.aoti_load_package(os.path.expanduser(type_or_path))
elif os.path.isfile(type_or_path):
# Path and no prefix -> try different methods
if verbose and console is not None:
console.print(f"Loading torch model {type_or_path}")
try:
model = torch.load(os.path.expanduser(type_or_path), map_location=device, weights_only=False)
model.eval()
return model
except Exception:
try:
return torch.jit.load(os.path.expanduser(type_or_path), map_location=device)
except Exception:
if Version(torch.__version__) > Version("2.1"):
try:
program = torch.export.load(os.path.expanduser(type_or_path))
module = program.module()
module = module.to(device)
return module
except Exception:
return torch._inductor.aoti_load_package(os.path.expanduser(type_or_path))
else:
raise
if type_or_path.startswith("huggingface:"):
type_or_path = type_or_path[len("huggingface:") :]
if verbose and console is not None:
console.print(f"Loading huggingface model {type_or_path}")
from transformers import AutoModel # type: ignore
return AutoModel.from_pretrained(os.path.expanduser(type_or_path)).to(device)
elif type_or_path.startswith("torchvision:"):
type_or_path = type_or_path[len("torchvision:") :]
available_torchvision_models = torchvision.models.list_models()
if type_or_path in available_torchvision_models:
if verbose and console is not None:
console.print(f"Loading torchvision model {type_or_path}")
model = torchvision.models.get_model(type_or_path, **kwargs).to(device)
model.eval()
return model
else:
available_torchvision_models = torchvision.models.list_models()
raise ValueError(
f"Invalid model {type_or_path}. Must be: \n"
"- a valid path to a saved torch model\n"
"- 'jit:<path>' for a TorchScript/JIT model\n"
"- 'torchexport:<path>' for a torch.export model\n"
"- 'aot:<path>' for a pre-compiled AOT model\n"
"- 'huggingface:<model-name>' for a HuggingFace AutoModel\n"
f"- 'torchvision:<model-name>' from {available_torchvision_models}\n"
)
def run_model_with_input(model: nn.Module | Callable, input: TensorLike) -> TensorLike:
if isinstance(input, (list, tuple)):
return model(*input)
elif isinstance(input, dict):
# Some models take the dict as a single positional arg for these accept fallback
try:
return model(**{str(k): v for k, v in input.items()})
except TypeError:
return model(input)
else:
return model(input)
def transfer_to_device(result: ty.Any, to_device: torch.device) -> ty.Any:
if hasattr(result, "to"):
return result.to(to_device)
if isinstance(result, Sequence):
return [transfer_to_device(ri, to_device=to_device) for ri in result]
elif hasattr(result, "items"):
return {k: transfer_to_device(v, to_device=to_device) for k, v in result.items()}
else:
raise ValueError(f"Unsupported result type: {type(result)} could not transfer to {to_device}")
def apply_batch_precision(batch: TensorLike, precision: PrecisionType) -> TensorLike:
def _apply_batch_precision(batch_tensor: torch.Tensor):
if AMP_PREFIX not in precision.value:
if precision == PrecisionType.FP16:
batch_tensor = batch_tensor.half()
elif precision == PrecisionType.BFLOAT16:
batch_tensor = batch_tensor.bfloat16()
elif precision == PrecisionType.FP8_E4M3FN:
batch_tensor = batch_tensor.to(torch.float8_e4m3fn)
elif precision == PrecisionType.FP8_E5M2:
batch_tensor = batch_tensor.to(torch.float8_e5m2)
else:
if precision != PrecisionType.FP32:
raise ValueError(f"Invalid precision type {precision}.")
return batch_tensor
if isinstance(batch, torch.Tensor):
batch = _apply_batch_precision(batch)
elif isinstance(batch, (list, tuple)):
batch = tuple(_apply_batch_precision(b) for b in batch)
elif isinstance(batch, dict):
batch = {k: _apply_batch_precision(v) for k, v in batch.items()}
else:
raise ValueError(f"Unsupported batch type: {type(batch)}. Must be a Tensor, Tuple, or Dict.")
return batch
def apply_non_amp_model_precision(
model: nn.Module,
precision: PrecisionType,
) -> nn.Module:
if AMP_PREFIX not in precision.value:
if precision == PrecisionType.FP16:
model = model.half()
elif precision == PrecisionType.BFLOAT16:
model = model.bfloat16()
elif precision == PrecisionType.FP8_E4M3FN:
model = model.to(torch.float8_e4m3fn)
elif precision == PrecisionType.FP8_E5M2:
model = model.to(torch.float8_e5m2)
else:
if precision != PrecisionType.FP32:
raise ValueError(f"Invalid precision type {precision}.")
return model
@contextmanager
def matmul_precision_ctxt(precision: str | None) -> ty.Generator:
old_precision = torch.get_float32_matmul_precision()
try:
if precision is not None:
torch.set_float32_matmul_precision(precision=precision)
yield True
finally:
torch.set_float32_matmul_precision(precision=old_precision)
def get_amp_ctxt_for_precision(precision: PrecisionType, device: torch.device) -> AbstractContextManager:
if AMP_PREFIX in precision.value:
valid_values = [PrecisionType.AMP, PrecisionType.AMP_FP16, PrecisionType.AMP_BFLOAT16]
if precision not in valid_values:
raise ValueError(f"Invalid AMP precision type {precision} must be one of {valid_values}")
if precision in [PrecisionType.AMP]:
ctxt = torch.autocast(device_type=device.type, enabled=True)
elif precision in [PrecisionType.AMP_FP16]:
ctxt = torch.autocast(device_type=device.type, dtype=torch.float16, enabled=True)
elif precision in [PrecisionType.AMP_BFLOAT16]:
ctxt = torch.autocast(device_type=device.type, dtype=torch.bfloat16, enabled=True)
else:
raise ValueError(f"Invalid precision type {precision}.")
else:
ctxt = nullcontext()
return ctxt
def get_model_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
[docs]
def measure_gpu_memory_allocation(
model: nn.Module | Callable, batch: TensorLike, device: torch.device, iterations: int = 3
) -> tuple[int, int]:
"""Measure peak memory usage during inference.
Returns both the PyTorch allocator peak (via torch.cuda.max_memory_allocated)
and the process-level GPU memory peak (via pynvml sampling).
Parameters
----------
model : nn.Module | Callable
The model to benchmark.
batch : TensorLike
Sample input to the model.
device : torch.device
The device where the model is located and shall be used for benchmarking.
iterations : int, optional
Number of iterations to run for measuring memory allocation, by default 3
Returns
-------
tuple[int, int]
(torch_memory_bytes, gpu_memory_bytes) — PyTorch allocator peak and
process-level GPU memory peak.
"""
is_cuda = device.type == "cuda"
if is_cuda:
torch.cuda.reset_peak_memory_stats(device=device)
max_mem = [-1]
stop_event = threading.Event()
sampler = threading.Thread(target=sample_gpu_memory, args=(device, stop_event, max_mem))
sampler.start()
time.sleep(0.01)
batch = transfer_to_device(batch, to_device=device)
if isinstance(model, nn.Module):
model = model.to(device)
try:
for _ in range(iterations):
r = run_model_with_input(model, batch)
try:
_ = transfer_to_device(r, to_device=torch.device("cpu"))
except Exception:
console.print(TRANSFER_WARNING)
finally:
if is_cuda:
stop_event.set()
sampler.join()
if is_cuda:
logger.debug(torch.cuda.memory_summary(device=device, abbreviated=True))
torch_memory = torch.cuda.max_memory_allocated(device=device)
gpu_memory = max_mem[0]
else:
torch_memory = -1
gpu_memory = -1
return torch_memory, gpu_memory
[docs]
def measure_repeated_inference_timing(
model: nn.Module,
sample: TensorLike,
batch_size: int,
model_device: torch.device,
transfer_to_device_fn: Callable = transfer_to_device,
num_runs: int = 100,
progress_callback: Optional[Callable] = None,
) -> pd.DataFrame:
"""Measure inference times.
Parameters
----------
model : nn.Module
The model to benchmark.
sample : TensorLike
Sample input to the model.
batch_size : int
The batch size of the sample.
model_device : torch.device
The device where the model is located and shall be used for benchmarking.
transfer_to_device_fn : Callable, optional
Function to transfer data to the specified device, by default transfer_to_device
num_runs : int, optional
Number of inference runs to perform, by default 100
progress_callback : Optional[Callable], optional
Callback function to report progress, by default None
Returns
-------
pd.DataFrame
DataFrame containing timing results.
Examples
--------
Measure Inference::
import torch
from nvbenjo.torch_utils import measure_repeated_inference_timing
from nvbenjo.torch_utils import get_model
from nvbenjo.cfg import TorchRuntimeConfig
model = get_model("torchvision:resnet18", device=torch.device("cpu"), runtime_config=TorchRuntimeConfig())
sample = torch.randn(2, 3, 224, 224) # batch size 2
results = measure_repeated_inference_timing(
model=model,
sample=sample,
batch_size=2,
model_device=torch.device("cpu"),
num_runs=2
)
"""
results_raw = []
for _ in range(num_runs):
start_on_cpu = time.perf_counter()
device_sample = transfer_to_device_fn(sample, model_device)
if model_device.type == "cuda":
start_event = torch.cuda.Event(enable_timing=True)
stop_event = torch.cuda.Event(enable_timing=True)
start_event.record() # For GPU timing
start_on_device = time.perf_counter() # For CPU timing
device_result = run_model_with_input(model, device_sample)
if model_device.type == "cuda":
stop_event.record()
torch.cuda.synchronize()
# elapsed_on_device = stop_event.elapsed_time(start_event)
elapsed_on_device = start_event.elapsed_time(stop_event) / 1000.0
stop_on_device = time.perf_counter()
else:
stop_on_device = time.perf_counter()
elapsed_on_device = stop_on_device - start_on_device
try:
transfer_to_device_fn(device_result, torch.device("cpu"))
except Exception:
console.print(TRANSFER_WARNING)
stop_on_cpu = time.perf_counter()
assert elapsed_on_device > 0
results_raw.append(
{
"time_cpu_to_device": start_on_device - start_on_cpu,
"time_inference": elapsed_on_device,
"time_device_to_cpu": stop_on_cpu - stop_on_device,
"time_total": stop_on_cpu - start_on_cpu,
"time_total_batch_normalized": (stop_on_cpu - start_on_cpu) / batch_size,
}
)
if progress_callback is not None:
progress_callback()
results_raw = pd.DataFrame(results_raw)
return results_raw
def _file_meta(type_or_path: str) -> Optional[dict]:
path = type_or_path
for prefix in ("jit:", "torchexport:", "aot:"):
if path.startswith(prefix):
path = path[len(prefix) :]
break
path = os.path.expanduser(path)
if os.path.isfile(path):
st = os.stat(path)
return {"name": os.path.basename(path), "size": st.st_size, "mtime": st.st_mtime}
return None
def _aot_cache_path(
cache_dir: str,
model_cfg: TorchModelConfig,
batch_size: int,
runtime_cfg: TorchRuntimeConfig,
device: torch.device,
) -> Path:
key_parts: dict = {
"torch": torch.__version__,
"cuda_version": torch.version.cuda,
"type_or_path": model_cfg.type_or_path,
"model_kwargs": sorted(model_cfg.kwargs.items()),
"file_meta": _file_meta(model_cfg.type_or_path),
"shape": list(model_cfg.shape),
"batch_size": batch_size,
"precision": runtime_cfg.precision.value,
"compile_kwargs": sorted((k, v) for k, v in runtime_cfg.compile_kwargs.items() if k != "package_path"),
"device_type": device.type,
}
if device.type == "cuda":
key_parts["sm"] = torch.cuda.get_device_capability(device)
key_parts["device_name"] = torch.cuda.get_device_name(device)
digest = hashlib.sha256(json.dumps(key_parts, default=str, sort_keys=True).encode()).hexdigest()[:16]
safe_name = re.sub(r"[^A-Za-z0-9_.-]", "_", model_cfg.name)
return Path(cache_dir).expanduser() / f"{safe_name}_{digest}.pt2"
def _export_program(model: Any, batch: TensorLike, device: torch.device) -> Any:
if not isinstance(model, nn.Module):
return model.to(device)
device_batch = transfer_to_device(batch, device)
if isinstance(device_batch, dict):
return torch.export.export(model.to(device), args=(), kwargs=device_batch)
if isinstance(device_batch, (tuple, list)):
batch_args = tuple(device_batch)
else:
batch_args = (device_batch,)
return torch.export.export(model.to(device), batch_args)
def _aot_compile_or_load(
model: Any,
batch: TensorLike,
device: torch.device,
model_cfg: TorchModelConfig,
batch_size: int,
runtime_cfg: TorchRuntimeConfig,
progress_bar: Optional[Progress],
) -> Any:
cache_path = (
_aot_cache_path(runtime_cfg.cache_dir, model_cfg, batch_size, runtime_cfg, device)
if runtime_cfg.cache_dir
else None
)
# run_single_threaded avoids internal threading that conflicts with
# external CUDA graph capture (pytorch/pytorch#158834, fixed in 2.8+).
# see https://github.com/pytorch/pytorch/commit/85467ed063d284fa21a2f1d2adfec8fda544923d
load_kwargs: dict[str, Any] = {}
if runtime_cfg.cuda_graphs:
load_kwargs["run_single_threaded"] = True
if cache_path is not None and cache_path.exists():
with progress_task(progress_bar, f" Load AOT compiled model {cache_path}...", total=None):
try:
return torch._inductor.aoti_load_package(str(cache_path), **load_kwargs)
except Exception:
console.print(f"Failed to load AOT cache {cache_path}, falling back to recompile")
console.print_exception()
program = _export_program(model, batch, device)
program = program.run_decompositions()
compile_kwargs = dict(runtime_cfg.compile_kwargs)
if cache_path is not None:
if "package_path" in compile_kwargs:
raise ValueError("Cannot set both runtime_config.cache_dir and compile_kwargs['package_path']")
cache_path.parent.mkdir(parents=True, exist_ok=True)
# Compile to a sibling tmp path, then atomic-rename. Keep the .pt2 suffix
# because aoti_compile_and_package validates package_path ends in .pt2.
tmp_path = cache_path.with_name(f"{cache_path.stem}.tmp{cache_path.suffix}")
compile_kwargs["package_path"] = str(tmp_path)
with progress_task(progress_bar, " AOT compiling...", total=None):
torch._inductor.aoti_compile_and_package(program, **compile_kwargs)
os.replace(tmp_path, cache_path)
return torch._inductor.aoti_load_package(str(cache_path), **load_kwargs)
else:
with progress_task(progress_bar, " AOT compiling...", total=None):
package_path = torch._inductor.aoti_compile_and_package(program, **compile_kwargs)
return torch._inductor.aoti_load_package(package_path, **load_kwargs)
def _copy_into(dst: TensorLike, src: TensorLike) -> None:
"""In-place copy ``src`` into ``dst`` matching nested structure."""
if dst is src:
# no cost if already moved
return
if isinstance(dst, torch.Tensor):
if not isinstance(src, torch.Tensor):
raise ValueError(f"Type mismatch copying into graph buffer: {type(dst)} vs {type(src)}")
dst.copy_(src, non_blocking=True)
return
if isinstance(dst, (list, tuple)) and isinstance(src, (list, tuple)):
for d, s in zip(dst, src):
_copy_into(d, s)
return
if isinstance(dst, dict) and isinstance(src, dict):
for k, d in dst.items():
_copy_into(d, src[k]) # type: ignore[index] # ty: ignore[invalid-argument-type]
return
raise ValueError(f"Unsupported batch type for CUDA graph copy: {type(dst)}")
class _CudaGraphedModel:
"""Callable that copies inputs into captured buffers and replays a CUDA graph.
The captured ``torch.cuda.CUDAGraph`` records device pointers for ``static_input``
and ``static_output``; this wrapper bundles all three so the graph stays valid
(replaying after the buffers are freed is undefined behavior).
"""
def __init__(
self,
graph: "torch.cuda.CUDAGraph",
static_input: TensorLike,
static_output: ty.Any,
device: torch.device,
model: ty.Any = None,
):
self.graph = graph
self.static_input = static_input
self.static_output = static_output
self.device = device
# Keep the original model alive so its CUDA resources aren't freed.
self._model = model
if isinstance(static_input, dict):
self._pick_src = lambda args, kwargs: kwargs
elif isinstance(static_input, (list, tuple)):
self._pick_src = lambda args, kwargs: args
else:
self._pick_src = lambda args, kwargs: args[0]
def __call__(self, *args, **kwargs) -> ty.Any:
_copy_into(self.static_input, self._pick_src(args, kwargs))
self.graph.replay()
return self.static_output
def transfer_to_device(self, x: ty.Any, to_device: torch.device) -> ty.Any:
"""Drop-in replacement for ``transfer_to_device`` for use in the timing loop.
For CPU→graph-device, copies into ``static_input`` and synchronizes so the
cost is reflected in ``time_cpu_to_device``. For the reverse direction
(e.g., GPU→CPU output transfer), falls back to the regular transfer.
"""
if to_device == self.device:
_copy_into(self.static_input, x)
torch.cuda.synchronize(self.device)
return self.static_input
return transfer_to_device(x, to_device)
def _cuda_graph_capture(
model: nn.Module | Callable,
batch: TensorLike,
device: torch.device,
num_warmup_iters: int,
capture_kwargs: Optional[dict] = None,
progress_bar: Optional[Progress] = None,
) -> _CudaGraphedModel:
"""Capture ``model(batch)`` as a CUDA graph and return a copy-replay callable.
Performs ``num_warmup_iters`` warmup iterations on a side stream before capture
so cuDNN autotune and lazy allocations settle, then captures one graph keyed
by the structure/dtype/shape of ``batch``.
"""
if device.type != "cuda":
raise ValueError(f"_graph_capture requires a CUDA device, got {device}")
static_input = transfer_to_device(batch, device)
with progress_task(progress_bar, " CUDA graph warm-up", total=num_warmup_iters) as task:
s = torch.cuda.Stream(device=device)
s.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(s):
for _ in range(num_warmup_iters):
static_output = run_model_with_input(model, static_input)
if progress_bar is not None and task is not None:
progress_bar.advance(task)
torch.cuda.current_stream(device).wait_stream(s)
with progress_task(progress_bar, " CUDA graph capture", total=None):
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, **(capture_kwargs or {})):
static_output = run_model_with_input(model, static_input)
return _CudaGraphedModel(graph, static_input, static_output, device, model=model)