Tensor Deduplication for Multi-Model Inference
Multi-model workloads are the norm: A/B tests, customer fine-tunes, safety variants, multi-stage pipelines. GPU memory requirements scales linearly with model count, and VRAM is the limiting resource. Tensor dedup can make a big difference here.
Summary
- Problem: Multi-model workloads are the norm: A/B tests, customer fine-tunes, safety variants, multi-stage pipelines. GPU memory scales linearly with model count, and VRAM is the limiting resource.
- Solution: Tensor deduplication automatically identifies and shares bit-identical weight tensors across models, requiring no checkpoint modifications.
- Results: Across diffusion and LLM workloads, real-world savings range from 3–32%. DeepFloyd IF stages share 18.87 GB (32% reduction). Synthetic upper bound is 50%.
- Overhead: Hashing adds <1% to model load time. Zero runtime overhead since the forward pass is unchanged.
- Compatibility: Works with HuggingFace safetensors, GGUF, and Diffusers pipelines. No changes to training or checkpoints required.
Multi-Model Memory Bloat
Modern inference deployments rarely serve a single model. Production systems routinely load:
- Multiple fine-tuned variants (chat, code, domain-specific)
- A/B test candidates running simultaneously
- Multi-stage pipelines (e.g., diffusion models with shared text encoders)
- Customer-specific adapters merged into standalone checkpoints
Each model loads its full weight set into GPU memory, even when most tensors are bit-identical to models already loaded. A deployment serving 5 variants of a 7B model allocates 35 GB, despite potentially 80%+ weight overlap.
Introduction
Tensor deduplication works like a content-addressable store for model weights: if two models contain identical tensors, I load the data once and let all models share the same GPU memory.
The mechanism is model-agnostic. It operates at the weight-loading layer, computing cryptographic fingerprints of each tensor and reusing storage when matches are found. Reused tensors share the exact same data_ptr() on GPU, and no model sees any behavioral change.
Experiments:
- DeepFloyd IF (multi-stage diffusion): 32% memory reduction (18.87 GB saved)
- SDXL + CLIP (cross-model sharing): 3.4% reduction (234 MB saved)
- Synthetic upper bound (identical models): 50% reduction
Why LoRA Isn’t Enough
vLLM’s LoRA swapping is excellent when models are structured as base + adapter. But most production fine-tunes are published as merged checkpoints, where the LoRA deltas have been baked in and the result looks like an independent model.
Merged checkpoints lose all metadata about which layers changed. Even when 90% of weights are bit-identical to the base, the system loads everything again. Fine-tuned models cannot participate in vLLM’s LoRA sharing mechanism.
This pattern extends beyond LLMs. Multi-stage diffusion pipelines (SDXL, DeepFloyd IF) load identical text encoders and safety modules multiple times. Each stage treats shared components as independent weights.
Tensor Deduplication Design
The mechanism operates at the weight-loading layer:
-
As each tensor loads, compute a fingerprint:
(shape, dtype, blake2b_hash) - Check a process-wide registry for an existing match
-
If found, return the cached tensor (same
data_ptr()) - If not found, register for future reuse
Hashing and deduplication occur before tensors move to GPU, avoiding duplicate memory allocation in the first place. The CPU is otherwise idle during weight loading, so hashing overhead is negligible (<1% of load time).
Safety: Reused tensors are read-only. No model can overwrite shared data.
Performance: Zero runtime overhead. The forward pass sees no change; only the underlying storage is shared.
Architecture
Here is how tensor dedup integrates into vLLM’s weight loading pipeline:
Important: Single-Process Mode Required
Tensor deduplication relies on a process-wide registry (TensorDedupRegistry) that maintains references to cached tensors. This only works when all models are loaded within the same Python process. vLLM’s default V1 engine uses multiprocessing for worker isolation, which would give each model its own address space which would prevent deduplication with the current design.
To enable deduplication, set:
VLLM_ENABLE_V1_MULTIPROCESSING=0
This runs all model loading and inference in a single process, allowing the TensorDedupRegistry to be shared across multiple LLM() instances.
Future: Multi-Process Memory Manager
For the prototype, I forced single process for vLLM but I think a better architecture would be to use CUDA IPC to create a centralized “Memory Manager”, something I plan to implement in future.
Process A (Inferno Memory Manager):
- Loads the model weights (e.g., Llama-3-70B) onto the GPU.
- Iterates through the state dict.
-
For each tensor, calls
cudaIpcGetMemHandle.
Stores these handles in a metadata store.
Process B, C, D (vLLM Workers):
- Launch as usual but with a patched model loader.
- Instead of reading .safetensors files from disk, they query the Memory Manager.
- Receive IPC handles + shape/dtype metadata.
-
Call
cudaIpcOpenMemHandleto get a pointer. - Reconstruct the PyTorch tensor from that pointer.
Again, the main integration point would be vLLMs model loader: vllm/model_executor/model_loader.py
Configuration
Configuration is exposed through TensorDedupConfig:
from vllm.config.load import TensorDedupConfig
dedup_config = TensorDedupConfig(
enabled=True,
hash_algorithm="blake2b", # or "xxhash64" for speed
verify_bytes=False, # optional byte-level collision check
min_tensor_bytes=4194304, # skip small tensors (4MB threshold)
)
llm = LLM(model="...", tensor_dedup=dedup_config)
Synthetic Evaluation
To validate the implementation, I loaded the same model twice, a synthetic scenario that guarantees 100% weight overlap:
Qwen/Qwen2-0.5B loaded twice
Without dedup: allocated = 18.42 GB
With dedup: allocated = 9.20 GB
Shared tensors: 73
Memory saved: 858 MB (50%)
The 50% figure is the theoretical maximum when dealing with two models: every tensor in the second model is deduplicated. Real-world savings depend on actual weight overlap and the number of participating models.
Real-World Results
SDXL Base + CLIP Text Encoder
I evaluated tensor deduplication in the SDXL diffusion pipeline using the Diffusers library. SDXL Base and its CLIP text encoder share 196 bit-identical tensors.
| Metric | Value |
|---|---|
| VRAM saved | 234.72 MB |
| VRAM reduction | 3.43% |
| Shared tensors | 196 |
This demonstrates cross-model deduplication between components not originally designed to share weights. Even small savings matter in diffusion workloads, where VRAM headroom governs resolution and batch size.
DeepFloyd IF Stage I XL + Stage II L
DeepFloyd IF consists of multiple stages, each loading its own text encoder and safety modules, even though these components are identical across stages.
Per-stage VRAM (without dedup):
| Stage | VRAM |
|---|---|
| Stage I XL | 34.91 GB |
| Stage II L | 23.51 GB |
| Total | 58.42 GB |
Shared components identified:
| Component | Size |
|---|---|
| text_encoder | 17.74 GB |
| safety_checker | 1.13 GB |
| watermarker | 15 KB |
With deduplication:
| Metric | Value |
|---|---|
| Total VRAM (no dedup) | 58.42 GB |
| Total VRAM (with dedup) | 39.54 GB |
| Savings | 18.87 GB (32.31%) |
The second stage’s largest component (the T5 text encoder) becomes effectively free in terms of VRAM. Multi-stage diffusion pipelines benefit enormously from weight sharing.
Expected Savings by Workload
| Workload | Savings | Notes |
|---|---|---|
| Diffusion: DeepFloyd IF Stage I + II | 32.31% | 18.87 GB shared (text encoder, safety) |
| Diffusion: SDXL + CLIP | 3.43% | 196 shared tensors |
| LLM: Identical models | 50% | Synthetic upper bound |
| LLM: Instruction-tuned variants | 10–25% | Embeddings + early blocks preserved |
| LLM: Merged QLoRA | 5–20% | Early layers often unchanged |
| LLM: Full fine-tuning | ~0% | All weights modified |
| Unrelated models | ~0% | No shared weights |
Future Direction: Frozen Base Substitution
Tensor deduplication exploits accidental weight identity: layers that happen to be the same across models. A more deliberate approach is to engineer which layers are shared, rather than hoping they remain identical.
The Idea
Consider N fine-tuned variants of a 7B base model:
| Approach | Memory (10 variants) | Quality |
|---|---|---|
| Independent models | 70 GB | Baseline |
| INT4 quantization | 17.5 GB | Degraded uniformly |
| 1 frozen base + N heads (10%) | 14 GB | Degraded selectively |
The frozen base approach can use less memory than aggressive quantization while potentially preserving more quality. The base model runs at full precision; only the task-specific “heads” vary.
Why This Might Beat Quantization
Quantization introduces uniform error across every layer. Frozen base substitution introduces error only in replaced layers, and those layers are typically:
- The most general (early transformer blocks)
- Least specialized for downstream tasks
- Most similar across fine-tuned variants
The hypothesis: substituting early layers with a frozen base may degrade quality less than INT4 quantization for fine-tuning tasks that introduce highly localized changes (domain adaptation, safety tuning, style tuning).
This requires controlled experiments. I plan to compare frozen base substitution against the quantization ladder (INT8 → INT4 → INT3) and measure where each approach dominates on the Pareto frontier of memory vs. quality.
This section describes future research, not implemented code. The tensor deduplication above is working today.
Limitations and Caveats
-
Weight identity is rare across unrelated models: Deduplication only helps when models share a common lineage or training procedure.
-
Single-process requirement: The current implementation requires
VLLM_ENABLE_V1_MULTIPROCESSING=0, which may not be suitable for all deployment scenarios. -
No cross-process sharing: Deduplication does not work across separate GPU processes. Future work could explore shared memory segments or memory-mapped approaches.
-
Hashing overhead: While small (under 1% of load time), hash computation is non-zero. The
min_tensor_bytesparameter allows skipping small tensors. -
Registry memory overhead: The
TensorDedupRegistrymaintains references to all cached tensors. This is negligible compared to tensor storage but present.
Future Work
-
Empirical quality studies: Systematic comparison of frozen base substitution vs. quantization across model families and tasks.
-
Training-time coordination: Fine-tuning recipes that explicitly freeze base layers, ensuring bit-identical weights for deduplication.
-
Checkpoint metadata standards: Conventions for marking shared layers in model cards and checkpoint formats.
-
Multi-process support: Shared memory or memory-mapped approaches to enable deduplication across worker processes.
Conclusion
Tensor deduplication provides automatic weight sharing across models with no changes to checkpoints, training, or the forward pass.
Results
- Multi-stage diffusion pipelines (DeepFloyd IF) achieve 32% memory reduction, with the text encoder becoming effectively free on the second stage
- Cross-model sharing (SDXL + CLIP) yields 3.4% savings even between components not designed for sharing
- LLM fine-tune variants typically share 5–25% of weights; identical models reach 50%
The implementation is ~280 lines and integrates at the weight-loading layer.
The proof-of-concept is available at inferno-sh/inferno-lab-vllm#2. Feedback welcome.
Reproducing the Results
LLM deduplication (vLLM):
git clone [email protected]:inferno-sh/inferno-lab-vllm.git
cd inferno-lab-vllm
uv venv && source .venv/bin/activate && uv sync
VLLM_ENABLE_V1_MULTIPROCESSING=0 python tensor_dedup_vllm.py
Diffusion deduplication (Diffusers):
git clone [email protected]:inferno-sh/inferno-lab-tensor-dedup.git
cd inferno-lab-tensor-dedup
uv venv && source .venv/bin/activate && uv sync
python tensor_dedup_diffusers.py
See the inferno-lab-tensor-dedup repo for SDXL and DeepFloyd IF reproduction scripts.