Large transformer models have revolutionized artificial intelligence, powering everything from chatbots to code generation tools. However, their impressive capabilities come with a significant computational cost, particularly during inference. As these models continue to grow in size and complexity, optimizing their inference speed has become crucial for practical deployment in real-world applications.
The challenge of inference optimization extends beyond mere computational efficiency. Organizations deploying large language models face mounting pressure to reduce latency, minimize hardware costs, and improve user experience while maintaining model quality. This comprehensive guide explores proven strategies and cutting-edge techniques to accelerate inference for large transformer models without compromising their performance.
⚡ Inference Speed Challenge
Large transformer models can take seconds to generate responses, but users expect sub-second interactions
Understanding the Inference Bottlenecks
Before diving into optimization techniques, it’s essential to understand where computational bottlenecks occur during transformer inference. The primary performance challenges stem from the model’s architecture itself and the sequential nature of text generation.
The attention mechanism, while powerful, requires computing attention weights for every token in the sequence against every other token. This quadratic complexity becomes increasingly problematic as sequence lengths grow. During inference, each new token must attend to all previously generated tokens, creating a cumulative computational burden that scales with output length.
Memory bandwidth presents another significant constraint. Large transformer models often exceed GPU memory capacity, necessitating frequent data transfers between different memory hierarchies. The key-value cache, which stores attention states for previously processed tokens, can consume substantial memory and become a bottleneck for longer sequences.
Matrix multiplication operations dominate the computational workload in transformers. While modern GPUs excel at these operations, the sequential nature of autoregressive generation prevents full parallelization. Each token must wait for the previous token to complete before beginning computation, limiting the ability to leverage massive parallel processing capabilities.
Model Architecture Optimizations
Multi-Query Attention and Grouped Query Attention
Traditional multi-head attention mechanisms replicate key and value projections across all attention heads, consuming significant memory and computational resources. Multi-Query Attention (MQA) addresses this by sharing key and value projections across all heads while maintaining separate query projections. This modification reduces memory requirements and improves cache efficiency without substantially impacting model quality.
Grouped Query Attention (GQA) offers a middle ground between standard multi-head attention and MQA. By grouping attention heads and sharing key-value pairs within groups, GQA provides better quality retention than MQA while still achieving meaningful speed improvements. This approach has been successfully implemented in models like Llama 2, demonstrating its practical effectiveness.
Mixture of Experts (MoE) Architecture
MoE architectures activate only a subset of model parameters for each input, dramatically reducing computational requirements during inference. Instead of processing inputs through all model parameters, MoE models route inputs to specialized expert networks based on learned gating mechanisms. This selective activation can reduce computational load by 80% or more while maintaining model capacity.
The key to effective MoE implementation lies in efficient routing algorithms and load balancing. Modern MoE systems employ sophisticated routing strategies that consider both computational efficiency and expert specialization. Switch Transformer and GLaM represent successful implementations of this approach, achieving substantial speed improvements with minimal quality degradation.
Quantization Techniques
Post-Training Quantization
Post-training quantization converts model weights from 32-bit floating-point to lower precision formats without retraining. INT8 quantization can reduce model size by 75% while maintaining acceptable accuracy for most applications. The process involves calibrating quantization parameters using a representative dataset and applying scaling factors to preserve numerical stability.
Here’s a practical example using PyTorch’s built-in quantization:
import torch
import torch.quantization as quantization
from transformers import AutoModel, AutoTokenizer
# Load your transformer model
model = AutoModel.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# Prepare model for quantization
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# Calibrate with representative data
sample_inputs = [
"This is a sample sentence for calibration.",
"Another example to help with quantization.",
"More diverse text helps with better calibration."
]
for text in sample_inputs:
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
_ = model(**inputs)
# Convert to quantized model
quantized_model = torch.quantization.convert(model, inplace=False)
# Compare model sizes
def get_model_size(model):
torch.save(model.state_dict(), 'temp_model.pth')
size = os.path.getsize('temp_model.pth')
os.remove('temp_model.pth')
return size / (1024 * 1024) # Size in MB
original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)
print(f"Original model: {original_size:.2f} MB")
print(f"Quantized model: {quantized_size:.2f} MB")
print(f"Compression ratio: {original_size/quantized_size:.2f}x")
Dynamic quantization takes this approach further by quantizing activations during inference based on their runtime distributions. This technique adapts to the specific characteristics of each input, potentially achieving better accuracy than static quantization approaches. Modern frameworks like PyTorch and TensorFlow provide built-in support for dynamic quantization, making implementation straightforward.
Quantization-Aware Training
For applications requiring maximum accuracy, quantization-aware training incorporates quantization effects during the training process. This approach allows models to adapt to reduced precision representations, often achieving better results than post-training quantization. The training process simulates quantization operations using fake quantization, enabling gradients to flow through quantized operations.
import torch
import torch.nn as nn
from torch.quantization import QuantStub, DeQuantStub
class QuantizedTransformerBlock(nn.Module):
def __init__(self, original_block):
super().__init__()
self.quant = QuantStub()
self.transformer_block = original_block
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.transformer_block(x)
x = self.dequant(x)
return x
# Example quantization-aware training setup
def setup_qat_model(model):
model.train()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# Prepare model for quantization-aware training
torch.quantization.prepare_qat(model, inplace=True)
return model
# Training loop with quantization awareness
def train_with_qat(model, train_loader, optimizer, criterion, num_epochs=3):
model.train()
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Convert to quantized model after some training
if epoch >= 2: # Start quantization after initial training
model.apply(torch.quantization.disable_observer)
model.apply(torch.quantization.disable_fake_quant)
# Convert to actual quantized model
model.eval()
quantized_model = torch.quantization.convert(model, inplace=False)
return quantized_model
Mixed-precision training combines different quantization levels throughout the model, using higher precision for sensitive operations while applying aggressive quantization elsewhere. This nuanced approach balances speed gains with accuracy preservation, making it particularly suitable for production deployments.
Pruning and Sparsity
Structured Pruning
Structured pruning removes entire components of the neural network, such as attention heads, layers, or neurons, creating smaller models that maintain dense computation patterns. This approach ensures compatibility with standard hardware accelerators while reducing computational requirements. Research has shown that many transformer models contain redundant components that can be removed without significant performance degradation.
Here’s an example of implementing attention head pruning:
import torch
import torch.nn as nn
from transformers import BertModel
import numpy as np
class PrunedBertSelfAttention(nn.Module):
def __init__(self, original_attention, heads_to_prune):
super().__init__()
self.num_attention_heads = original_attention.num_attention_heads
self.attention_head_size = original_attention.attention_head_size
self.all_head_size = original_attention.all_head_size
# Create mask for heads to keep
self.active_heads = [i for i in range(self.num_attention_heads)
if i not in heads_to_prune]
# Prune query, key, value projections
self.query = self._prune_linear_layer(original_attention.query, heads_to_prune)
self.key = self._prune_linear_layer(original_attention.key, heads_to_prune)
self.value = self._prune_linear_layer(original_attention.value, heads_to_prune)
self.dropout = original_attention.dropout
def _prune_linear_layer(self, layer, heads_to_prune):
"""Prune a linear layer by removing columns corresponding to pruned heads"""
head_size = self.attention_head_size
old_num_heads = self.num_attention_heads
# Create index tensor for columns to keep
keep_indices = []
for i in range(old_num_heads):
if i not in heads_to_prune:
keep_indices.extend(range(i * head_size, (i + 1) * head_size))
# Create new linear layer with pruned weights
new_layer = nn.Linear(layer.in_features, len(keep_indices), bias=layer.bias is not None)
new_layer.weight.data = layer.weight.data[keep_indices, :]
if layer.bias is not None:
new_layer.bias.data = layer.bias.data[keep_indices]
return new_layer
def analyze_attention_importance(model, dataloader, device):
"""Analyze which attention heads are most important"""
head_importance = {}
model.eval()
with torch.no_grad():
for batch in dataloader:
inputs = batch.to(device)
outputs = model(**inputs, output_attentions=True)
# Calculate importance based on attention entropy
for layer_idx, attention in enumerate(outputs.attentions):
if layer_idx not in head_importance:
head_importance[layer_idx] = []
# Calculate entropy for each head
for head_idx in range(attention.size(1)):
attention_head = attention[:, head_idx, :, :]
entropy = -torch.sum(attention_head * torch.log(attention_head + 1e-12), dim=-1)
head_importance[layer_idx].append(entropy.mean().item())
return head_importance
# Example usage
def prune_model_heads(model, prune_ratio=0.3):
"""Prune least important attention heads"""
# This would typically require a calibration dataset
# For demonstration, we'll prune every 3rd head
pruned_model = model
for layer_idx, layer in enumerate(model.encoder.layer):
num_heads = layer.attention.self.num_attention_heads
heads_to_prune = list(range(2, num_heads, 3)) # Prune every 3rd head
if heads_to_prune:
layer.attention.self = PrunedBertSelfAttention(
layer.attention.self, heads_to_prune
)
print(f"Layer {layer_idx}: Pruned heads {heads_to_prune}")
return pruned_model
Layer pruning represents one of the most effective structured pruning techniques. By removing entire transformer layers, models can achieve substantial speed improvements with minimal accuracy loss. The optimal pruning strategy depends on the specific model architecture and target application, requiring careful analysis of layer contributions to overall performance.
Unstructured Pruning
Unstructured pruning removes individual weights based on magnitude or importance criteria, creating sparse models that require specialized hardware or software support for optimal performance. While more flexible than structured pruning, unstructured approaches often require custom kernels or sparse computation libraries to realize speed benefits.
Magnitude-based pruning removes weights with the smallest absolute values, operating under the assumption that small weights contribute minimally to model output. More sophisticated approaches use gradient-based importance measures or second-order information to identify truly redundant parameters. These methods can achieve higher sparsity levels while maintaining model quality.
Knowledge Distillation
Knowledge distillation creates smaller, faster models by training them to mimic the behavior of larger teacher models. This technique transfers knowledge from complex models to simpler architectures, achieving significant speed improvements while preserving much of the original model’s capability.
The distillation process involves training a compact student model to match the output distributions of a larger teacher model. By learning from soft targets rather than hard labels, student models can capture nuanced patterns that might be lost in traditional supervised learning. This approach has proven particularly effective for transformer models, where teacher-student architectures can achieve 10x speedups with modest accuracy trade-offs.
Progressive distillation extends this concept by gradually reducing model size through multiple distillation stages. Each stage creates a slightly smaller model that serves as the teacher for the next stage, allowing for more aggressive compression while maintaining quality. This multi-stage approach often outperforms single-stage distillation, particularly for large compression ratios.
Hardware-Specific Optimizations
GPU Optimization
Modern GPUs provide specialized tensor cores designed specifically for deep learning workloads. Leveraging these hardware features requires careful attention to data layouts, precision formats, and operation fusion. Mixed-precision training and inference can achieve 2-3x speedups on modern GPUs while maintaining numerical stability.
Memory coalescing and efficient data movement patterns significantly impact GPU performance. Transformer models benefit from optimized memory access patterns that minimize cache misses and maximize bandwidth utilization. Techniques like gradient checkpointing and activation recomputation can reduce memory requirements at the cost of additional computation.
CPU Optimization
CPU inference optimization focuses on vectorization, cache efficiency, and parallel processing. Modern CPUs support advanced vector instructions (AVX-512) that can accelerate matrix operations significantly. Proper data layout and loop optimization can achieve substantial speedups for CPU-bound inference scenarios.
Thread parallelism and NUMA-aware memory allocation become crucial for multi-core CPU deployment. Transformer models can benefit from operator-level parallelism and pipeline parallelism techniques that distribute computation across multiple CPU cores efficiently.
Advanced Inference Techniques
Speculative Decoding
Speculative decoding accelerates autoregressive generation by using a smaller draft model to generate multiple candidate tokens simultaneously. The larger target model then verifies these candidates in parallel, accepting correct predictions and rejecting incorrect ones. This approach can achieve 2-3x speedup for text generation tasks while maintaining identical output quality.
Here’s a practical implementation of speculative decoding:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
class SpeculativeDecoder:
def __init__(self, target_model, draft_model, tokenizer):
self.target_model = target_model
self.draft_model = draft_model
self.tokenizer = tokenizer
self.target_model.eval()
self.draft_model.eval()
def speculative_sample(self, input_ids, max_new_tokens=50, gamma=5, temperature=1.0):
"""
Generate tokens using speculative decoding
Args:
input_ids: Input token sequence
max_new_tokens: Maximum number of tokens to generate
gamma: Number of tokens to speculate ahead
temperature: Sampling temperature
"""
generated_tokens = input_ids.clone()
with torch.no_grad():
for _ in range(max_new_tokens):
# Draft model generates gamma candidate tokens
draft_tokens = self._draft_generate(generated_tokens, gamma, temperature)
# Target model verifies all candidates in parallel
accepted_tokens = self._target_verify(generated_tokens, draft_tokens, temperature)
# Add accepted tokens to sequence
if len(accepted_tokens) > 0:
generated_tokens = torch.cat([generated_tokens, accepted_tokens], dim=1)
else:
# If no tokens accepted, generate one token with target model
next_token = self._target_generate_one(generated_tokens, temperature)
generated_tokens = torch.cat([generated_tokens, next_token], dim=1)
# Check for stopping condition
if self.tokenizer.eos_token_id in generated_tokens[0]:
break
return generated_tokens
def _draft_generate(self, input_ids, gamma, temperature):
"""Generate gamma candidate tokens using draft model"""
candidates = []
current_input = input_ids
for _ in range(gamma):
outputs = self.draft_model(current_input)
logits = outputs.logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
candidates.append(next_token)
current_input = torch.cat([current_input, next_token], dim=1)
return torch.cat(candidates, dim=1)
def _target_verify(self, input_ids, draft_tokens, temperature):
"""Verify draft tokens using target model"""
# Prepare sequence with all draft tokens
extended_sequence = torch.cat([input_ids, draft_tokens], dim=1)
# Get target model predictions for all positions
target_outputs = self.target_model(extended_sequence)
target_logits = target_outputs.logits / temperature
# Verify each draft token
accepted_tokens = []
for i, draft_token in enumerate(draft_tokens[0]):
position = input_ids.size(1) + i
target_probs = F.softmax(target_logits[:, position-1, :], dim=-1)
# Calculate acceptance probability
draft_logits = self.draft_model(extended_sequence[:, :position]).logits
draft_probs = F.softmax(draft_logits[:, -1, :] / temperature, dim=-1)
acceptance_prob = min(1.0, (target_probs[0, draft_token] /
draft_probs[0, draft_token]).item())
if torch.rand(1).item() < acceptance_prob:
accepted_tokens.append(draft_token.unsqueeze(0))
else:
break # Reject this and all subsequent tokens
if accepted_tokens:
return torch.cat(accepted_tokens, dim=0).unsqueeze(0)
return torch.empty(1, 0, dtype=torch.long)
def _target_generate_one(self, input_ids, temperature):
"""Generate one token using target model"""
outputs = self.target_model(input_ids)
logits = outputs.logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1)
# Example usage
def benchmark_speculative_decoding():
# Load models (draft should be much smaller than target)
target_model = AutoModelForCausalLM.from_pretrained('gpt2-medium')
draft_model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
decoder = SpeculativeDecoder(target_model, draft_model, tokenizer)
# Test input
input_text = "The future of artificial intelligence is"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Benchmark regular generation
import time
start_time = time.time()
regular_output = target_model.generate(input_ids, max_new_tokens=50, do_sample=True)
regular_time = time.time() - start_time
# Benchmark speculative decoding
start_time = time.time()
speculative_output = decoder.speculative_sample(input_ids, max_new_tokens=50)
speculative_time = time.time() - start_time
print(f"Regular generation time: {regular_time:.3f}s")
print(f"Speculative decoding time: {speculative_time:.3f}s")
print(f"Speedup: {regular_time/speculative_time:.2f}x")
return regular_output, speculative_output
The effectiveness of speculative decoding depends on the quality of the draft model and the acceptance rate of speculated tokens. Careful selection of draft models and optimization of speculation strategies can maximize the benefits of this technique.
Continuous Batching
Traditional batching waits for all sequences in a batch to complete before processing new requests. Continuous batching enables more efficient GPU utilization by immediately processing new requests as soon as slots become available. This approach can significantly improve throughput for serving scenarios with variable sequence lengths.
Here’s an implementation of a continuous batching system:
import torch
import asyncio
from collections import deque
from dataclasses import dataclass
from typing import List, Optional
import time
@dataclass
class GenerationRequest:
id: str
input_ids: torch.Tensor
max_new_tokens: int
temperature: float = 1.0
generated_tokens: List[int] = None
is_complete: bool = False
created_at: float = None
class ContinuousBatcher:
def __init__(self, model, tokenizer, max_batch_size=8, max_sequence_length=512):
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_sequence_length = max_sequence_length
# Active requests being processed
self.active_requests = {}
# Queue for new requests
self.request_queue = deque()
# Completed requests
self.completed_requests = {}
self.model.eval()
async def add_request(self, request: GenerationRequest):
"""Add a new generation request to the queue"""
request.created_at = time.time()
request.generated_tokens = []
self.request_queue.append(request)
# Wait for completion
while not request.is_complete:
await asyncio.sleep(0.01)
return self.completed_requests.pop(request.id)
async def process_requests(self):
"""Main processing loop for continuous batching"""
while True:
# Fill batch with active requests and new requests from queue
batch_requests = self._prepare_batch()
if not batch_requests:
await asyncio.sleep(0.001)
continue
# Process batch
await self._process_batch(batch_requests)
# Clean up completed requests
self._cleanup_completed()
await asyncio.sleep(0.001) # Small delay to prevent CPU spinning
def _prepare_batch(self) -> List[GenerationRequest]:
"""Prepare a batch of requests for processing"""
batch = []
# Add active requests that aren't complete
for request in list(self.active_requests.values()):
if not request.is_complete and len(batch) < self.max_batch_size:
batch.append(request)
# Add new requests from queue
while len(batch) < self.max_batch_size and self.request_queue:
request = self.request_queue.popleft()
self.active_requests[request.id] = request
batch.append(request)
return batch
async def _process_batch(self, batch_requests: List[GenerationRequest]):
"""Process a batch of requests"""
if not batch_requests:
return
# Prepare input tensors
input_ids_list = []
attention_mask_list = []
for request in batch_requests:
# Combine original input with generated tokens
full_sequence = torch.cat([
request.input_ids,
torch.tensor(request.generated_tokens, dtype=torch.long).unsqueeze(0)
], dim=1)
input_ids_list.append(full_sequence)
attention_mask_list.append(torch.ones_like(full_sequence))
# Pad sequences to same length
max_length = max(seq.size(1) for seq in input_ids_list)
padded_input_ids = []
padded_attention_masks = []
for i, (input_ids, attention_mask) in enumerate(zip(input_ids_list, attention_mask_list)):
pad_length = max_length - input_ids.size(1)
if pad_length > 0:
# Pad with tokenizer pad token
pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
padded_input = torch.cat([
input_ids,
torch.full((1, pad_length), pad_token_id, dtype=torch.long)
], dim=1)
padded_mask = torch.cat([
attention_mask,
torch.zeros(1, pad_length, dtype=torch.long)
], dim=1)
else:
padded_input = input_ids
padded_mask = attention_mask
padded_input_ids.append(padded_input)
padded_attention_masks.append(padded_mask)
# Stack into batch tensors
batch_input_ids = torch.cat(padded_input_ids, dim=0)
batch_attention_mask = torch.cat(padded_attention_masks, dim=0)
# Generate next tokens
with torch.no_grad():
outputs = self.model(
input_ids=batch_input_ids,
attention_mask=batch_attention_mask
)
logits = outputs.logits[:, -1, :] # Get last token logits
# Sample next tokens for each request
for i, request in enumerate(batch_requests):
if request.is_complete:
continue
# Apply temperature and sample
scaled_logits = logits[i] / request.temperature
probs = torch.softmax(scaled_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
request.generated_tokens.append(next_token)
# Check for completion
if (next_token == self.tokenizer.eos_token_id or
len(request.generated_tokens) >= request.max_new_tokens):
request.is_complete = True
# Generate final output
full_output = torch.cat([
request.input_ids,
torch.tensor(request.generated_tokens, dtype=torch.long).unsqueeze(0)
], dim=1)
self.completed_requests[request.id] = {
'generated_text': self.tokenizer.decode(full_output[0]),
'generated_tokens': request.generated_tokens,
'processing_time': time.time() - request.created_at
}
def _cleanup_completed(self):
"""Remove completed requests from active requests"""
completed_ids = [req_id for req_id, req in self.active_requests.items()
if req.is_complete]
for req_id in completed_ids:
del self.active_requests[req_id]
# Example usage
async def demo_continuous_batching():
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
# Create continuous batcher
batcher = ContinuousBatcher(model, tokenizer, max_batch_size=4)
# Start processing loop
processing_task = asyncio.create_task(batcher.process_requests())
# Submit multiple requests
requests = [
GenerationRequest("req1", tokenizer.encode("The future of AI is", return_tensors='pt'), 20),
GenerationRequest("req2", tokenizer.encode("Machine learning helps", return_tensors='pt'), 25),
GenerationRequest("req3", tokenizer.encode("Deep learning networks", return_tensors='pt'), 15),
]
# Process requests concurrently
start_time = time.time()
results = await asyncio.gather(*[batcher.add_request(req) for req in requests])
total_time = time.time() - start_time
print(f"Processed {len(requests)} requests in {total_time:.3f}s")
for i, result in enumerate(results):
print(f"Request {i+1}: {result['processing_time']:.3f}s")
print(f"Generated: {result['generated_text']}\n")
processing_task.cancel()
# Run the demo
# asyncio.run(demo_continuous_batching())
Dynamic batching strategies adapt batch sizes based on available memory and computational resources. By monitoring GPU utilization and memory consumption, systems can optimize batch sizes in real-time to maximize throughput while maintaining acceptable latency.
Implementation Best Practices
Successful inference optimization requires systematic measurement and incremental improvement. Profiling tools help identify specific bottlenecks in model execution, enabling targeted optimization efforts. Benchmarking different techniques on representative workloads ensures that optimizations translate to real-world performance gains.
Version control and A/B testing become essential when deploying optimized models. Maintaining baseline performance metrics and comparing optimized variants helps validate that speed improvements don’t come at the cost of unacceptable accuracy degradation. Automated testing pipelines can catch performance regressions early in the development process.
Conclusion
Optimizing inference speed for large transformer models requires a comprehensive approach that combines architectural improvements, quantization techniques, pruning strategies, and hardware-specific optimizations. The most effective solutions often involve combining multiple techniques, carefully balancing speed gains against accuracy preservation.
The field of inference optimization continues to evolve rapidly, with new techniques and hardware platforms emerging regularly. Staying current with the latest developments while maintaining a solid foundation in established optimization principles ensures that deployments can benefit from both proven strategies and cutting-edge innovations.
Success in transformer inference optimization comes from understanding the specific requirements of your application and systematically applying appropriate techniques. By carefully measuring performance impacts and iteratively refining approaches, organizations can achieve the speed improvements necessary for practical deployment while maintaining the quality that makes large transformer models valuable.