SRPO Code Implementation Explained

This document provides a comprehensive guide to understanding the SRPO (Self-Referential Policy Optimization) algorithm implementation in siiRL. SRPO is designed for training Vision-Language-Action (VLA) models in embodied AI scenarios.

Overview: What is SRPO?

Self-Referential Policy Optimization (SRPO) for Vision-Language-Action Models is a novel VLA-RL framework. SRPO eliminates the need for external demonstrations or manual reward engineering by leveraging successful trajectories generated by the model within the current training batch as self-references. This enables us to assign progress-based rewards to failed attempts.

A core innovation is the use of latent world representations (V-JEPA) to robustly measure behavioral progress. Rather than relying on raw pixels or requiring domain-specific fine-tuning, we utilize compressed, transferable encodings from a world model’s latent space. These representations naturally capture progress patterns across environments, making trajectory comparison accurate and generalizable.

Empirical evaluation on the LIBERO benchmark demonstrates SRPO’s efficiency and effectiveness. Starting from a supervised baseline with a 48.9% success rate, SRPO achieves a 99.2% success rate on novel states within only 200 RL steps, representing a 103% relative improvement without any additional supervision. Furthermore, SRPO shows significant robustness on the LIBERO-Plus benchmark, achieving a 167% performance gain.

In siiRL, SRPO is implemented as the embodied_srpo_pipeline + GRPO advantage estimator.

Code Architecture Overview

siiRL/
├── siirl/
│   ├── execution/
│   │   └── dag/
│   │       └── builtin_pipelines.py         # embodied_srpo_pipeline() definition
│   ├── user_interface/
│   │   └── filter_interface/
│   │       └── embodied.py                  # embodied_local_rank_sampling()
│   ├── engine/
│   │   ├── rollout/
│   │   │   └── embodied_rollout.py          # EmbodiedHFRollout class
│   │   └── actor/
│   │       └── embodied_actor.py            # RobDataParallelPPOActor class
│   ├── dag_worker/
│   │   └── core_algos.py                    # GRPO advantage & PPO loss
│   ├── environment/
│   │   └── embodied/
│   │       └── adapters/                    # LIBERO environment adapter
│   └── utils/
│       ├── reward_score/
│       │   └── embodied.py                  # compute_embodied_reward()
│       └── embodied/
│           └── video_emb.py                 # VideoEmbeddingModel (V-JEPA)
└── examples/
    └── embodied_srpo_trainer/
        └── run_openvla_oft_libero_*.sh      # Training scripts

Training Pipeline Definition

The SRPO training pipeline is defined in siirl/execution/dag/builtin_pipelines.py using the Python Pipeline API:

siirl/execution/dag/builtin_pipelines.py - embodied_srpo_pipeline()
def embodied_srpo_pipeline() -> TaskGraph:
    """
    Embodied AI GRPO training pipeline with data filtering and VJEPA-based reward computation.

    Workflow:
        1. rollout_actor: Environment rollout with embodied AI agent
        2. dynaminc_sampling: Data verification and filtering
        3. compute_reward: VJEPA-based reward computation
        4. calculate_advantages: Calculate advantages (GRPO group-based)
        5. actor_old_log_prob: Compute old actor log probabilities (forward only)
        6. reference_log_prob: Compute reference model log probabilities
        7. actor_train: Actor training with GRPO
    """
    pipeline = Pipeline(
        "embodied_grpo_training_pipeline",
        "Embodied AI GRPO training workflow with data filtering and VJEPA-based reward computation."
    )

    pipeline.add_node(
        "rollout_actor",
        func="siirl.dag_worker.dagworker:DAGWorker.generate",
        deps=[],
        node_type=NodeType.MODEL_INFERENCE,
        node_role=NodeRole.ROLLOUT
    ).add_node(
        "dynaminc_sampling",
        func="siirl.user_interface.filter_interface.embodied.embodied_local_rank_sampling",
        deps=["rollout_actor"],
        node_type=NodeType.COMPUTE,
        node_role=NodeRole.DYNAMIC_SAMPLING
    ).add_node(
        "compute_reward",
        func="siirl.dag_worker.dagworker:DAGWorker.compute_reward",
        deps=["dynaminc_sampling"],
        node_type=NodeType.COMPUTE,
        node_role=NodeRole.REWARD
    ).add_node(
        "calculate_advantages",
        func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
        deps=["compute_reward"],
        node_type=NodeType.COMPUTE,
        node_role=NodeRole.ADVANTAGE
    ).add_node(
        "actor_old_log_prob",
        func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
        deps=["calculate_advantages"],
        node_type=NodeType.MODEL_TRAIN,
        node_role=NodeRole.ACTOR,
        only_forward_compute=True
    ).add_node(
        "reference_log_prob",
        func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
        deps=["actor_old_log_prob"],
        node_type=NodeType.MODEL_TRAIN,
        node_role=NodeRole.REFERENCE
    ).add_node(
        "actor_train",
        func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
        deps=["reference_log_prob"],
        node_type=NodeType.MODEL_TRAIN,
        node_role=NodeRole.ACTOR
    )

    return pipeline.build()

Data Flow Diagram

                           SRPO Training Pipeline Data Flow
==============================================================================

DataLoader (task_id, trial_id)
    |
    v
+---------------------+
| rollout_actor       | EmbodiedHFRollout.generate_sequences()
| (MODEL_INFERENCE)   | -> VLA model + LIBERO environment interaction
+----------+----------+
           | Output: {responses, input_ids, attention_mask, pixel_values,
           |          complete, finish_step, vjepa_embedding, task_file_name}
           v
+---------------------+
| dynamic_sampling    | embodied_local_rank_sampling()
| (COMPUTE)           | -> verify() + _filter_batch()
+----------+----------+ Filter by accuracy bounds & truncation
           | Output: filtered batch (samples with 0.1 <= acc <= 0.9)
           v
+---------------------+
| compute_reward      | compute_embodied_reward()
| (COMPUTE)           | -> VJEPA-based reward shaping
+----------+----------+ Success: reward=1.0, Failure: reward=sigmoid(distance)
           | Output: + {token_level_scores, token_level_rewards}
           v
+---------------------+
| calculate_advantages| compute_grpo_outcome_advantage()
| (COMPUTE)           | -> Group by prompt, normalize (score - mean) / std
+----------+----------+
           | Output: + {advantages, returns}
           v
+---------------------+
| actor_old_log_prob  | RobDataParallelPPOActor.compute_log_prob()
| (MODEL_TRAIN)       | -> Forward only, no gradient
| only_forward=True   |
+----------+----------+
           | Output: + {old_log_probs}
           v
+---------------------+
| reference_log_prob  | Reference model forward pass
| (MODEL_TRAIN)       |
+----------+----------+
           | Output: + {ref_log_prob}
           v
+---------------------+
| actor_train         | RobDataParallelPPOActor.update_policy()
| (MODEL_TRAIN)       | -> compute_policy_loss_vanilla() (PPO clipped loss)
+---------------------+
           |
           | Metrics: {pg_loss, pg_clipfrac, ppo_kl, grad_norm}
           v
+---------------------+
| sync_weights        | ShardingManager (if needed)
+---------------------+

Core Components Deep Dive

1. Rollout: Environment Interaction

File: siirl/engine/rollout/embodied_rollout.py

Class: EmbodiedHFRollout

This is the core component that orchestrates the interaction between the VLA model and the simulation environment (LIBERO). It handles the complete episode generation process including action prediction, environment stepping, and visual embedding extraction.

Class Initialization

class EmbodiedHFRollout(BaseRollout):
    def __init__(self, module: nn.Module, config: ActorRolloutRefArguments):
        self.model = module  # VLA model (e.g., OpenVLA-OFT)
        self.config = config

        # Initialize V-JEPA embedding model for reward computation
        self.embedding_model = VideoEmbeddingModel(
            model_path=config.embodied.video_embedding_model_path,
            img_size=config.embodied.embedding_img_size,
            enable_fp16=config.embodied.embedding_enable_fp16
        )

        # Initialize LIBERO environment adapter with parallel environments
        self.num_workers = config.embodied.env.num_envs  # e.g., 16 parallel envs
        self.adapter = LIBEROAdapter(
            env_name=config.embodied.env.env_name,      # e.g., "libero_goal"
            num_envs=self.num_workers,
            max_steps=config.embodied.env.max_steps,    # e.g., 512
            num_steps_wait=config.embodied.env.num_steps_wait,
            model_family=config.embodied.env.model_family,
            gpu_ids=[self._rank % self._num_gpus_per_node]
        )

Main Entry Point: generate_sequences()

def generate_sequences(self, prompts):
    """
    Main entry point for generating sequences.
    Splits large batches into chunks that fit the number of parallel workers.
    """
    total_batch_size = prompts.batch_size[0]
    n_samples = prompts['n_samples'] if 'n_samples' in prompts else 1

    # Each prompt needs n_samples trajectories
    batch_size_per_chunk = self.num_workers // n_samples
    num_chunks = (total_batch_size + batch_size_per_chunk - 1) // batch_size_per_chunk

    all_chunk_outputs = []
    for i in range(num_chunks):
        chunk_prompts = prompts[start_idx:end_idx]
        chunk_output = self._generate_chunk_rollout(chunk_prompts)
        all_chunk_outputs.append(chunk_output)

    return torch.cat(all_chunk_outputs)

Episode Generation Loop: _generate_chunk_rollout()

This is the heart of the embodied rollout - a step-by-step interaction loop between the VLA model and the environment.

def _generate_chunk_rollout(self, prompts):
    """Generate complete episodes for a chunk of tasks."""
    task_id = prompts['task_id']
    trial_id = prompts['trial_id']
    max_steps = self.config.embodied.env.max_steps
    chunk_size = task_id.size(0)

    # Step 1: Reset all parallel environments
    init_data_list = self.adapter._blocking_reset(
        task_ids=task_id.reshape(-1).cpu().numpy().tolist(),
        trial_ids=trial_id.reshape(-1).cpu().numpy().tolist(),
    )

    # Collect initial observations
    inputs = [self._obs_to_input(init_data['obs']) for init_data in init_data_list]
    task_descriptions = [init_data["task_description"] for init_data in init_data_list]
    task_records = [{"active": d['active'], "complete": d['complete'],
                     "finish_step": d['finish_step'], "task_file_name": d['task_file_name']}
                    for d in init_data_list]

    # Step 2: Main interaction loop (up to max_steps)
    step = 0
    vla_history = []  # Store all step data for training

    while step < max_steps:
        active_indices = [i for i, r in enumerate(task_records) if r['active']]

        # Step 2a: Process observations into VLA input format
        vla_input = self.process_input(inputs, task_descriptions)

        # Step 2b: VLA model predicts actions
        vla_output = self._generate_one_step(vla_input)
        actions = vla_output["action"]

        # Store step data for later training
        vla_history.append({
            "responses": vla_output["responses"],
            "input_ids": vla_output["input_ids"],
            "attention_mask": vla_output["attention_mask"],
            "pixel_values": vla_output["pixel_values"],
            "action": actions,
            "step": step
        })

        # Step 2c: Execute actions in environment
        step_results_list = self.adapter._blocking_step({
            "indices": active_indices,
            "actions": actions,
        })

        # Step 2d: Update observations and task status
        for idx in active_indices:
            result = step_results_list[idx]
            inputs[idx] = self._obs_to_input(result['obs'])
            task_records[idx]['active'] = result['active']
            task_records[idx]['complete'] = result['complete']
            task_records[idx]['finish_step'] = result['finish_step']

        step += self.config.embodied.action_chunks_len  # e.g., += 8

    # Step 3: Post-processing - Stack history and compute embeddings
    batch = {}
    for k in ["responses", "input_ids", "attention_mask", "pixel_values"]:
        batch[k] = torch.stack([h[k] for h in vla_history], dim=1)

    batch["complete"] = torch.tensor([r["complete"] for r in task_records])
    batch["finish_step"] = torch.tensor([r["finish_step"] for r in task_records])

    # Step 4: Extract V-JEPA embeddings for reward computation
    batch_names, batch_frames = zip(*[(r['task_file_name'], all_video[r['task_file_name']])
                                       for r in task_records])
    vjepa_embeddings = self.embedding_model.get_embeddings(batch_names, batch_frames)
    batch["vjepa_embedding"] = torch.tensor(np.array(vjepa_embeddings))

    return TensorDict(batch, batch_size=chunk_size)

Single-Step Action Generation: _generate_one_step()

@torch.no_grad()
def _generate_one_step(self, prompts: dict):
    """Generate one action chunk from VLA model."""
    if self.config.embodied.embodied_type == "openvla-oft":
        # OpenVLA-OFT: Action Flow Transformer variant
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
            actions, response = self.model.generate_action_verl(
                input_ids=idx,
                pixel_values=pixel_values,
                attention_mask=attention_mask,
                do_sample=do_sample,
                unnorm_key=self.config.embodied.unnorm_key,
                temperature=temperature,
            )
        # response shape: (batch_size, action_chunks_len * action_token_len)

    elif self.config.embodied.embodied_type == "openvla":
        # Standard OpenVLA: Autoregressive token generation
        output = self.model.generate(
            input_ids=idx,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            do_sample=do_sample,
            max_new_tokens=response_length,
            temperature=temperature,
        )
        # Decode action tokens to continuous actions
        predicted_action_token_ids = output.sequences[:, prompt_length:]
        discretized_actions = self.model.vocab_size - predicted_action_token_ids
        normalized_actions = self.model.bin_centers[discretized_actions]

    return {
        'responses': response,
        'input_ids': idx,
        'attention_mask': attention_mask,
        'pixel_values': pixel_values,
        'action': actions,
    }

Key Output Fields:

Field

Shape

Description

responses

(B, traj_len, action_token_len)

Action tokens (e.g., 7-dim: xyz + quat + gripper)

complete

(B,)

Boolean: task success flag

finish_step

(B,)

Integer: episode termination step

vjepa_embedding

(B, embed_dim)

V-JEPA visual features for reward computation

2. Data Filtering (Dynamic Sampling)

File: siirl/user_interface/filter_interface/embodied.py

Function: embodied_local_rank_sampling()

This step filters out “too easy” or “too hard” prompts based on the success rate within each group.

def embodied_local_rank_sampling(
    config: SiiRLArguments,
    batch: TensorDict,
    **kwargs: Any,
) -> NodeOutput:
    """
    Performs verification, metric collection, and optional filtering on a batch.
    """
    # Step 1: Verify the entire batch to get scores and enrich it with an 'acc' tensor.
    _, reward_metrics, format_metrics, reward_format_metrics = verify(batch)

    # Step 2: Conditionally filter the batch based on accuracy and truncation
    embodied_sampling = config.algorithm.embodied_sampling
    if embodied_sampling.filter_accuracy or embodied_sampling.filter_truncated:
        n_samples = config.actor_rollout_ref.rollout.n
        processed_batch = _filter_batch(batch, n_samples, config)
    else:
        processed_batch = batch

    return NodeOutput(batch=processed_batch, metrics=sample_metrics)


def _filter_batch(batch: TensorDict, n_samples: int, config: SiiRLArguments) -> TensorDict:
    """
    Filters a batch based on accuracy and truncation criteria.
    Filtering is performed at the prompt level.
    """
    num_prompts = len(batch) // n_samples

    # --- 1. Accuracy Filtering ---
    if config.algorithm.embodied_sampling.filter_accuracy:
        # Reshape flat accuracy tensor into (num_prompts, n_samples)
        acc_matrix = batch["acc"].reshape(num_prompts, n_samples)
    # Calculate mean accuracy for each prompt
    prompt_mean_acc = acc_matrix.mean(dim=-1)

        # Create a boolean mask for prompts within the desired accuracy bounds
        accuracy_lower_bound = config.algorithm.embodied_sampling.accuracy_lower_bound
        accuracy_upper_bound = config.algorithm.embodied_sampling.accuracy_upper_bound
    acc_mask = (prompt_mean_acc >= accuracy_lower_bound) & (prompt_mean_acc <= accuracy_upper_bound)
    else:
        acc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)

    # --- 2. Truncation Filtering ---
    if config.algorithm.embodied_sampling.filter_truncated:
        finish_steps = batch["finish_step"].reshape(num_prompts, n_samples)
        max_steps = config.actor_rollout_ref.embodied.env.max_steps
        # A prompt is considered truncated if *any* of its samples reached max steps
        has_truncated = (finish_steps >= max_steps).any(dim=-1)
        trunc_mask = ~has_truncated
    else:
        trunc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)

    # --- 3. Combine Masks and Apply Filter ---
    combined_mask = acc_mask & trunc_mask
    final_mask = combined_mask.repeat_interleave(n_samples)
    filtered_batch = select_idxs(batch, final_mask)

    return filtered_batch

Why Filter?

  • Too easy (acc > 0.9): All samples succeed → zero variance → zero advantage → no learning signal.

  • Too hard (acc < 0.1): All samples fail → similar issue.

  • Sweet spot (0.1 ≤ acc ≤ 0.9): Diverse outcomes → meaningful advantage estimates.

3. Reward Computation (VJEPA-based)

File: siirl/utils/reward_score/embodied.py

Function: compute_embodied_reward()

This is a key innovation of SRPO: using visual similarity to compute dense rewards for failed trajectories.

def compute_embodied_reward(
    batch_data: TensorDict,
    **kwargs: Any,
) -> List[Dict[str, Any]]:
    """
    Computes rewards based on VJEPA embeddings and task completion status.

    Reward Formula:
    - Success: reward = 1.0
    - Failure: reward = sigmoid(distance_to_success_cluster) ∈ [0, 0.6]
    """
    # --- Step 1: Data Extraction and Pre-filtering ---
    batch_size = batch_data["responses"].size(0)
    completes = np.array(batch_data["complete"].tolist())
    embeddings = batch_data["vjepa_embedding"].cpu().numpy()
    task_file_names = _tensor_to_str_list(batch_data["task_file_name"])

    # Pre-filtering: Identify invalid samples (all-zero embeddings)
    zero_embedding_mask = np.all(embeddings == 0, axis=1)
    valid_indices = np.where(~zero_embedding_mask)[0]

    # --- Step 2: Initialize rewards ---
    final_rewards = np.zeros(batch_size, dtype=float)
    task_names = [_extract_task_name(name) for name in task_file_names]

    # Group valid samples by task name
    task_to_valid_indices = {}
    for idx in valid_indices:
        task_name = task_names[idx]
        task_to_valid_indices.setdefault(task_name, []).append(idx)

    # --- Step 3: Process each task group ---
    for task_name, indices in task_to_valid_indices.items():
        indices = np.array(indices)
        task_completes = completes[indices]

        success_indices = indices[task_completes]
        fail_indices = indices[~task_completes]

        # Success trajectories get reward = 1.0
        final_rewards[success_indices] = 1.0

        if len(success_indices) == 0 or len(fail_indices) == 0:
            continue

        # a. Cluster successful embeddings using DBSCAN
        succ_embeddings = embeddings[success_indices]
        scaler = StandardScaler()
        scaled_succ_embeddings = scaler.fit_transform(succ_embeddings)
        clustering = DBSCAN(eps=0.5, min_samples=2).fit(scaled_succ_embeddings)

        cluster_centers = []
        for label in set(clustering.labels_) - {-1}:
            cluster_points = scaled_succ_embeddings[clustering.labels_ == label]
            center = scaler.inverse_transform(cluster_points.mean(axis=0, keepdims=True)).flatten()
            cluster_centers.append(center)

        if not cluster_centers:
            cluster_centers = [succ_embeddings.mean(axis=0)]
        cluster_centers = np.array(cluster_centers)

        # b. Compute distance from failed trajectories to nearest success cluster
        fail_embeddings = embeddings[fail_indices]
        distance_matrix = cdist(fail_embeddings, cluster_centers, "euclidean")
        min_distances = distance_matrix.min(axis=1)

        # c. Map distance to reward via sigmoid
        max_dist, min_dist = min_distances.max(), min_distances.min()
        dist_range = max_dist - min_dist
        if dist_range < 1e-6:
            normalized_dists = np.full_like(min_distances, 0.5)
        else:
            normalized_dists = (min_distances - min_dist) / dist_range

        sigmoid_steepness = 10.0
        sigmoid_offset = 0.5
        sigmoid_inputs = sigmoid_steepness * (sigmoid_offset - normalized_dists)
        reward_values = 0.6 * special.expit(sigmoid_inputs)

        final_rewards[fail_indices] = reward_values

    return [{"score": final_rewards[i]} for i in range(batch_size)]

Reward Visualization:

Reward
  ^
1.0|  ●●● (Success)
   |
0.6|  ─────────────────────  (Max for failure)
   |      ╱
   |     ╱  Sigmoid curve
   |    ╱
0.0|───╱────────────────────▶ Distance to Success
       Near           Far

Intuition: Failed trajectories that are “visually similar” to successful ones (low distance) receive higher rewards, encouraging the policy to explore in promising directions.

4. Advantage Calculation (GRPO)

File: siirl/dag_worker/core_algos.py

Function: compute_grpo_outcome_advantage()

GRPO computes advantages using group-relative normalization, eliminating the need for a Critic network.

@register_adv_est(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
    token_level_rewards: torch.Tensor,  # (B, response_length)
    response_mask: torch.Tensor,        # (B, response_length)
    index: np.ndarray,                  # (B,) - prompt index for grouping
    epsilon: float = 1e-6,
    norm_adv_by_std_in_grpo: bool = True,
    config: Optional[AlgorithmArguments] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    GRPO Advantage = (reward - group_mean) / group_std

    This is the "Self-Referential" part: the baseline is computed from
    the policy's own samples, not from a separate Value network.
    """
    # Sum rewards across response tokens to get scalar reward per sample
    scores = token_level_rewards.sum(dim=-1)  # (B,)

    # Group samples by prompt index
    id2score = defaultdict(list)
    id2mean = {}
    id2std = {}

    with torch.no_grad():
        bsz = scores.shape[0]
    for i in range(bsz):
            idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
            id2score[idx_key].append(scores[i])

    # Compute group statistics
    for idx in id2score:
            if len(id2score[idx]) == 1:
                id2mean[idx] = torch.tensor(0.0)
                id2std[idx] = torch.tensor(1.0)
            elif len(id2score[idx]) > 1:
        scores_tensor = torch.stack(id2score[idx])
        id2mean[idx] = torch.mean(scores_tensor)
        id2std[idx] = torch.std(scores_tensor)

    # Normalize: advantage = (score - mean) / std
    for i in range(bsz):
            idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
        if norm_adv_by_std_in_grpo:
                scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)
            else:
                scores[i] = scores[i] - id2mean[idx_key]  # Dr.GRPO variant

        # Broadcast to token level
        scores = scores.unsqueeze(-1) * response_mask

    return scores, scores  # (advantages, returns)

Embodied-specific handling in compute_advantage():

def compute_advantage(data: TensorDict, adv_estimator, ...):
    if adv_estimator == AdvantageEstimator.GRPO:
        if "finish_step" in data and data["responses"].ndim == 3:
            # Embodied scenario: compute mask based on finish_step
            responses = data["responses"]
            batch_size = responses.size(0)
            response_length = responses.size(1) * responses.size(2)  # traj_len * action_token_len

            action_token_len = responses.size(2)
            finish_step = data['finish_step'] * action_token_len

            steps = torch.arange(response_length, device=responses.device)
            steps_expanded = steps.unsqueeze(0).expand(batch_size, -1)
            grpo_calculation_mask = steps_expanded < finish_step.unsqueeze(1)
        else:
            # NLP scenario: use attention_mask-based response_mask
            grpo_calculation_mask = data["response_mask"]

        advantages, returns = compute_grpo_outcome_advantage(
            token_level_rewards=data["token_level_rewards"],
            response_mask=grpo_calculation_mask,
            index=data["uid"],
            norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
        )

5. Policy Update (PPO Loss)

File: siirl/engine/actor/embodied_actor.py

Class: RobDataParallelPPOActor

Method: update_policy()

The actor update uses the standard PPO clipped objective with GRPO advantages.

def update_policy(self, data: TensorDict):
    self.actor_module.train()
    self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
    temperature = data['temperature']

    select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values',
                   'old_log_probs', 'advantages', "finish_step"]
    batch = data.select(*select_keys)
    dataloader = batch.split(self.config.ppo_mini_batch_size)

    metrics = {}
    for batch_idx, data in enumerate(dataloader):
        mini_batch = data
        micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
        self.actor_optimizer.zero_grad()

        for test_idx, data in enumerate(micro_batches):
            data = data.cuda()
            responses = data['responses']
            response_length = responses.size(1) * responses.size(2)

            # Build response mask from finish_step
            finish_step = data['finish_step'] * self.config.action_token_len
            steps = torch.arange(response_length, device=responses.device)
            steps_expanded = steps.unsqueeze(0).expand(responses.size(0), -1)
            response_mask = steps_expanded < finish_step.unsqueeze(1)

            old_log_prob = data['old_log_probs']
            advantages = data['advantages']

            # Split trajectory into mini-batches for memory efficiency
            traj_len = responses.size(1)
            traj_split_num = int(traj_len / self.config.traj_mini_batch_size)

            for i in range(0, traj_len, int(traj_len / traj_split_num)):
    # Forward pass to get current log probs
    entropy, log_prob = self._forward_micro_batch_update(
                    input_ids=input_ids[i:i+chunk_size],
                    attention_mask=attention_mask[i:i+chunk_size],
                    pixel_values=pixel_values[i:i+chunk_size],
                    responses=responses[i:i+chunk_size],
                    temperature=temperature
    )

    # Compute PPO clipped loss
    pg_loss, pg_clipfrac, ppo_kl, _ = core_algos.compute_policy_loss_vanilla(
                    old_log_prob=old_log_prob_tmp,
        log_prob=log_prob,
                    advantages=advantages_tmp,
                    response_mask=response_mask_tmp,
        config=self.config
    )

    loss = pg_loss / self.gradient_accumulation
    loss.backward()

    grad_norm = self._optimizer_step()
    return metrics

PPO Loss Function (from core_algos.py):

@register_policy_loss("vanilla")
def compute_policy_loss_vanilla(
    old_log_prob: torch.Tensor,
    log_prob: torch.Tensor,
    advantages: torch.Tensor,
    response_mask: torch.Tensor,
    config: Optional[ActorArguments] = None,
    ...
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    L^CLIP(θ) = E[min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t)]

    where r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
    """
    clip_ratio = config.clip_ratio
    clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
    clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio

    # Importance ratio
    negative_approx_kl = log_prob - old_log_prob
    negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0)  # stability
    ratio = torch.exp(negative_approx_kl)
    ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)

    # Clipped objective
    pg_losses1 = -advantages * ratio
    pg_losses2 = -advantages * torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
    clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)

    # Dual-clip for negative advantages
    pg_losses3 = -advantages * clip_ratio_c
    clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
    pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)

    pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

    return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower

Key Configuration Parameters

Parameter

Location

Description

algorithm.adv_estimator

Training config

Set to grpo for SRPO

actor_rollout_ref.rollout.n

Training script

Group size (samples per prompt)

algorithm.embodied_sampling.filter_accuracy

Training script

Enable accuracy-based filtering

algorithm.embodied_sampling.accuracy_lower_bound

Training script

Min success rate (default: 0.1)

algorithm.embodied_sampling.accuracy_upper_bound

Training script

Max success rate (default: 0.9)

algorithm.embodied_sampling.filter_truncated

Training script

Filter truncated episodes

actor_rollout_ref.embodied.video_embedding_model_path

Training script

Path to V-JEPA model

actor_rollout_ref.embodied.env.num_envs

Config

Number of parallel environments

actor_rollout_ref.embodied.env.max_steps

Config

Maximum steps per episode

actor_rollout_ref.embodied.action_chunks_len

Config

Actions per VLA forward pass

Quick Reference: File Locations

Component

File Path

Training Entry

siirl/main_dag.py

Pipeline Definition

siirl/execution/dag/builtin_pipelines.py

Embodied Rollout

siirl/engine/rollout/embodied_rollout.py

Environment Adapter

siirl/environment/embodied/adapters/

V-JEPA Embedding

siirl/utils/embodied/video_emb.py

Data Filtering

siirl/user_interface/filter_interface/embodied.py

VJEPA Reward

siirl/utils/reward_score/embodied.py

GRPO Advantage

siirl/dag_worker/core_algos.py

VLA Actor

siirl/engine/actor/embodied_actor.py

Example Scripts

examples/embodied_srpo_trainer/run_openvla_oft_*.sh

References

  1. SRPO Paper: Self-Referential Policy Optimization for Vision-Language-Action Models

  2. V-JEPA: Video Joint Embedding Predictive Architecture 2