Metrics Interface

Custom metrics allow you to track and aggregate any quantitative measures during training and validation. siiRL provides a distributed, Ray-based metrics system that automatically handles aggregation across all workers using various reduction operations (mean, max, min, sum).

Architecture Overview

                      Distributed Metrics Architecture
==============================================================================

DAGWorker 0        DAGWorker 1        DAGWorker 2        DAGWorker N
+-----------+      +-----------+      +-----------+      +-----------+
| compute   |      | compute   |      | compute   |      | compute   |
| metrics   |      | metrics   |      | metrics   |      | metrics   |
+-----+-----+      +-----+-----+      +-----+-----+      +-----+-----+
      |                  |                  |                  |
      v                  v                  v                  v
+-----+-----+      +-----+-----+      +-----+-----+      +-----+-----+
| Metric    |      | Metric    |      | Metric    |      | Metric    |
| Client    |      | Client    |      | Client    |      | Client    |
+-----+-----+      +-----+-----+      +-----+-----+      +-----+-----+
      |                  |                  |                  |
      +------------------+------------------+------------------+
                                 |
                                 v
                      +-------------------+
                      |   MetricWorker    |  (Ray Actor - Singleton)
                      |   (Aggregator)    |
                      +-------------------+
                      | - Collect metrics |
                      | - Wait for all    |
                      |   workers         |
                      | - Aggregate:      |
                      |   mean/max/min/   |
                      |   sum             |
                      +--------+----------+
                               |
                               v
                      +-------------------+
                      |  Final Metrics    |
                      | (to Logger/WandB) |
                      +-------------------+

==============================================================================

Metrics Data Flow:

+-------------+     +----------------+     +----------------+     +--------+
| TensorDict  | --> | compute_*      | --> | MetricClient   | --> | Metric |
| (batch)     |     | _metric()      |     | .submit_metric |     | Worker |
+-------------+     +----------------+     +----------------+     +--------+
                           |
                           v
                    +-------------+
                    | Dict[str,   |
                    |   float]    |
                    | {name: val} |
                    +-------------+

==============================================================================

Key Files:

  • siirl/execution/metric_worker/metric_worker.py - Ray-based distributed metrics aggregation

  • siirl/utils/metrics/metric_utils.py - Core metric computation functions

  • siirl/execution/metric_worker/utils.py - Aggregation function utilities

Quick Start

Current Metrics System

Built-in Metrics Reference

The following tables list all built-in metrics provided by siiRL.

Data Metrics (from compute_data_metric in metric_utils.py):

Critic Metrics

Metric Name

Description

critic/score/mean|max|min

Sequence-level scores from token-level scores

critic/rewards/mean|max|min

Sequence-level rewards from token-level rewards

critic/advantages/mean|max|min

Advantages (masked by response_mask)

critic/returns/mean|max|min

Returns (masked by response_mask)

critic/values/mean|max|min

Value function estimates (if available)

critic/vf_explained_var

Explained variance of value function

Response Analysis Metrics

Metric Name

Description

response/length/mean|max|min

Response token lengths

response/clip_ratio/mean

Proportion hitting max response length

response/correct_length/mean|max|min

Lengths for responses with reward > 0.5

response/wrong_length/mean|max|min

Lengths for responses with reward ≤ 0.5

Prompt Analysis Metrics

Metric Name

Description

prompt/length/mean|max|min

Prompt token lengths

prompt/clip_ratio/mean

Proportion hitting max prompt length

System & Multi-turn Metrics

Metric Name

Description

perf/process_cpu_mem_used_gb

CPU memory usage per process

num_turns/min|max|mean

Statistics for multi-turn conversations

Timing Metrics (from compute_timing_metrics):

Metric Name

Description

timing_s/{stage}

Raw timing in seconds for each stage

timing_per_token_ms/{stage}

Per-token timing in milliseconds

Stages: gen, ref, values, adv, update_critic, update_actor

Throughput Metrics (from compute_throughout_metrics):

Metric Name

Description

perf/total_num_tokens

Total tokens processed

perf/time_per_step

Time per training step

perf/throughput

Tokens per second per GPU

Validation Metrics (from process_validation_metrics):

Metric Name

Description

val-core/{data_source}/{var}/mean@N

Mean across N samples

val-core/{data_source}/{var}/best@N/mean|std

Bootstrap best-of-N statistics

val-core/{data_source}/{var}/worst@N/mean|std

Bootstrap worst-of-N statistics

val-core/{data_source}/{var}/maj@N/mean|std

Bootstrap majority voting statistics

val/test_score/{data_source}

Test score per data source

Custom Metrics Implementation

Method 1: Custom Data Metrics

Extend the data metrics computed from training batches:

# Add to metric_utils.py
def compute_custom_training_metrics(data: TensorDict) -> Dict[str, float]:
    """Custom training-specific metrics"""
    metrics = {}

    # Policy entropy (exploration measure)
    if "policy_logits" in data:
        logits = data["policy_logits"]
        probs = torch.softmax(logits, dim=-1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
        response_mask = data.get("response_mask", torch.ones_like(entropy))

        # Only compute entropy for response tokens
        masked_entropy = entropy * response_mask.float()
        valid_entropy = masked_entropy.sum() / response_mask.sum()

        metrics["training/policy_entropy/mean"] = valid_entropy.item()

    # Gradient norm tracking
    if "grad_norm" in data:
        metrics["training/grad_norm/mean"] = data["grad_norm"].item()

    # Loss convergence tracking
    if "loss_values" in data:
        loss_values = data["loss_values"]
        metrics["training/loss/mean"] = loss_values.mean().item()
        metrics["training/loss/std"] = loss_values.std().item()

    return metrics

# Usage in MetricClient.compute_local_data_metric
def compute_local_data_metric(self, data: TensorDict, world_size: int):
    # Standard metrics
    standard_metrics = compute_data_metric(data)

    # Add custom metrics
    custom_metrics = compute_custom_training_metrics(data)

    # Combine and submit
    all_metrics = {**standard_metrics, **custom_metrics}
    self.submit_metric(all_metrics, world_size)

Method 2: Custom Validation Metrics

Add custom validation metrics with bootstrap sampling:

# Add to metric_utils.py
def compute_custom_validation_metrics(
    data_sources: list[str],
    sample_inputs: list[str],
    infos_dict: dict[str, list],
    sample_turns: list[int]
) -> dict[str, float]:
    """Custom validation metrics with bootstrap analysis"""

    # Extract custom fields from infos_dict
    custom_metrics = {}

    if "custom_score" in infos_dict:
        # Group by data source
        source_scores = defaultdict(list)
        for i, source in enumerate(data_sources):
            source_scores[source].append(infos_dict["custom_score"][i])

        # Compute statistics per source
        for source, scores in source_scores.items():
            if len(scores) > 0:
                custom_metrics[f"val/custom_score/{source}/mean"] = np.mean(scores)
                custom_metrics[f"val/custom_score/{source}/std"] = np.std(scores)

                # Bootstrap sampling for confidence intervals
                if len(scores) > 1:
                    bootstrap_results = bootstrap_metric(
                        data=scores,
                        subset_size=min(5, len(scores)),
                        reduce_fns=[np.mean, np.max, np.min],
                        n_bootstrap=1000
                    )
                    custom_metrics[f"val/custom_score/{source}/bootstrap_mean"] = bootstrap_results[0][0]
                    custom_metrics[f"val/custom_score/{source}/bootstrap_mean_std"] = bootstrap_results[0][1]

    # Conversation quality for multi-turn
    if "conversation_quality" in infos_dict and len(sample_turns) > 0:
        quality_by_turns = defaultdict(list)
        for i, turns in enumerate(sample_turns):
            if i < len(infos_dict["conversation_quality"]):
                quality_by_turns[turns].append(infos_dict["conversation_quality"][i])

        for turn_count, qualities in quality_by_turns.items():
            if len(qualities) > 0:
                custom_metrics[f"val/conversation_quality/turns_{turn_count}/mean"] = np.mean(qualities)

    return custom_metrics

# Usage in MetricClient.process_local_validation_metrics
def process_local_validation_metrics(self, data_sources, sample_inputs, infos_dict, sample_turns, world_size):
    # Standard validation metrics
    standard_metrics = process_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)

    # Add custom validation metrics
    custom_metrics = compute_custom_validation_metrics(data_sources, sample_inputs, infos_dict, sample_turns)

    # Combine and submit
    all_metrics = {**standard_metrics, **custom_metrics}
    self.submit_metric(all_metrics, world_size)

Method 3: Custom Aggregation Logic

Create custom aggregation functions for specialized reduction operations:

# Add to execution/metric_worker/utils.py
def MedianMetric(metrics: List[Metric]):
    """Custom median aggregation"""
    values = [v for metric in metrics
             for v in (metric.value if isinstance(metric.value, list) else [metric.value])]
    return float(torch.median(torch.tensor(values)).item())

def PercentileMetric(percentile: float):
    """Custom percentile aggregation factory"""
    def _percentile_metric(metrics: List[Metric]):
        values = [v for metric in metrics
                 for v in (metric.value if isinstance(metric.value, list) else [metric.value])]
        return float(torch.quantile(torch.tensor(values), percentile / 100.0).item())
    return _percentile_metric

# Update MetricFunc to handle custom aggregations
def MetricFunc(name: str):
    if "median" in name:
        return MedianMetric
    elif "p95" in name:
        return PercentileMetric(95)
    elif "p99" in name:
        return PercentileMetric(99)
    elif "min" in name:
        return MinMetric
    elif "max" in name:
        return MaxMetric
    elif "sum" in name or "total" in name:
        return SumMetric
    else:
        return MeanMetric

# Usage: name your metrics to trigger specific aggregations
metrics = {
    "custom/latency/median": latency_values,  # Will use MedianMetric
    "custom/score/p95": score_values,         # Will use 95th percentile
}

Method 4: Complex Custom Metrics

For more sophisticated metrics requiring multiple computation steps:

# Add to metric_utils.py
def compute_advanced_metrics(data: TensorDict) -> Dict[str, float]:
    """Advanced metrics requiring complex computation"""
    metrics = {}

    # Sequence coherence analysis
    if "responses" in data and "attention_mask" in data:
        coherence_scores = compute_sequence_coherence(data)
        metrics.update({
            "analysis/coherence/mean": coherence_scores.mean().item(),
            "analysis/coherence/std": coherence_scores.std().item(),
            "analysis/coherence/median": coherence_scores.median().item(),
        })

    # Token transition analysis
    if "responses" in data:
        transition_metrics = analyze_token_transitions(data)
        metrics.update(transition_metrics)

    # Reward distribution analysis
    if "token_level_rewards" in data:
        reward_dist_metrics = analyze_reward_distribution(data)
        metrics.update(reward_dist_metrics)

    return metrics

def compute_sequence_coherence(data: TensorDict) -> torch.Tensor:
    """Compute coherence score for each sequence"""
    responses = data["responses"]
    attention_mask = data["attention_mask"]
    batch_size = responses.shape[0]

    coherence_scores = []
    for i in range(batch_size):
        # Extract valid tokens for this sequence
        valid_length = attention_mask[i].sum().item()
        sequence = responses[i][:valid_length]

        # Compute local coherence (e.g., token transition smoothness)
        if len(sequence) > 1:
            # Simplified coherence: variance in token values
            coherence = 1.0 / (1.0 + torch.var(sequence.float()).item())
        else:
            coherence = 1.0

        coherence_scores.append(coherence)

    return torch.tensor(coherence_scores, device=responses.device)

def analyze_token_transitions(data: TensorDict) -> Dict[str, float]:
    """Analyze patterns in token transitions"""
    responses = data["responses"]
    response_mask = data.get("response_mask", torch.ones_like(responses))

    # Count unique transitions
    unique_transitions = set()
    total_transitions = 0

    for i in range(responses.shape[0]):
        response_tokens = responses[i][response_mask[i].bool()]
        if len(response_tokens) > 1:
            for j in range(len(response_tokens) - 1):
                transition = (response_tokens[j].item(), response_tokens[j+1].item())
                unique_transitions.add(transition)
                total_transitions += 1

    diversity_ratio = len(unique_transitions) / max(total_transitions, 1)

    return {
        "analysis/transition_diversity/mean": diversity_ratio,
        "analysis/unique_transitions/total": len(unique_transitions),
        "analysis/total_transitions/total": total_transitions,
    }

Integration with Training Workflow

MetricClient Usage Pattern

The MetricClient provides the main interface for submitting metrics:

from siirl.execution.metric_worker.metric_worker import MetricClient, MetricWorker

# Initialize metric worker and client
metric_worker = MetricWorker.remote()
await metric_worker.start.remote()
metric_client = MetricClient(metric_worker)

# During training loop
for step, batch in enumerate(dataloader):
    # ... training logic ...

    # Submit standard metrics
    metric_client.compute_local_data_metric(batch, world_size)

    # Submit custom metrics
    custom_metrics = compute_advanced_metrics(batch)
    metric_client.submit_metric(custom_metrics, world_size)

    # Submit timing metrics
    timing_data = {"step": step_time, "forward": forward_time}
    metric_client.compute_local_timing_metrics(batch, timing_data, world_size)

    # Wait for metrics to be processed
    metric_client.wait_submit()

# Get final aggregated results
final_metrics = metric_client.wait_final_res()

Ray-based Distributed Aggregation

The system uses Ray actors for distributed metrics processing:

MetricWorker Actor: - Runs asynchronously to collect metrics from all workers - Aggregates metrics when all processes have submitted values - Supports different aggregation functions (mean, max, min, sum) - Automatically handles timing metric renaming (timing_s/perf/delta_time/)

Aggregation Logic: - Metrics are collected in a queue until all workers (world_size) submit - Each metric triggers computation when the expected number of submissions is reached - Final results are stored and returned when requested

Special Metric Configurations

Some metrics require special aggregation logic:

# In metric_worker.py
Special_Metric = {
    "graph_output_handling": MaxMetric,  # Only rollout_tp 0 contributes
}

Custom metrics can be added to this dictionary for specialized handling.

Advanced Examples

Example 1: Model Performance Analysis

def compute_model_performance_metrics(data: TensorDict, model_outputs: dict) -> Dict[str, float]:
    """Comprehensive model performance analysis"""
    metrics = {}

    # Attention pattern analysis
    if "attention_weights" in model_outputs:
        attention_weights = model_outputs["attention_weights"]

        # Attention concentration (how focused is attention)
        attention_entropy = -torch.sum(
            attention_weights * torch.log(attention_weights + 1e-9), dim=-1
        )
        metrics["model/attention_entropy/mean"] = attention_entropy.mean().item()

        # Attention on different token types
        if "attention_mask" in data:
            prompt_attention = attention_weights[:, :, :-data["responses"].shape[-1]]
            response_attention = attention_weights[:, :, -data["responses"].shape[-1]:]

            metrics["model/prompt_attention_ratio/mean"] = (
                prompt_attention.sum() / attention_weights.sum()
            ).item()

    # Hidden state analysis
    if "hidden_states" in model_outputs:
        hidden_states = model_outputs["hidden_states"]

        # Representation diversity
        layer_norms = torch.norm(hidden_states, dim=-1)
        metrics["model/hidden_norm/mean"] = layer_norms.mean().item()
        metrics["model/hidden_norm/std"] = layer_norms.std().item()

    return metrics

Example 2: Conversation Quality Assessment

def compute_conversation_quality_metrics(data: TensorDict) -> Dict[str, float]:
    """Multi-dimensional conversation quality assessment"""
    metrics = {}

    if "responses" not in data or "prompts" not in data:
        return metrics

    responses = data["responses"]
    prompts = data["prompts"]
    response_mask = data.get("response_mask", torch.ones_like(responses))

    batch_size = responses.shape[0]
    quality_scores = []

    for i in range(batch_size):
        # Extract actual tokens (remove padding)
        response_tokens = responses[i][response_mask[i].bool()]
        prompt_tokens = prompts[i]

        # Length appropriateness (not too short, not too long)
        response_length = len(response_tokens)
        length_score = compute_length_appropriateness(response_length)

        # Vocabulary richness
        unique_tokens = len(torch.unique(response_tokens))
        vocab_score = min(unique_tokens / response_length, 1.0) if response_length > 0 else 0

        # Repetition penalty
        repetition_score = compute_repetition_score(response_tokens)

        # Overall quality
        quality = 0.3 * length_score + 0.3 * vocab_score + 0.4 * repetition_score
        quality_scores.append(quality)

    quality_tensor = torch.tensor(quality_scores, device=responses.device)

    return {
        "conversation/quality/mean": quality_tensor.mean().item(),
        "conversation/quality/std": quality_tensor.std().item(),
        "conversation/quality/min": quality_tensor.min().item(),
        "conversation/quality/max": quality_tensor.max().item(),
    }

def compute_length_appropriateness(length: int, target_length: int = 50) -> float:
    """Compute how appropriate the response length is"""
    if length == 0:
        return 0.0
    ratio = length / target_length
    if ratio <= 1.0:
        return ratio  # Shorter is better than longer
    else:
        return 1.0 / ratio  # Penalize overly long responses

def compute_repetition_score(tokens: torch.Tensor) -> float:
    """Compute score based on repetition patterns"""
    if len(tokens) <= 1:
        return 1.0

    # Count repeated bigrams
    bigrams = set()
    repeated_bigrams = 0

    for i in range(len(tokens) - 1):
        bigram = (tokens[i].item(), tokens[i+1].item())
        if bigram in bigrams:
            repeated_bigrams += 1
        else:
            bigrams.add(bigram)

    # Higher repetition = lower score
    repetition_ratio = repeated_bigrams / (len(tokens) - 1)
    return 1.0 - repetition_ratio

Configuration and Best Practices

Metric Naming Conventions

Follow these conventions for consistent metric organization:

# Training metrics
training/{category}/{metric_name}/{aggregation}

# Validation metrics
val/{category}/{data_source}/{metric_name}
val-core/{data_source}/{variable}/{metric_name}
val-aux/{category}/{metric_name}

# Performance metrics
perf/{metric_name}

# Analysis metrics
analysis/{category}/{metric_name}/{aggregation}

# Model introspection
model/{component}/{metric_name}/{aggregation}

Aggregation Selection

Choose aggregation methods based on metric semantics:

  • mean: Default for most metrics (accuracy, loss, etc.)

  • max: For peak values (max memory, worst-case latency)

  • min: For best-case scenarios (min loss, fastest response)

  • sum/total: For cumulative values (total tokens, total time)

  • median: For robust central tendency (when outliers matter)

  • p95/p99: For percentile-based SLA metrics

Error Handling

Always implement robust error handling:

def compute_safe_custom_metrics(data: TensorDict) -> Dict[str, float]:
    """Example of safe metric computation"""
    metrics = {}

    try:
        # Check data availability
        if "required_field" not in data:
            return metrics

        # Handle empty tensors
        values = data["required_field"]
        if values.numel() == 0:
            return metrics

        # Compute metrics with numerical stability
        mean_val = torch.mean(values.float())
        if torch.isfinite(mean_val):
            metrics["custom/metric/mean"] = mean_val.item()

    except Exception as e:
        # Log error but don't crash training
        print(f"Error computing custom metrics: {e}")
        return {}

    return metrics

Performance Considerations

  • Batch Processing: Compute metrics on entire batches, not individual samples

  • Device Placement: Keep tensors on the same device as input data

  • Memory Management: Avoid accumulating large tensors across steps

  • Async Processing: Use Ray actors for non-blocking metrics aggregation

  • Selective Computation: Only compute expensive metrics when needed

Debugging Custom Metrics

import os

def debug_custom_metrics(data: TensorDict, metrics: Dict[str, float]):
    """Debug utility for custom metrics"""
    if os.environ.get("DEBUG_METRICS", "0") == "1":
        print(f"Data keys: {list(data.keys())}")
        print(f"Data shapes: {[(k, v.shape if hasattr(v, 'shape') else type(v)) for k, v in data.items()]}")
        print(f"Computed metrics: {metrics}")

        # Check for common issues
        for name, value in metrics.items():
            if not isinstance(value, (int, float)):
                print(f"WARNING: Metric {name} has invalid type {type(value)}")
            elif not np.isfinite(value):
                print(f"WARNING: Metric {name} is not finite: {value}")

File Structure Summary

siirl/execution/metric_worker/
├── metric_worker.py          # Ray actor for distributed aggregation
│   ├── MetricWorker          # Ray remote actor class
│   └── MetricClient          # Client interface
└── utils.py                  # Aggregation functions
    ├── Metric                # Dataclass for metric values
    ├── MetricFunc            # Function selection logic
    ├── MeanMetric            # Mean aggregation
    ├── MaxMetric             # Maximum aggregation
    ├── MinMetric             # Minimum aggregation
    └── SumMetric             # Sum aggregation

siirl/utils/metrics/
└── metric_utils.py           # Core metric computation
    ├── compute_data_metric           # Standard training metrics
    ├── compute_timing_metrics        # Timing analysis
    ├── compute_throughout_metrics    # Throughput analysis
    ├── process_validation_metrics    # Validation with bootstrap
    ├── bootstrap_metric             # Bootstrap sampling utility
    └── aggregate_validation_metrics  # Parallel validation processing

This architecture provides a scalable, flexible foundation for comprehensive metrics collection in distributed training environments.