Metrics Interface
Custom metrics allow you to track and aggregate any quantitative measures during training and validation. siiRL provides a distributed, Ray-based metrics system that automatically handles aggregation across all workers using various reduction operations (mean, max, min, sum).
Architecture Overview
Distributed Metrics Architecture
==============================================================================
DAGWorker 0 DAGWorker 1 DAGWorker 2 DAGWorker N
+-----------+ +-----------+ +-----------+ +-----------+
| compute | | compute | | compute | | compute |
| metrics | | metrics | | metrics | | metrics |
+-----+-----+ +-----+-----+ +-----+-----+ +-----+-----+
| | | |
v v v v
+-----+-----+ +-----+-----+ +-----+-----+ +-----+-----+
| Metric | | Metric | | Metric | | Metric |
| Client | | Client | | Client | | Client |
+-----+-----+ +-----+-----+ +-----+-----+ +-----+-----+
| | | |
+------------------+------------------+------------------+
|
v
+-------------------+
| MetricWorker | (Ray Actor - Singleton)
| (Aggregator) |
+-------------------+
| - Collect metrics |
| - Wait for all |
| workers |
| - Aggregate: |
| mean/max/min/ |
| sum |
+--------+----------+
|
v
+-------------------+
| Final Metrics |
| (to Logger/WandB) |
+-------------------+
==============================================================================
Metrics Data Flow:
+-------------+ +----------------+ +----------------+ +--------+
| TensorDict | --> | compute_* | --> | MetricClient | --> | Metric |
| (batch) | | _metric() | | .submit_metric | | Worker |
+-------------+ +----------------+ +----------------+ +--------+
|
v
+-------------+
| Dict[str, |
| float] |
| {name: val} |
+-------------+
==============================================================================
Key Files:
siirl/execution/metric_worker/metric_worker.py- Ray-based distributed metrics aggregationsiirl/utils/metrics/metric_utils.py- Core metric computation functionssiirl/execution/metric_worker/utils.py- Aggregation function utilities
Quick Start
Method 1: Extending Core Metrics Functions (Recommended)
Step 1: Create your metric computation function in metric_utils.py
# Add to siirl/utils/metrics/metric_utils.py
def compute_custom_data_metrics(data: TensorDict) -> Dict[str, float]:
"""Custom metrics computed from batch data"""
metrics = {}
# Token-level accuracy
if "correct_tokens" in data and "attention_mask" in data:
correct = data["correct_tokens"].float()
mask = data["attention_mask"].float()
accuracy = (correct * mask).sum() / mask.sum()
metrics["custom/token_accuracy/mean"] = accuracy.item()
# Response quality score
if "responses" in data and "response_mask" in data:
response_quality = compute_response_quality_score(data)
metrics["custom/response_quality/mean"] = response_quality.mean().item()
metrics["custom/response_quality/max"] = response_quality.max().item()
metrics["custom/response_quality/min"] = response_quality.min().item()
return metrics
def compute_response_quality_score(data: TensorDict) -> torch.Tensor:
"""Helper function to compute response quality"""
responses = data["responses"]
response_mask = data["response_mask"]
# Example: vocabulary diversity score
unique_tokens_per_response = []
for i in range(responses.shape[0]):
response_tokens = responses[i][response_mask[i].bool()]
unique_count = len(torch.unique(response_tokens))
unique_tokens_per_response.append(unique_count)
return torch.tensor(unique_tokens_per_response, device=responses.device).float()
Step 2: Submit metrics using MetricClient
# Usage in your training loop
from siirl.execution.metric_worker.metric_worker import MetricClient
# In your DAG worker or training script
custom_metrics = compute_custom_data_metrics(batch)
metric_client.submit_metric(custom_metrics, world_size)
Current Metrics System
Built-in Metrics Reference
The following tables list all built-in metrics provided by siiRL.
Data Metrics (from compute_data_metric in metric_utils.py):
Metric Name |
Description |
|---|---|
|
Sequence-level scores from token-level scores |
|
Sequence-level rewards from token-level rewards |
|
Advantages (masked by response_mask) |
|
Returns (masked by response_mask) |
|
Value function estimates (if available) |
|
Explained variance of value function |
Metric Name |
Description |
|---|---|
|
Response token lengths |
|
Proportion hitting max response length |
|
Lengths for responses with reward > 0.5 |
|
Lengths for responses with reward ≤ 0.5 |
Metric Name |
Description |
|---|---|
|
Prompt token lengths |
|
Proportion hitting max prompt length |
Metric Name |
Description |
|---|---|
|
CPU memory usage per process |
|
Statistics for multi-turn conversations |
Timing Metrics (from compute_timing_metrics):
Metric Name |
Description |
|---|---|
|
Raw timing in seconds for each stage |
|
Per-token timing in milliseconds |
Stages: gen, ref, values, adv, update_critic, update_actor
Throughput Metrics (from compute_throughout_metrics):
Metric Name |
Description |
|---|---|
|
Total tokens processed |
|
Time per training step |
|
Tokens per second per GPU |
Validation Metrics (from process_validation_metrics):
Metric Name |
Description |
|---|---|
|
Mean across N samples |
|
Bootstrap best-of-N statistics |
|
Bootstrap worst-of-N statistics |
|
Bootstrap majority voting statistics |
|
Test score per data source |
Custom Metrics Implementation
Method 1: Custom Data Metrics
Extend the data metrics computed from training batches:
# Add to metric_utils.py
def compute_custom_training_metrics(data: TensorDict) -> Dict[str, float]:
"""Custom training-specific metrics"""
metrics = {}
# Policy entropy (exploration measure)
if "policy_logits" in data:
logits = data["policy_logits"]
probs = torch.softmax(logits, dim=-1)
entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
response_mask = data.get("response_mask", torch.ones_like(entropy))
# Only compute entropy for response tokens
masked_entropy = entropy * response_mask.float()
valid_entropy = masked_entropy.sum() / response_mask.sum()
metrics["training/policy_entropy/mean"] = valid_entropy.item()
# Gradient norm tracking
if "grad_norm" in data:
metrics["training/grad_norm/mean"] = data["grad_norm"].item()
# Loss convergence tracking
if "loss_values" in data:
loss_values = data["loss_values"]
metrics["training/loss/mean"] = loss_values.mean().item()
metrics["training/loss/std"] = loss_values.std().item()
return metrics
# Usage in MetricClient.compute_local_data_metric
def compute_local_data_metric(self, data: TensorDict, world_size: int):
# Standard metrics
standard_metrics = compute_data_metric(data)
# Add custom metrics
custom_metrics = compute_custom_training_metrics(data)
# Combine and submit
all_metrics = {**standard_metrics, **custom_metrics}
self.submit_metric(all_metrics, world_size)
Method 2: Custom Validation Metrics
Add custom validation metrics with bootstrap sampling:
# Add to metric_utils.py
def compute_custom_validation_metrics(
data_sources: list[str],
sample_inputs: list[str],
infos_dict: dict[str, list],
sample_turns: list[int]
) -> dict[str, float]:
"""Custom validation metrics with bootstrap analysis"""
# Extract custom fields from infos_dict
custom_metrics = {}
if "custom_score" in infos_dict:
# Group by data source
source_scores = defaultdict(list)
for i, source in enumerate(data_sources):
source_scores[source].append(infos_dict["custom_score"][i])
# Compute statistics per source
for source, scores in source_scores.items():
if len(scores) > 0:
custom_metrics[f"val/custom_score/{source}/mean"] = np.mean(scores)
custom_metrics[f"val/custom_score/{source}/std"] = np.std(scores)
# Bootstrap sampling for confidence intervals
if len(scores) > 1:
bootstrap_results = bootstrap_metric(
data=scores,
subset_size=min(5, len(scores)),
reduce_fns=[np.mean, np.max, np.min],
n_bootstrap=1000
)
custom_metrics[f"val/custom_score/{source}/bootstrap_mean"] = bootstrap_results[0][0]
custom_metrics[f"val/custom_score/{source}/bootstrap_mean_std"] = bootstrap_results[0][1]
# Conversation quality for multi-turn
if "conversation_quality" in infos_dict and len(sample_turns) > 0:
quality_by_turns = defaultdict(list)
for i, turns in enumerate(sample_turns):
if i < len(infos_dict["conversation_quality"]):
quality_by_turns[turns].append(infos_dict["conversation_quality"][i])
for turn_count, qualities in quality_by_turns.items():
if len(qualities) > 0:
custom_metrics[f"val/conversation_quality/turns_{turn_count}/mean"] = np.mean(qualities)
return custom_metrics
# Usage in MetricClient.process_local_validation_metrics
def process_local_validation_metrics(self, data_sources, sample_inputs, infos_dict, sample_turns, world_size):
# Standard validation metrics
standard_metrics = process_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)
# Add custom validation metrics
custom_metrics = compute_custom_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)
# Combine and submit
all_metrics = {**standard_metrics, **custom_metrics}
self.submit_metric(all_metrics, world_size)
Method 3: Custom Aggregation Logic
Create custom aggregation functions for specialized reduction operations:
# Add to execution/metric_worker/utils.py
def MedianMetric(metrics: List[Metric]):
"""Custom median aggregation"""
values = [v for metric in metrics
for v in (metric.value if isinstance(metric.value, list) else [metric.value])]
return float(torch.median(torch.tensor(values)).item())
def PercentileMetric(percentile: float):
"""Custom percentile aggregation factory"""
def _percentile_metric(metrics: List[Metric]):
values = [v for metric in metrics
for v in (metric.value if isinstance(metric.value, list) else [metric.value])]
return float(torch.quantile(torch.tensor(values), percentile / 100.0).item())
return _percentile_metric
# Update MetricFunc to handle custom aggregations
def MetricFunc(name: str):
if "median" in name:
return MedianMetric
elif "p95" in name:
return PercentileMetric(95)
elif "p99" in name:
return PercentileMetric(99)
elif "min" in name:
return MinMetric
elif "max" in name:
return MaxMetric
elif "sum" in name or "total" in name:
return SumMetric
else:
return MeanMetric
# Usage: name your metrics to trigger specific aggregations
metrics = {
"custom/latency/median": latency_values, # Will use MedianMetric
"custom/score/p95": score_values, # Will use 95th percentile
}
Method 4: Complex Custom Metrics
For more sophisticated metrics requiring multiple computation steps:
# Add to metric_utils.py
def compute_advanced_metrics(data: TensorDict) -> Dict[str, float]:
"""Advanced metrics requiring complex computation"""
metrics = {}
# Sequence coherence analysis
if "responses" in data and "attention_mask" in data:
coherence_scores = compute_sequence_coherence(data)
metrics.update({
"analysis/coherence/mean": coherence_scores.mean().item(),
"analysis/coherence/std": coherence_scores.std().item(),
"analysis/coherence/median": coherence_scores.median().item(),
})
# Token transition analysis
if "responses" in data:
transition_metrics = analyze_token_transitions(data)
metrics.update(transition_metrics)
# Reward distribution analysis
if "token_level_rewards" in data:
reward_dist_metrics = analyze_reward_distribution(data)
metrics.update(reward_dist_metrics)
return metrics
def compute_sequence_coherence(data: TensorDict) -> torch.Tensor:
"""Compute coherence score for each sequence"""
responses = data["responses"]
attention_mask = data["attention_mask"]
batch_size = responses.shape[0]
coherence_scores = []
for i in range(batch_size):
# Extract valid tokens for this sequence
valid_length = attention_mask[i].sum().item()
sequence = responses[i][:valid_length]
# Compute local coherence (e.g., token transition smoothness)
if len(sequence) > 1:
# Simplified coherence: variance in token values
coherence = 1.0 / (1.0 + torch.var(sequence.float()).item())
else:
coherence = 1.0
coherence_scores.append(coherence)
return torch.tensor(coherence_scores, device=responses.device)
def analyze_token_transitions(data: TensorDict) -> Dict[str, float]:
"""Analyze patterns in token transitions"""
responses = data["responses"]
response_mask = data.get("response_mask", torch.ones_like(responses))
# Count unique transitions
unique_transitions = set()
total_transitions = 0
for i in range(responses.shape[0]):
response_tokens = responses[i][response_mask[i].bool()]
if len(response_tokens) > 1:
for j in range(len(response_tokens) - 1):
transition = (response_tokens[j].item(), response_tokens[j+1].item())
unique_transitions.add(transition)
total_transitions += 1
diversity_ratio = len(unique_transitions) / max(total_transitions, 1)
return {
"analysis/transition_diversity/mean": diversity_ratio,
"analysis/unique_transitions/total": len(unique_transitions),
"analysis/total_transitions/total": total_transitions,
}
Integration with Training Workflow
MetricClient Usage Pattern
The MetricClient provides the main interface for submitting metrics:
from siirl.execution.metric_worker.metric_worker import MetricClient, MetricWorker
# Initialize metric worker and client
metric_worker = MetricWorker.remote()
await metric_worker.start.remote()
metric_client = MetricClient(metric_worker)
# During training loop
for step, batch in enumerate(dataloader):
# ... training logic ...
# Submit standard metrics
metric_client.compute_local_data_metric(batch, world_size)
# Submit custom metrics
custom_metrics = compute_advanced_metrics(batch)
metric_client.submit_metric(custom_metrics, world_size)
# Submit timing metrics
timing_data = {"step": step_time, "forward": forward_time}
metric_client.compute_local_timing_metrics(batch, timing_data, world_size)
# Wait for metrics to be processed
metric_client.wait_submit()
# Get final aggregated results
final_metrics = metric_client.wait_final_res()
Ray-based Distributed Aggregation
The system uses Ray actors for distributed metrics processing:
MetricWorker Actor:
- Runs asynchronously to collect metrics from all workers
- Aggregates metrics when all processes have submitted values
- Supports different aggregation functions (mean, max, min, sum)
- Automatically handles timing metric renaming (timing_s/ → perf/delta_time/)
Aggregation Logic:
- Metrics are collected in a queue until all workers (world_size) submit
- Each metric triggers computation when the expected number of submissions is reached
- Final results are stored and returned when requested
Special Metric Configurations
Some metrics require special aggregation logic:
# In metric_worker.py
Special_Metric = {
"graph_output_handling": MaxMetric, # Only rollout_tp 0 contributes
}
Custom metrics can be added to this dictionary for specialized handling.
Advanced Examples
Example 1: Model Performance Analysis
def compute_model_performance_metrics(data: TensorDict, model_outputs: dict) -> Dict[str, float]:
"""Comprehensive model performance analysis"""
metrics = {}
# Attention pattern analysis
if "attention_weights" in model_outputs:
attention_weights = model_outputs["attention_weights"]
# Attention concentration (how focused is attention)
attention_entropy = -torch.sum(
attention_weights * torch.log(attention_weights + 1e-9), dim=-1
)
metrics["model/attention_entropy/mean"] = attention_entropy.mean().item()
# Attention on different token types
if "attention_mask" in data:
prompt_attention = attention_weights[:, :, :-data["responses"].shape[-1]]
response_attention = attention_weights[:, :, -data["responses"].shape[-1]:]
metrics["model/prompt_attention_ratio/mean"] = (
prompt_attention.sum() / attention_weights.sum()
).item()
# Hidden state analysis
if "hidden_states" in model_outputs:
hidden_states = model_outputs["hidden_states"]
# Representation diversity
layer_norms = torch.norm(hidden_states, dim=-1)
metrics["model/hidden_norm/mean"] = layer_norms.mean().item()
metrics["model/hidden_norm/std"] = layer_norms.std().item()
return metrics
Example 2: Conversation Quality Assessment
def compute_conversation_quality_metrics(data: TensorDict) -> Dict[str, float]:
"""Multi-dimensional conversation quality assessment"""
metrics = {}
if "responses" not in data or "prompts" not in data:
return metrics
responses = data["responses"]
prompts = data["prompts"]
response_mask = data.get("response_mask", torch.ones_like(responses))
batch_size = responses.shape[0]
quality_scores = []
for i in range(batch_size):
# Extract actual tokens (remove padding)
response_tokens = responses[i][response_mask[i].bool()]
prompt_tokens = prompts[i]
# Length appropriateness (not too short, not too long)
response_length = len(response_tokens)
length_score = compute_length_appropriateness(response_length)
# Vocabulary richness
unique_tokens = len(torch.unique(response_tokens))
vocab_score = min(unique_tokens / response_length, 1.0) if response_length > 0 else 0
# Repetition penalty
repetition_score = compute_repetition_score(response_tokens)
# Overall quality
quality = 0.3 * length_score + 0.3 * vocab_score + 0.4 * repetition_score
quality_scores.append(quality)
quality_tensor = torch.tensor(quality_scores, device=responses.device)
return {
"conversation/quality/mean": quality_tensor.mean().item(),
"conversation/quality/std": quality_tensor.std().item(),
"conversation/quality/min": quality_tensor.min().item(),
"conversation/quality/max": quality_tensor.max().item(),
}
def compute_length_appropriateness(length: int, target_length: int = 50) -> float:
"""Compute how appropriate the response length is"""
if length == 0:
return 0.0
ratio = length / target_length
if ratio <= 1.0:
return ratio # Shorter is better than longer
else:
return 1.0 / ratio # Penalize overly long responses
def compute_repetition_score(tokens: torch.Tensor) -> float:
"""Compute score based on repetition patterns"""
if len(tokens) <= 1:
return 1.0
# Count repeated bigrams
bigrams = set()
repeated_bigrams = 0
for i in range(len(tokens) - 1):
bigram = (tokens[i].item(), tokens[i+1].item())
if bigram in bigrams:
repeated_bigrams += 1
else:
bigrams.add(bigram)
# Higher repetition = lower score
repetition_ratio = repeated_bigrams / (len(tokens) - 1)
return 1.0 - repetition_ratio
Configuration and Best Practices
Metric Naming Conventions
Follow these conventions for consistent metric organization:
# Training metrics
training/{category}/{metric_name}/{aggregation}
# Validation metrics
val/{category}/{data_source}/{metric_name}
val-core/{data_source}/{variable}/{metric_name}
val-aux/{category}/{metric_name}
# Performance metrics
perf/{metric_name}
# Analysis metrics
analysis/{category}/{metric_name}/{aggregation}
# Model introspection
model/{component}/{metric_name}/{aggregation}
Aggregation Selection
Choose aggregation methods based on metric semantics:
mean: Default for most metrics (accuracy, loss, etc.)
max: For peak values (max memory, worst-case latency)
min: For best-case scenarios (min loss, fastest response)
sum/total: For cumulative values (total tokens, total time)
median: For robust central tendency (when outliers matter)
p95/p99: For percentile-based SLA metrics
Error Handling
Always implement robust error handling:
def compute_safe_custom_metrics(data: TensorDict) -> Dict[str, float]:
"""Example of safe metric computation"""
metrics = {}
try:
# Check data availability
if "required_field" not in data:
return metrics
# Handle empty tensors
values = data["required_field"]
if values.numel() == 0:
return metrics
# Compute metrics with numerical stability
mean_val = torch.mean(values.float())
if torch.isfinite(mean_val):
metrics["custom/metric/mean"] = mean_val.item()
except Exception as e:
# Log error but don't crash training
print(f"Error computing custom metrics: {e}")
return {}
return metrics
Performance Considerations
Batch Processing: Compute metrics on entire batches, not individual samples
Device Placement: Keep tensors on the same device as input data
Memory Management: Avoid accumulating large tensors across steps
Async Processing: Use Ray actors for non-blocking metrics aggregation
Selective Computation: Only compute expensive metrics when needed
Debugging Custom Metrics
import os
def debug_custom_metrics(data: TensorDict, metrics: Dict[str, float]):
"""Debug utility for custom metrics"""
if os.environ.get("DEBUG_METRICS", "0") == "1":
print(f"Data keys: {list(data.keys())}")
print(f"Data shapes: {[(k, v.shape if hasattr(v, 'shape') else type(v)) for k, v in data.items()]}")
print(f"Computed metrics: {metrics}")
# Check for common issues
for name, value in metrics.items():
if not isinstance(value, (int, float)):
print(f"WARNING: Metric {name} has invalid type {type(value)}")
elif not np.isfinite(value):
print(f"WARNING: Metric {name} is not finite: {value}")
File Structure Summary
siirl/execution/metric_worker/
├── metric_worker.py # Ray actor for distributed aggregation
│ ├── MetricWorker # Ray remote actor class
│ └── MetricClient # Client interface
└── utils.py # Aggregation functions
├── Metric # Dataclass for metric values
├── MetricFunc # Function selection logic
├── MeanMetric # Mean aggregation
├── MaxMetric # Maximum aggregation
├── MinMetric # Minimum aggregation
└── SumMetric # Sum aggregation
siirl/utils/metrics/
└── metric_utils.py # Core metric computation
├── compute_data_metric # Standard training metrics
├── compute_timing_metrics # Timing analysis
├── compute_throughout_metrics # Throughput analysis
├── process_validation_metrics # Validation with bootstrap
├── bootstrap_metric # Bootstrap sampling utility
└── aggregate_validation_metrics # Parallel validation processing
This architecture provides a scalable, flexible foundation for comprehensive metrics collection in distributed training environments.