SRPO Code Implementation Explained
This document provides a comprehensive guide to understanding the SRPO (Self-Referential Policy Optimization) algorithm implementation in siiRL. SRPO is designed for training Vision-Language-Action (VLA) models in embodied AI scenarios.
Note
Paper Reference: SRPO: Self-Referential Policy Optimization for Vision-Language-Action Models
Overview: What is SRPO?
Self-Referential Policy Optimization (SRPO) for Vision-Language-Action Models is a novel VLA-RL framework. SRPO eliminates the need for external demonstrations or manual reward engineering by leveraging successful trajectories generated by the model within the current training batch as self-references. This enables us to assign progress-based rewards to failed attempts.
A core innovation is the use of latent world representations (V-JEPA) to robustly measure behavioral progress. Rather than relying on raw pixels or requiring domain-specific fine-tuning, we utilize compressed, transferable encodings from a world model’s latent space. These representations naturally capture progress patterns across environments, making trajectory comparison accurate and generalizable.
Empirical evaluation on the LIBERO benchmark demonstrates SRPO’s efficiency and effectiveness. Starting from a supervised baseline with a 48.9% success rate, SRPO achieves a 99.2% success rate on novel states within only 200 RL steps, representing a 103% relative improvement without any additional supervision. Furthermore, SRPO shows significant robustness on the LIBERO-Plus benchmark, achieving a 167% performance gain.
In siiRL, SRPO is implemented as the embodied_srpo_pipeline + GRPO advantage estimator.
Code Architecture Overview
siiRL/
├── siirl/
│ ├── execution/
│ │ └── dag/
│ │ └── builtin_pipelines.py # embodied_srpo_pipeline() definition
│ ├── user_interface/
│ │ └── filter_interface/
│ │ └── embodied.py # embodied_local_rank_sampling()
│ ├── engine/
│ │ ├── rollout/
│ │ │ └── embodied_rollout.py # EmbodiedHFRollout class
│ │ └── actor/
│ │ └── embodied_actor.py # RobDataParallelPPOActor class
│ ├── dag_worker/
│ │ └── core_algos.py # GRPO advantage & PPO loss
│ ├── environment/
│ │ └── embodied/
│ │ └── adapters/ # LIBERO environment adapter
│ └── utils/
│ ├── reward_score/
│ │ └── embodied.py # compute_embodied_reward()
│ └── embodied/
│ └── video_emb.py # VideoEmbeddingModel (V-JEPA)
└── examples/
└── embodied_srpo_trainer/
└── run_openvla_oft_libero_*.sh # Training scripts
Training Pipeline Definition
The SRPO training pipeline is defined in siirl/execution/dag/builtin_pipelines.py using the Python Pipeline API:
def embodied_srpo_pipeline() -> TaskGraph:
"""
Embodied AI GRPO training pipeline with data filtering and VJEPA-based reward computation.
Workflow:
1. rollout_actor: Environment rollout with embodied AI agent
2. dynaminc_sampling: Data verification and filtering
3. compute_reward: VJEPA-based reward computation
4. calculate_advantages: Calculate advantages (GRPO group-based)
5. actor_old_log_prob: Compute old actor log probabilities (forward only)
6. reference_log_prob: Compute reference model log probabilities
7. actor_train: Actor training with GRPO
"""
pipeline = Pipeline(
"embodied_grpo_training_pipeline",
"Embodied AI GRPO training workflow with data filtering and VJEPA-based reward computation."
)
pipeline.add_node(
"rollout_actor",
func="siirl.dag_worker.dagworker:DAGWorker.generate",
deps=[],
node_type=NodeType.MODEL_INFERENCE,
node_role=NodeRole.ROLLOUT
).add_node(
"dynaminc_sampling",
func="siirl.user_interface.filter_interface.embodied.embodied_local_rank_sampling",
deps=["rollout_actor"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.DYNAMIC_SAMPLING
).add_node(
"compute_reward",
func="siirl.dag_worker.dagworker:DAGWorker.compute_reward",
deps=["dynaminc_sampling"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.REWARD
).add_node(
"calculate_advantages",
func="siirl.dag_worker.dagworker:DAGWorker.compute_advantage",
deps=["compute_reward"],
node_type=NodeType.COMPUTE,
node_role=NodeRole.ADVANTAGE
).add_node(
"actor_old_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_old_log_prob",
deps=["calculate_advantages"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR,
only_forward_compute=True
).add_node(
"reference_log_prob",
func="siirl.dag_worker.dagworker:DAGWorker.compute_ref_log_prob",
deps=["actor_old_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.REFERENCE
).add_node(
"actor_train",
func="siirl.dag_worker.dagworker:DAGWorker.train_actor",
deps=["reference_log_prob"],
node_type=NodeType.MODEL_TRAIN,
node_role=NodeRole.ACTOR
)
return pipeline.build()
Data Flow Diagram
SRPO Training Pipeline Data Flow
==============================================================================
DataLoader (task_id, trial_id)
|
v
+---------------------+
| rollout_actor | EmbodiedHFRollout.generate_sequences()
| (MODEL_INFERENCE) | -> VLA model + LIBERO environment interaction
+----------+----------+
| Output: {responses, input_ids, attention_mask, pixel_values,
| complete, finish_step, vjepa_embedding, task_file_name}
v
+---------------------+
| dynamic_sampling | embodied_local_rank_sampling()
| (COMPUTE) | -> verify() + _filter_batch()
+----------+----------+ Filter by accuracy bounds & truncation
| Output: filtered batch (samples with 0.1 <= acc <= 0.9)
v
+---------------------+
| compute_reward | compute_embodied_reward()
| (COMPUTE) | -> VJEPA-based reward shaping
+----------+----------+ Success: reward=1.0, Failure: reward=sigmoid(distance)
| Output: + {token_level_scores, token_level_rewards}
v
+---------------------+
| calculate_advantages| compute_grpo_outcome_advantage()
| (COMPUTE) | -> Group by prompt, normalize (score - mean) / std
+----------+----------+
| Output: + {advantages, returns}
v
+---------------------+
| actor_old_log_prob | RobDataParallelPPOActor.compute_log_prob()
| (MODEL_TRAIN) | -> Forward only, no gradient
| only_forward=True |
+----------+----------+
| Output: + {old_log_probs}
v
+---------------------+
| reference_log_prob | Reference model forward pass
| (MODEL_TRAIN) |
+----------+----------+
| Output: + {ref_log_prob}
v
+---------------------+
| actor_train | RobDataParallelPPOActor.update_policy()
| (MODEL_TRAIN) | -> compute_policy_loss_vanilla() (PPO clipped loss)
+---------------------+
|
| Metrics: {pg_loss, pg_clipfrac, ppo_kl, grad_norm}
v
+---------------------+
| sync_weights | ShardingManager (if needed)
+---------------------+
Core Components Deep Dive
1. Rollout: Environment Interaction
File: siirl/engine/rollout/embodied_rollout.py
Class: EmbodiedHFRollout
This is the core component that orchestrates the interaction between the VLA model and the simulation environment (LIBERO). It handles the complete episode generation process including action prediction, environment stepping, and visual embedding extraction.
Class Initialization
class EmbodiedHFRollout(BaseRollout):
def __init__(self, module: nn.Module, config: ActorRolloutRefArguments):
self.model = module # VLA model (e.g., OpenVLA-OFT)
self.config = config
# Initialize V-JEPA embedding model for reward computation
self.embedding_model = VideoEmbeddingModel(
model_path=config.embodied.video_embedding_model_path,
img_size=config.embodied.embedding_img_size,
enable_fp16=config.embodied.embedding_enable_fp16
)
# Initialize LIBERO environment adapter with parallel environments
self.num_workers = config.embodied.env.num_envs # e.g., 16 parallel envs
self.adapter = LIBEROAdapter(
env_name=config.embodied.env.env_name, # e.g., "libero_goal"
num_envs=self.num_workers,
max_steps=config.embodied.env.max_steps, # e.g., 512
num_steps_wait=config.embodied.env.num_steps_wait,
model_family=config.embodied.env.model_family,
gpu_ids=[self._rank % self._num_gpus_per_node]
)
Main Entry Point: generate_sequences()
def generate_sequences(self, prompts):
"""
Main entry point for generating sequences.
Splits large batches into chunks that fit the number of parallel workers.
"""
total_batch_size = prompts.batch_size[0]
n_samples = prompts['n_samples'] if 'n_samples' in prompts else 1
# Each prompt needs n_samples trajectories
batch_size_per_chunk = self.num_workers // n_samples
num_chunks = (total_batch_size + batch_size_per_chunk - 1) // batch_size_per_chunk
all_chunk_outputs = []
for i in range(num_chunks):
chunk_prompts = prompts[start_idx:end_idx]
chunk_output = self._generate_chunk_rollout(chunk_prompts)
all_chunk_outputs.append(chunk_output)
return torch.cat(all_chunk_outputs)
Episode Generation Loop: _generate_chunk_rollout()
This is the heart of the embodied rollout - a step-by-step interaction loop between the VLA model and the environment.
def _generate_chunk_rollout(self, prompts):
"""Generate complete episodes for a chunk of tasks."""
task_id = prompts['task_id']
trial_id = prompts['trial_id']
max_steps = self.config.embodied.env.max_steps
chunk_size = task_id.size(0)
# Step 1: Reset all parallel environments
init_data_list = self.adapter._blocking_reset(
task_ids=task_id.reshape(-1).cpu().numpy().tolist(),
trial_ids=trial_id.reshape(-1).cpu().numpy().tolist(),
)
# Collect initial observations
inputs = [self._obs_to_input(init_data['obs']) for init_data in init_data_list]
task_descriptions = [init_data["task_description"] for init_data in init_data_list]
task_records = [{"active": d['active'], "complete": d['complete'],
"finish_step": d['finish_step'], "task_file_name": d['task_file_name']}
for d in init_data_list]
# Step 2: Main interaction loop (up to max_steps)
step = 0
vla_history = [] # Store all step data for training
while step < max_steps:
active_indices = [i for i, r in enumerate(task_records) if r['active']]
# Step 2a: Process observations into VLA input format
vla_input = self.process_input(inputs, task_descriptions)
# Step 2b: VLA model predicts actions
vla_output = self._generate_one_step(vla_input)
actions = vla_output["action"]
# Store step data for later training
vla_history.append({
"responses": vla_output["responses"],
"input_ids": vla_output["input_ids"],
"attention_mask": vla_output["attention_mask"],
"pixel_values": vla_output["pixel_values"],
"action": actions,
"step": step
})
# Step 2c: Execute actions in environment
step_results_list = self.adapter._blocking_step({
"indices": active_indices,
"actions": actions,
})
# Step 2d: Update observations and task status
for idx in active_indices:
result = step_results_list[idx]
inputs[idx] = self._obs_to_input(result['obs'])
task_records[idx]['active'] = result['active']
task_records[idx]['complete'] = result['complete']
task_records[idx]['finish_step'] = result['finish_step']
step += self.config.embodied.action_chunks_len # e.g., += 8
# Step 3: Post-processing - Stack history and compute embeddings
batch = {}
for k in ["responses", "input_ids", "attention_mask", "pixel_values"]:
batch[k] = torch.stack([h[k] for h in vla_history], dim=1)
batch["complete"] = torch.tensor([r["complete"] for r in task_records])
batch["finish_step"] = torch.tensor([r["finish_step"] for r in task_records])
# Step 4: Extract V-JEPA embeddings for reward computation
batch_names, batch_frames = zip(*[(r['task_file_name'], all_video[r['task_file_name']])
for r in task_records])
vjepa_embeddings = self.embedding_model.get_embeddings(batch_names, batch_frames)
batch["vjepa_embedding"] = torch.tensor(np.array(vjepa_embeddings))
return TensorDict(batch, batch_size=chunk_size)
Single-Step Action Generation: _generate_one_step()
@torch.no_grad()
def _generate_one_step(self, prompts: dict):
"""Generate one action chunk from VLA model."""
if self.config.embodied.embodied_type == "openvla-oft":
# OpenVLA-OFT: Action Flow Transformer variant
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
actions, response = self.model.generate_action_verl(
input_ids=idx,
pixel_values=pixel_values,
attention_mask=attention_mask,
do_sample=do_sample,
unnorm_key=self.config.embodied.unnorm_key,
temperature=temperature,
)
# response shape: (batch_size, action_chunks_len * action_token_len)
elif self.config.embodied.embodied_type == "openvla":
# Standard OpenVLA: Autoregressive token generation
output = self.model.generate(
input_ids=idx,
pixel_values=pixel_values,
attention_mask=attention_mask,
do_sample=do_sample,
max_new_tokens=response_length,
temperature=temperature,
)
# Decode action tokens to continuous actions
predicted_action_token_ids = output.sequences[:, prompt_length:]
discretized_actions = self.model.vocab_size - predicted_action_token_ids
normalized_actions = self.model.bin_centers[discretized_actions]
return {
'responses': response,
'input_ids': idx,
'attention_mask': attention_mask,
'pixel_values': pixel_values,
'action': actions,
}
Key Output Fields:
Field |
Shape |
Description |
|---|---|---|
|
|
Action tokens (e.g., 7-dim: xyz + quat + gripper) |
|
|
Boolean: task success flag |
|
|
Integer: episode termination step |
|
|
V-JEPA visual features for reward computation |
2. Data Filtering (Dynamic Sampling)
File: siirl/user_interface/filter_interface/embodied.py
Function: embodied_local_rank_sampling()
This step filters out “too easy” or “too hard” prompts based on the success rate within each group.
def embodied_local_rank_sampling(
config: SiiRLArguments,
batch: TensorDict,
**kwargs: Any,
) -> NodeOutput:
"""
Performs verification, metric collection, and optional filtering on a batch.
"""
# Step 1: Verify the entire batch to get scores and enrich it with an 'acc' tensor.
_, reward_metrics, format_metrics, reward_format_metrics = verify(batch)
# Step 2: Conditionally filter the batch based on accuracy and truncation
embodied_sampling = config.algorithm.embodied_sampling
if embodied_sampling.filter_accuracy or embodied_sampling.filter_truncated:
n_samples = config.actor_rollout_ref.rollout.n
processed_batch = _filter_batch(batch, n_samples, config)
else:
processed_batch = batch
return NodeOutput(batch=processed_batch, metrics=sample_metrics)
def _filter_batch(batch: TensorDict, n_samples: int, config: SiiRLArguments) -> TensorDict:
"""
Filters a batch based on accuracy and truncation criteria.
Filtering is performed at the prompt level.
"""
num_prompts = len(batch) // n_samples
# --- 1. Accuracy Filtering ---
if config.algorithm.embodied_sampling.filter_accuracy:
# Reshape flat accuracy tensor into (num_prompts, n_samples)
acc_matrix = batch["acc"].reshape(num_prompts, n_samples)
# Calculate mean accuracy for each prompt
prompt_mean_acc = acc_matrix.mean(dim=-1)
# Create a boolean mask for prompts within the desired accuracy bounds
accuracy_lower_bound = config.algorithm.embodied_sampling.accuracy_lower_bound
accuracy_upper_bound = config.algorithm.embodied_sampling.accuracy_upper_bound
acc_mask = (prompt_mean_acc >= accuracy_lower_bound) & (prompt_mean_acc <= accuracy_upper_bound)
else:
acc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)
# --- 2. Truncation Filtering ---
if config.algorithm.embodied_sampling.filter_truncated:
finish_steps = batch["finish_step"].reshape(num_prompts, n_samples)
max_steps = config.actor_rollout_ref.embodied.env.max_steps
# A prompt is considered truncated if *any* of its samples reached max steps
has_truncated = (finish_steps >= max_steps).any(dim=-1)
trunc_mask = ~has_truncated
else:
trunc_mask = torch.ones(num_prompts, dtype=torch.bool, device=device)
# --- 3. Combine Masks and Apply Filter ---
combined_mask = acc_mask & trunc_mask
final_mask = combined_mask.repeat_interleave(n_samples)
filtered_batch = select_idxs(batch, final_mask)
return filtered_batch
Why Filter?
Too easy (acc > 0.9): All samples succeed → zero variance → zero advantage → no learning signal.
Too hard (acc < 0.1): All samples fail → similar issue.
Sweet spot (0.1 ≤ acc ≤ 0.9): Diverse outcomes → meaningful advantage estimates.
3. Reward Computation (VJEPA-based)
File: siirl/utils/reward_score/embodied.py
Function: compute_embodied_reward()
This is a key innovation of SRPO: using visual similarity to compute dense rewards for failed trajectories.
def compute_embodied_reward(
batch_data: TensorDict,
**kwargs: Any,
) -> List[Dict[str, Any]]:
"""
Computes rewards based on VJEPA embeddings and task completion status.
Reward Formula:
- Success: reward = 1.0
- Failure: reward = sigmoid(distance_to_success_cluster) ∈ [0, 0.6]
"""
# --- Step 1: Data Extraction and Pre-filtering ---
batch_size = batch_data["responses"].size(0)
completes = np.array(batch_data["complete"].tolist())
embeddings = batch_data["vjepa_embedding"].cpu().numpy()
task_file_names = _tensor_to_str_list(batch_data["task_file_name"])
# Pre-filtering: Identify invalid samples (all-zero embeddings)
zero_embedding_mask = np.all(embeddings == 0, axis=1)
valid_indices = np.where(~zero_embedding_mask)[0]
# --- Step 2: Initialize rewards ---
final_rewards = np.zeros(batch_size, dtype=float)
task_names = [_extract_task_name(name) for name in task_file_names]
# Group valid samples by task name
task_to_valid_indices = {}
for idx in valid_indices:
task_name = task_names[idx]
task_to_valid_indices.setdefault(task_name, []).append(idx)
# --- Step 3: Process each task group ---
for task_name, indices in task_to_valid_indices.items():
indices = np.array(indices)
task_completes = completes[indices]
success_indices = indices[task_completes]
fail_indices = indices[~task_completes]
# Success trajectories get reward = 1.0
final_rewards[success_indices] = 1.0
if len(success_indices) == 0 or len(fail_indices) == 0:
continue
# a. Cluster successful embeddings using DBSCAN
succ_embeddings = embeddings[success_indices]
scaler = StandardScaler()
scaled_succ_embeddings = scaler.fit_transform(succ_embeddings)
clustering = DBSCAN(eps=0.5, min_samples=2).fit(scaled_succ_embeddings)
cluster_centers = []
for label in set(clustering.labels_) - {-1}:
cluster_points = scaled_succ_embeddings[clustering.labels_ == label]
center = scaler.inverse_transform(cluster_points.mean(axis=0, keepdims=True)).flatten()
cluster_centers.append(center)
if not cluster_centers:
cluster_centers = [succ_embeddings.mean(axis=0)]
cluster_centers = np.array(cluster_centers)
# b. Compute distance from failed trajectories to nearest success cluster
fail_embeddings = embeddings[fail_indices]
distance_matrix = cdist(fail_embeddings, cluster_centers, "euclidean")
min_distances = distance_matrix.min(axis=1)
# c. Map distance to reward via sigmoid
max_dist, min_dist = min_distances.max(), min_distances.min()
dist_range = max_dist - min_dist
if dist_range < 1e-6:
normalized_dists = np.full_like(min_distances, 0.5)
else:
normalized_dists = (min_distances - min_dist) / dist_range
sigmoid_steepness = 10.0
sigmoid_offset = 0.5
sigmoid_inputs = sigmoid_steepness * (sigmoid_offset - normalized_dists)
reward_values = 0.6 * special.expit(sigmoid_inputs)
final_rewards[fail_indices] = reward_values
return [{"score": final_rewards[i]} for i in range(batch_size)]
Reward Visualization:
Reward
^
1.0| ●●● (Success)
|
0.6| ───────────────────── (Max for failure)
| ╱
| ╱ Sigmoid curve
| ╱
0.0|───╱────────────────────▶ Distance to Success
Near Far
Intuition: Failed trajectories that are “visually similar” to successful ones (low distance) receive higher rewards, encouraging the policy to explore in promising directions.
4. Advantage Calculation (GRPO)
File: siirl/dag_worker/core_algos.py
Function: compute_grpo_outcome_advantage()
GRPO computes advantages using group-relative normalization, eliminating the need for a Critic network.
@register_adv_est(AdvantageEstimator.GRPO)
def compute_grpo_outcome_advantage(
token_level_rewards: torch.Tensor, # (B, response_length)
response_mask: torch.Tensor, # (B, response_length)
index: np.ndarray, # (B,) - prompt index for grouping
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgorithmArguments] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
GRPO Advantage = (reward - group_mean) / group_std
This is the "Self-Referential" part: the baseline is computed from
the policy's own samples, not from a separate Value network.
"""
# Sum rewards across response tokens to get scalar reward per sample
scores = token_level_rewards.sum(dim=-1) # (B,)
# Group samples by prompt index
id2score = defaultdict(list)
id2mean = {}
id2std = {}
with torch.no_grad():
bsz = scores.shape[0]
for i in range(bsz):
idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
id2score[idx_key].append(scores[i])
# Compute group statistics
for idx in id2score:
if len(id2score[idx]) == 1:
id2mean[idx] = torch.tensor(0.0)
id2std[idx] = torch.tensor(1.0)
elif len(id2score[idx]) > 1:
scores_tensor = torch.stack(id2score[idx])
id2mean[idx] = torch.mean(scores_tensor)
id2std[idx] = torch.std(scores_tensor)
# Normalize: advantage = (score - mean) / std
for i in range(bsz):
idx_key = int(index[i].item()) if isinstance(index[i], torch.Tensor) else int(index[i])
if norm_adv_by_std_in_grpo:
scores[i] = (scores[i] - id2mean[idx_key]) / (id2std[idx_key] + epsilon)
else:
scores[i] = scores[i] - id2mean[idx_key] # Dr.GRPO variant
# Broadcast to token level
scores = scores.unsqueeze(-1) * response_mask
return scores, scores # (advantages, returns)
Embodied-specific handling in compute_advantage():
def compute_advantage(data: TensorDict, adv_estimator, ...):
if adv_estimator == AdvantageEstimator.GRPO:
if "finish_step" in data and data["responses"].ndim == 3:
# Embodied scenario: compute mask based on finish_step
responses = data["responses"]
batch_size = responses.size(0)
response_length = responses.size(1) * responses.size(2) # traj_len * action_token_len
action_token_len = responses.size(2)
finish_step = data['finish_step'] * action_token_len
steps = torch.arange(response_length, device=responses.device)
steps_expanded = steps.unsqueeze(0).expand(batch_size, -1)
grpo_calculation_mask = steps_expanded < finish_step.unsqueeze(1)
else:
# NLP scenario: use attention_mask-based response_mask
grpo_calculation_mask = data["response_mask"]
advantages, returns = compute_grpo_outcome_advantage(
token_level_rewards=data["token_level_rewards"],
response_mask=grpo_calculation_mask,
index=data["uid"],
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
5. Policy Update (PPO Loss)
File: siirl/engine/actor/embodied_actor.py
Class: RobDataParallelPPOActor
Method: update_policy()
The actor update uses the standard PPO clipped objective with GRPO advantages.
def update_policy(self, data: TensorDict):
self.actor_module.train()
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
temperature = data['temperature']
select_keys = ['responses', 'input_ids', 'attention_mask', 'pixel_values',
'old_log_probs', 'advantages', "finish_step"]
batch = data.select(*select_keys)
dataloader = batch.split(self.config.ppo_mini_batch_size)
metrics = {}
for batch_idx, data in enumerate(dataloader):
mini_batch = data
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.actor_optimizer.zero_grad()
for test_idx, data in enumerate(micro_batches):
data = data.cuda()
responses = data['responses']
response_length = responses.size(1) * responses.size(2)
# Build response mask from finish_step
finish_step = data['finish_step'] * self.config.action_token_len
steps = torch.arange(response_length, device=responses.device)
steps_expanded = steps.unsqueeze(0).expand(responses.size(0), -1)
response_mask = steps_expanded < finish_step.unsqueeze(1)
old_log_prob = data['old_log_probs']
advantages = data['advantages']
# Split trajectory into mini-batches for memory efficiency
traj_len = responses.size(1)
traj_split_num = int(traj_len / self.config.traj_mini_batch_size)
for i in range(0, traj_len, int(traj_len / traj_split_num)):
# Forward pass to get current log probs
entropy, log_prob = self._forward_micro_batch_update(
input_ids=input_ids[i:i+chunk_size],
attention_mask=attention_mask[i:i+chunk_size],
pixel_values=pixel_values[i:i+chunk_size],
responses=responses[i:i+chunk_size],
temperature=temperature
)
# Compute PPO clipped loss
pg_loss, pg_clipfrac, ppo_kl, _ = core_algos.compute_policy_loss_vanilla(
old_log_prob=old_log_prob_tmp,
log_prob=log_prob,
advantages=advantages_tmp,
response_mask=response_mask_tmp,
config=self.config
)
loss = pg_loss / self.gradient_accumulation
loss.backward()
grad_norm = self._optimizer_step()
return metrics
PPO Loss Function (from core_algos.py):
@register_policy_loss("vanilla")
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
advantages: torch.Tensor,
response_mask: torch.Tensor,
config: Optional[ActorArguments] = None,
...
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
L^CLIP(θ) = E[min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t)]
where r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t)
"""
clip_ratio = config.clip_ratio
clip_ratio_low = config.clip_ratio_low if config.clip_ratio_low is not None else clip_ratio
clip_ratio_high = config.clip_ratio_high if config.clip_ratio_high is not None else clip_ratio
# Importance ratio
negative_approx_kl = log_prob - old_log_prob
negative_approx_kl = torch.clamp(negative_approx_kl, min=-20.0, max=20.0) # stability
ratio = torch.exp(negative_approx_kl)
ppo_kl = siirl_F.masked_mean(-negative_approx_kl, response_mask)
# Clipped objective
pg_losses1 = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
clip_pg_losses1 = torch.maximum(pg_losses1, pg_losses2)
# Dual-clip for negative advantages
pg_losses3 = -advantages * clip_ratio_c
clip_pg_losses2 = torch.min(pg_losses3, clip_pg_losses1)
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
Key Configuration Parameters
Parameter |
Location |
Description |
|---|---|---|
|
Training config |
Set to |
|
Training script |
Group size (samples per prompt) |
|
Training script |
Enable accuracy-based filtering |
|
Training script |
Min success rate (default: 0.1) |
|
Training script |
Max success rate (default: 0.9) |
|
Training script |
Filter truncated episodes |
|
Training script |
Path to V-JEPA model |
|
Config |
Number of parallel environments |
|
Config |
Maximum steps per episode |
|
Config |
Actions per VLA forward pass |
Quick Reference: File Locations
Component |
File Path |
|---|---|
Training Entry |
|
Pipeline Definition |
|
Embodied Rollout |
|
Environment Adapter |
|
V-JEPA Embedding |
|
Data Filtering |
|
VJEPA Reward |
|
GRPO Advantage |
|
VLA Actor |
|
Example Scripts |
|