Multi-task learning trains a single model on multiple tasks simultaneously, sharing representations across them. Done well, it produces models that are more parameter-efficient than a separate specialist per task, more robust to distribution shift on any single task, and sometimes genuinely better on all tasks than any single-task baseline. Done poorly, it produces models that underperform on everything because tasks interfere with each other’s gradients. Understanding the mechanisms behind task interference and the architectural choices that mitigate it is what separates practitioners who get value from multi-task setups from those who revert to single-task models after a frustrating experiment.
Hard Parameter Sharing
Hard parameter sharing is the simplest multi-task architecture: a shared backbone encoder processes all inputs, and task-specific heads branch off at the top. All parameters in the backbone are updated by gradients from every task, while each head is updated only by its own task’s loss. This is the default approach in most multi-task implementations and works well when tasks are semantically related — sharing a representation learned for sentiment classification and topic classification is straightforward because both tasks benefit from similar features.
import torch
import torch.nn as nn
from transformers import AutoModel
from typing import Optional
class HardSharingMultiTaskModel(nn.Module):
"""Shared encoder with separate classification heads per task.
Classic hard parameter sharing: backbone shared, heads independent.
"""
def __init__(
self,
encoder_name: str = "distilbert-base-uncased",
task_configs: dict = None, # {"task_name": num_classes}
):
super().__init__()
self.encoder = AutoModel.from_pretrained(encoder_name)
hidden_size = self.encoder.config.hidden_size
# One classification head per task
task_configs = task_configs or {"sentiment": 3, "topic": 8, "intent": 12}
self.heads = nn.ModuleDict({
task: nn.Sequential(
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size // 2),
nn.GELU(),
nn.Linear(hidden_size // 2, n_classes),
)
for task, n_classes in task_configs.items()
})
def forward(self, input_ids, attention_mask, task: str):
# Shared encoder — all tasks use the same weights
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Mean pool over non-padding tokens
mask_expanded = attention_mask.unsqueeze(-1).float()
pooled = (outputs.last_hidden_state * mask_expanded).sum(1) / mask_expanded.sum(1)
return self.heads[task](pooled)
class MultiTaskTrainer:
"""Simple multi-task training loop with task sampling."""
def __init__(self, model: HardSharingMultiTaskModel,
task_dataloaders: dict, lr: float = 2e-5):
self.model = model
self.task_dataloaders = task_dataloaders
self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
self.criterion = nn.CrossEntropyLoss()
# Task sampling weights (proportional to dataset size or tunable)
sizes = {t: len(dl.dataset) for t, dl in task_dataloaders.items()}
total = sum(sizes.values())
self.task_weights = {t: s / total for t, s in sizes.items()}
def train_step(self, task: str, batch: dict) -> float:
self.model.train()
logits = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
task=task,
)
loss = self.criterion(logits, batch["labels"])
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
return loss.item()
def sample_task(self) -> str:
"""Sample next task proportionally to dataset size."""
import random
tasks = list(self.task_weights.keys())
weights = [self.task_weights[t] for t in tasks]
return random.choices(tasks, weights=weights)[0]
Task sampling strategy has a meaningful effect on hard-sharing performance. Sampling proportionally to dataset size ensures large-dataset tasks do not dominate training and starve small-dataset tasks of updates. Temperature-based sampling — raising dataset sizes to a power less than 1 before normalising — further smooths out size imbalances and tends to improve small-task performance at a small cost to large-task performance. The right sampling temperature is dataset-dependent; 0.75 is a reasonable default to start with and tune from there.
Soft Parameter Sharing
Soft parameter sharing maintains separate sets of parameters for each task but regularises them to remain similar to each other. Each task has its own encoder, but an L2 regularisation term in the loss penalises divergence between corresponding parameters across tasks. This gives each task the flexibility to develop specialised representations while still benefiting from the regularisation pressure toward shared structure. It is more parameter-expensive than hard sharing — you have T copies of the encoder for T tasks — but less prone to negative transfer because tasks cannot directly interfere through shared gradients.
import torch
import torch.nn as nn
from transformers import AutoModel
class SoftSharingMultiTaskModel(nn.Module):
"""Soft parameter sharing: separate encoders regularised toward each other."""
def __init__(self, encoder_name: str, task_names: list,
n_classes_per_task: dict, sharing_lambda: float = 0.01):
super().__init__()
self.task_names = task_names
self.sharing_lambda = sharing_lambda
# One full encoder per task
self.encoders = nn.ModuleDict({
task: AutoModel.from_pretrained(encoder_name)
for task in task_names
})
hidden = self.encoders[task_names[0]].config.hidden_size
self.heads = nn.ModuleDict({
task: nn.Linear(hidden, n_classes_per_task[task])
for task in task_names
})
def forward(self, input_ids, attention_mask, task: str):
enc = self.encoders[task]
out = enc(input_ids=input_ids, attention_mask=attention_mask)
mask_exp = attention_mask.unsqueeze(-1).float()
pooled = (out.last_hidden_state * mask_exp).sum(1) / mask_exp.sum(1)
return self.heads[task](pooled)
def soft_sharing_loss(self) -> torch.Tensor:
"""L2 regularisation penalising divergence between task encoders."""
total = torch.tensor(0.0)
tasks = self.task_names
for i in range(len(tasks)):
for j in range(i + 1, len(tasks)):
for (n1, p1), (n2, p2) in zip(
self.encoders[tasks[i]].named_parameters(),
self.encoders[tasks[j]].named_parameters(),
):
total = total + (p1 - p2).pow(2).sum()
return self.sharing_lambda * total
def total_loss(self, task_loss: torch.Tensor) -> torch.Tensor:
return task_loss + self.soft_sharing_loss()
Task Gradient Interference and How to Detect It
The core failure mode of multi-task learning is negative transfer: a task’s gradients update shared parameters in a direction that hurts another task’s performance. This happens when tasks pull representations in conflicting directions — for example, a sentiment task benefits from representations that distinguish positive from negative language, while a formality classification task benefits from representations that are orthogonal to sentiment. Detecting negative transfer early is important because it determines whether multi-task learning is worth pursuing at all for a given task combination.
import torch
from typing import Dict, List
def compute_gradient_cosine_similarity(
model: torch.nn.Module,
task_losses: Dict[str, torch.Tensor],
) -> Dict[str, float]:
"""Measure pairwise gradient cosine similarity between tasks.
Negative similarity indicates task interference (negative transfer).
Near-zero similarity indicates orthogonal tasks (neutral).
Positive similarity indicates aligned tasks (positive transfer).
"""
task_grads = {}
for task, loss in task_losses.items():
model.zero_grad()
loss.backward(retain_graph=True)
# Collect shared encoder gradients as a flat vector
grads = []
for name, param in model.named_parameters():
if 'head' not in name and param.grad is not None:
grads.append(param.grad.detach().flatten())
task_grads[task] = torch.cat(grads)
# Compute pairwise cosine similarity
tasks = list(task_grads.keys())
similarities = {}
for i in range(len(tasks)):
for j in range(i + 1, len(tasks)):
g1, g2 = task_grads[tasks[i]], task_grads[tasks[j]]
cos_sim = torch.nn.functional.cosine_similarity(g1.unsqueeze(0), g2.unsqueeze(0)).item()
pair = f"{tasks[i]}_vs_{tasks[j]}"
similarities[pair] = cos_sim
print(f"{pair}: {cos_sim:+.4f} ({'positive transfer' if cos_sim > 0.1 else 'negative transfer' if cos_sim < -0.1 else 'neutral'})")
return similarities
# Rule of thumb:
# cos_sim > +0.1: tasks help each other, good candidate for hard sharing
# -0.1 < cos_sim < +0.1: neutral, soft sharing or separate heads preferred
# cos_sim < -0.1: tasks interfere, consider auxiliary weighting or PCGrad
If gradient cosine similarity is consistently negative between two tasks, hard parameter sharing will cause negative transfer. The options are: switch to soft sharing, use PCGrad (which projects conflicting gradients to remove the interfering component), weight down the problematic task's loss, or simply not include the harmful task in the multi-task setup. Not every combination of tasks benefits from multi-task training, and the gradient similarity diagnostic tells you quickly which combinations are worth pursuing before committing to a full training run.
Loss Weighting Strategies
When tasks have losses at very different scales — a regression task with MSE loss in the thousands and a classification task with cross-entropy near 1.0 — naive summation causes the large-loss task to dominate gradient updates and effectively trains a single-task model. Loss weighting is essential to get balanced multi-task training, and the right weighting strategy depends on whether you want to optimise all tasks equally or prioritise specific ones.
import torch
import torch.nn as nn
class UncertaintyWeighting(nn.Module):
"""Homoscedastic uncertainty weighting (Kendall et al. 2018).
Learns log(sigma^2) per task; tasks with high uncertainty get lower weight.
Gradient automatically balances task contributions without manual tuning.
"""
def __init__(self, n_tasks: int):
super().__init__()
# Initialise log variances to 0 (sigma=1 for each task initially)
self.log_vars = nn.Parameter(torch.zeros(n_tasks))
def forward(self, losses: list[torch.Tensor]) -> torch.Tensor:
total = torch.tensor(0.0, requires_grad=True)
for i, loss in enumerate(losses):
precision = torch.exp(-self.log_vars[i])
# Weighted loss + regularisation term
total = total + precision * loss + self.log_vars[i]
return total
# Usage: include log_vars in optimizer param groups
# model_params = list(model.parameters())
# weighting = UncertaintyWeighting(n_tasks=3)
# optimizer = AdamW(model_params + list(weighting.parameters()), lr=2e-5)
When Multi-Task Learning Beats Single-Task
Multi-task learning reliably outperforms single-task baselines in three situations: when one task has very limited labelled data and a related task has abundant data (the data-rich task acts as a regulariser and provides auxiliary signal); when tasks share linguistic or structural features that benefit from joint representation learning (NER and relation extraction, for example, both benefit from entity-aware representations); and when you need to deploy a single model endpoint that handles multiple tasks without maintaining separate inference servers per task.
It reliably underperforms single-task baselines when tasks have conflicting objectives (summarisation and factual QA often conflict because summarisation rewards conciseness while QA rewards completeness), when the tasks operate at very different granularities (token-level labelling and document-level classification share little useful structure), or when one task's dataset is so large it drowns out signal from smaller tasks regardless of sampling strategy. The practical test is always empirical: train single-task baselines for each task, then compare against your multi-task model on each task's validation set. If the multi-task model matches or beats all single-task baselines, it is worth deploying. If it underperforms on any critical task, single-task is the safer choice.
For LLMs specifically, multi-task fine-tuning — training on a mixture of instruction-following tasks simultaneously — is the standard approach to instruction tuning and consistently outperforms fine-tuning on a single task type. The diversity of the task mixture is more important than its size: a model fine-tuned on 10,000 examples spanning 50 task types generalises better than one fine-tuned on 50,000 examples of a single task type. This is why datasets like FLAN, Super-NaturalInstructions, and the Orca datasets are constructed as broad multi-task mixtures rather than single-task corpora.
Auxiliary Tasks: Using Related Tasks to Improve a Primary Task
A common and effective application of multi-task learning is using auxiliary tasks — tasks you do not need predictions for at inference time — purely to improve the primary task's representation learning. The classic example is adding a language modelling objective as an auxiliary task when fine-tuning a classifier: the auxiliary LM loss keeps the encoder from forgetting the pretraining distribution and acts as a regulariser that prevents overfitting on small labelled datasets. Adding part-of-speech tagging or dependency parsing as auxiliary tasks for NER consistently improves NER performance because the auxiliary tasks force the encoder to develop syntactically-aware representations that are also useful for entity recognition.
The key design choice for auxiliary tasks is weight: the auxiliary loss should be weighted low enough that it does not overpower the primary task's gradient signal. A weight of 0.1 to 0.3 on the auxiliary loss relative to the primary loss is a reasonable starting range. Start at 0.1 and increase until auxiliary task performance saturates; if primary task performance starts to degrade as you increase auxiliary weight, you have found the interference threshold. For language modelling as an auxiliary task specifically, a weight of 0.1 is usually sufficient to capture the regularisation benefit without meaningfully slowing convergence on the primary task.
import torch
import torch.nn as nn
from transformers import AutoModel, AutoModelForCausalLM
class AuxiliaryLMTrainer:
"""Fine-tune a classifier with auxiliary language modelling loss.
Primary task: classification (e.g. sentiment, intent)
Auxiliary task: masked or causal language modelling
The auxiliary loss acts as regularisation, preventing forgetting.
"""
def __init__(self, model_name: str, n_classes: int,
aux_weight: float = 0.1, lr: float = 2e-5):
self.encoder = AutoModel.from_pretrained(model_name)
hidden = self.encoder.config.hidden_size
vocab = self.encoder.config.vocab_size
self.classifier = nn.Linear(hidden, n_classes)
# LM head for auxiliary causal LM objective
self.lm_head = nn.Linear(hidden, vocab, bias=False)
self.aux_weight = aux_weight
params = (list(self.encoder.parameters()) +
list(self.classifier.parameters()) +
list(self.lm_head.parameters()))
self.optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=0.01)
def step(self, batch):
self.encoder.train()
out = self.encoder(input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'])
hidden = out.last_hidden_state
# Primary: classification loss on [CLS] token
cls_logits = self.classifier(hidden[:, 0])
cls_loss = nn.functional.cross_entropy(cls_logits, batch['labels'])
# Auxiliary: causal LM loss (predict next token for each position)
lm_logits = self.lm_head(hidden[:, :-1])
lm_targets = batch['input_ids'][:, 1:]
lm_loss = nn.functional.cross_entropy(
lm_logits.reshape(-1, lm_logits.size(-1)),
lm_targets.reshape(-1),
ignore_index=0, # ignore padding token id
)
total_loss = cls_loss + self.aux_weight * lm_loss
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(
list(self.encoder.parameters()) + list(self.classifier.parameters()), 1.0)
self.optimizer.step()
return {'cls_loss': cls_loss.item(), 'lm_loss': lm_loss.item(),
'total': total_loss.item()}
Mixture-of-Experts as Multi-Task Scaling
Mixture-of-experts (MoE) can be viewed as a form of learned soft parameter sharing where different subsets of parameters (experts) specialise on different input distributions. In a multi-task MoE setup, the router naturally learns to direct different tasks to different experts, achieving soft separation without explicit task labels — the model learns which parameters are task-specific versus shared through the routing distribution. Switch Transformer and Mixtral both exhibit this behaviour: analysis of their routing shows that certain experts specialise on specific syntactic constructions, linguistic registers, or code versus natural language, mirroring what manual soft-sharing architectures achieve through explicit regularisation. For teams building large-scale multi-task models, the MoE architecture amortises the cost of soft sharing — you get task-specific specialisation at inference time without paying the full parameter cost of separate encoders during deployment, since only a fraction of experts is active per token.
Practical Checklist for Multi-Task Setup
Before committing to a multi-task training run, work through these checks: compute gradient cosine similarity on a small batch to confirm tasks are not strongly negatively correlated; normalise loss scales so no single task dominates by raw magnitude; choose a sampling strategy (proportional, temperature-smoothed, or round-robin) and validate it produces balanced updates across tasks in the first 100 steps by logging per-task loss curves; add the primary task's validation metric as the early stopping criterion rather than total loss, since total loss can decrease while primary task performance plateaus or regresses; and always compare against single-task baselines before concluding that multi-task training is beneficial. These checks take a few hours and prevent the most common failure mode: discovering after a long training run that multi-task learning hurt the primary task you actually care about.
Task Routing and Conditional Computation
A natural extension of hard parameter sharing is to condition the shared backbone's behaviour on the current task, giving the model information about which task it is solving without fully separating parameters. The simplest form is task embedding: a learned embedding vector per task is added to the input representation or injected into each layer via a learned scale and shift, similar to how conditional normalisation works in style transfer. This allows the backbone to adapt its computation to the task context while maintaining full parameter sharing across tasks.
Prefix tuning applied to multi-task settings achieves a similar effect: each task gets its own learned prefix tokens prepended to the input, which steer the shared transformer's attention toward task-relevant features without modifying the backbone weights. This is particularly efficient because the backbone remains frozen after pretraining — only the task-specific prefixes are trained, which adds negligible parameters per task while maintaining the full shared representation capacity. The tradeoff relative to full fine-tuning is that prefix-based task conditioning works well when the base model's representations are already well-suited to the task distribution, but may underperform when tasks require significant redistribution of the pretrained representations. For instruction-following tasks built on a strong pretrained LLM, prefix-based multi-task conditioning is typically within a few percentage points of full fine-tuning performance at a fraction of the parameter cost, and is worth evaluating before committing to a full multi-task fine-tuning run that modifies all shared parameters.
Architecture Decision Summary
Hard parameter sharing: use when tasks are semantically related, dataset sizes are similar, and you want the simplest possible implementation. Start with proportional task sampling and uniform loss weights, then tune from there. Soft parameter sharing: use when tasks are moderately related but you expect some divergence in optimal representations, or when gradient cosine similarity analysis shows partial interference. The parameter overhead is roughly T times the encoder size for T tasks, so it is only practical when T is small (2–4 tasks). Task-conditioned hard sharing (task embeddings or prefixes): use when you want hard sharing efficiency but suspect the backbone needs some task-specific steering — this is often the right balance between simplicity and performance for 4–10 tasks over a strong pretrained base. Auxiliary tasks with weighted loss: use when your primary task has limited labelled data and you have access to related tasks with abundant labels or self-supervised signal. The auxiliary task adds no inference cost and consistently improves primary task generalisation for small-data regimes.