Quickstart: GRPO training on GSM8K dataset
Post-train a LLM using GSM8K dataset.
Introduction
In this example, we train an LLM to tackle the GSM8k task with function-based rewards.
Prerequisite:
the latest version of
siiRLand its dependencies installed following the installation guide. Using the docker image is recommended.a GPU with at least 24 GB HBM
Dataset Introduction
GSM8k is a math problem dataset. The prompt is an elementary school problem. The LLM model is asked to solve the math problem. Below is an example:
Prompt
Katy makes coffee using teaspoons of sugar and cups of water in the ratio of 7:13. If she used a total of 120 teaspoons of sugar and cups of water, calculate the number of teaspoonfuls of sugar she used.
Solution
The total ratio representing the ingredients she used to make the coffee is 7+13 = <<7+13=20>>20 Since the fraction representing the number of teaspoons she used is 7/20, she used 7/20120 = <<7/20120=42>>42 #### 42
Step 1: Prepare the dataset
We preprocess the dataset in parquet format so that (1) it contains necessary fields for computing RL rewards and (2) is faster to read.
python3 examples/data_preprocess/gsm8k.py --local_dir ~/data/gsm8k
Step 2: Download a model for post-training
In this example, we start with the Qwen2.5-0.5B-Instruct model.
python3 -c "import transformers; transformers.pipeline('text-generation', model='Qwen/Qwen2.5-0.5B-Instruct')"
Step 3: Perform GRPO training with the instruct model
Reward Model/Function
We use a pre-defined rule-based reward model. We force the model to produce a final answer following 4 “#” as shown in the solution. We extract the final answer from both the solution and model’s output using regular expression matching. We assign a reward of 1 to correct answer, 0.0 to incorrect answer and 0 to no answer.
For more details, please refer to siirl/utils/reward_score/gsm8k.py.
Training Script
Now let’s run GRPO training with the dataset and model above. [1]
Set the data.train_files ,data.val_files, actor_rollout_ref.model.path and critic.model.path based on your dataset and model names or paths.
python3 -m siirl.client.main_dag \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=128 \
data.max_prompt_length=2048 \
data.max_response_length=4096 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.shuffle=False \
actor_rollout_ref.model.path=$HOME/models/Qwen/Qwen2.5-0.5B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=False \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.grad_clip=0.5 \
actor_rollout_ref.actor.clip_ratio=0.2 \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.max_model_len=8192 \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.free_cache_engine=False \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','tensorboard'] \
trainer.project_name=siirl_qwen2.5_0.5b_grpo \
trainer.experiment_name=siirl_qwen2.5_0.5b_grpo_toy \
trainer.n_gpus_per_node=1 \
trainer.nnodes=1 \
trainer.save_freq=200 \
trainer.test_freq=10 \
trainer.total_epochs=30 \
trainer.resume_mode=auto \
trainer.max_actor_ckpt_to_keep=1 \
trainer.default_local_dir=ckpts/qwen2.5_0.5b/grpo/ \
trainer.val_before_train=True 2>&1 | tee verl_demo.log
You are expected to see the following logs, indicating training in progress. The key metric val/test_score/openai/gsm8k is computed every trainer.test_freq steps:
step:1 - training/epoch:1.000 - training/global_step:0.000 - training/rollout_probs_diff_max:0.373 - training/rollout_probs_diff_mean:0.004 - training/rollout_probs_diff_std:0.009 - actor/entropy_loss:0.438 - actor/grad_norm:0.221 - actor/lr:0.000 - actor/pg_clipfrac:0.000 - actor/pg_clipfrac_lower:0.000 - actor/pg_loss:0.003 - actor/ppo_kl:-0.000 - critic/advantages/max:1.789 - critic/advantages/mean:-0.002 - critic/advantages/min:-0.730 - critic/returns/max:1.789 - critic/returns/mean:-0.002 - critic/returns/min:-0.730 - critic/rewards/max:1.000 - critic/rewards/mean:0.013 - critic/rewards/min:0.000 - critic/score/max:1.000 - critic/score/mean:0.013 - critic/score/min:0.000 - perf/cpu_mem_used_gb:11.576 - perf/cpu_memory_used_gb:125.440 - perf/delta_time/actor:72.260 - perf/delta_time/actor_log_prob:10.829 - perf/delta_time/advantage:0.039 - perf/delta_time/compute_core_metrics:0.020 - perf/delta_time/data_loading:1.030 - perf/delta_time/get_data_from_buffer:0.001 - perf/delta_time/get_entry_node:0.000 - perf/delta_time/get_intern_data_actor_old_log_prob:0.000 - perf/delta_time/get_intern_data_actor_train:0.000 - perf/delta_time/get_intern_data_calculate_advantages:0.000 - perf/delta_time/get_intern_data_function_reward:0.000 - perf/delta_time/get_intern_data_reference_log_prob:0.000 - perf/delta_time/get_next_node:0.000 - perf/delta_time/graph_execution:128.358 - perf/delta_time/graph_loop_management:0.001 - perf/delta_time/graph_output_handling:0.002 - perf/delta_time/put_data_to_buffer:0.001 - perf/delta_time/put_intern_data_actor_old_log_prob:0.000 - perf/delta_time/put_intern_data_actor_train:0.000 - perf/delta_time/put_intern_data_calculate_advantages:0.000 - perf/delta_time/put_intern_data_function_reward:0.000 - perf/delta_time/put_intern_data_reference_log_prob:0.000 - perf/delta_time/reduce_metrics:0.036 - perf/delta_time/ref:28.170 - perf/delta_time/reference:28.172 - perf/delta_time/reset_data_buffer:0.038 - perf/delta_time/reset_intern_data_buffer:0.000 - perf/delta_time/reward:0.255 - perf/delta_time/rollout:16.797 - perf/delta_time/step:129.426 - perf/delta_time/step_barrier:0.001 - perf/max_mem_alloc_gb:34.832 - perf/max_mem_rsvd_gb:39.678 - perf/max_memory_allocated_gb:34.832 - perf/max_memory_reserved_gb:39.678 - perf/mfu/actor:0.023 - perf/mfu/actor_log_prob:0.052 - perf/mfu/ref:0.021 - perf/mfu/rollout:0.079 - response_length/clip_ratio:0.610 - response_length/max:256.000 - response_length/mean:232.029 - response_length/min:76.000 - prompt_length/clip_ratio:0.000 - prompt_length/max:189.000 - prompt_length/mean:104.727 - prompt_length/min:66.000 - perf/total_num_tokens:431047.000 - perf/time_per_step:129.426 - perf/throughput:3330.450
step:2 - training/epoch:1.000 - training/global_step:1.000 - training/rollout_probs_diff_max:0.326 - training/rollout_probs_diff_mean:0.004 - training/rollout_probs_diff_std:0.009 - actor/entropy_loss:0.432 - actor/grad_norm:0.210 - actor/lr:0.000 - actor/pg_clipfrac:0.000 - actor/pg_clipfrac_lower:0.000 - actor/pg_loss:0.004 - actor/ppo_kl:-0.000 - critic/advantages/max:1.789 - critic/advantages/mean:-0.004 - critic/advantages/min:-0.730 - critic/returns/max:1.789 - critic/returns/mean:-0.004 - critic/returns/min:-0.730 - critic/rewards/max:1.000 - critic/rewards/mean:0.013 - critic/rewards/min:0.000 - critic/score/max:1.000 - critic/score/mean:0.013 - critic/score/min:0.000 - perf/cpu_mem_used_gb:11.589 - perf/cpu_memory_used_gb:125.617 - perf/delta_time/actor:72.457 - perf/delta_time/actor_log_prob:10.689 - perf/delta_time/advantage:0.040 - perf/delta_time/compute_core_metrics:0.001 - perf/delta_time/data_loading:0.005 - perf/delta_time/get_data_from_buffer:0.001 - perf/delta_time/get_entry_node:0.000 - perf/delta_time/get_intern_data_actor_old_log_prob:0.000 - perf/delta_time/get_intern_data_actor_train:0.000 - perf/delta_time/get_intern_data_calculate_advantages:0.000 - perf/delta_time/get_intern_data_function_reward:0.000 - perf/delta_time/get_intern_data_reference_log_prob:0.000 - perf/delta_time/get_next_node:0.000 - perf/delta_time/graph_execution:123.794 - perf/delta_time/graph_loop_management:0.001 - perf/delta_time/graph_output_handling:0.002 - perf/delta_time/put_data_to_buffer:0.001 - perf/delta_time/put_intern_data_actor_old_log_prob:0.000 - perf/delta_time/put_intern_data_actor_train:0.000 - perf/delta_time/put_intern_data_calculate_advantages:0.000 - perf/delta_time/put_intern_data_function_reward:0.000 - perf/delta_time/put_intern_data_reference_log_prob:0.000 - perf/delta_time/reduce_metrics:0.001 - perf/delta_time/ref:24.271 - perf/delta_time/reference:24.273 - perf/delta_time/reset_data_buffer:0.005 - perf/delta_time/reset_intern_data_buffer:0.000 - perf/delta_time/reward:0.286 - perf/delta_time/rollout:16.043 - perf/delta_time/step:123.805 - perf/delta_time/step_barrier:0.001 - perf/max_mem_alloc_gb:36.362 - perf/max_mem_rsvd_gb:41.596 - perf/max_memory_allocated_gb:36.362 - perf/max_memory_reserved_gb:41.596 - perf/mfu/actor:0.023 - perf/mfu/actor_log_prob:0.053 - perf/mfu/ref:0.024 - perf/mfu/rollout:0.082 - response_length/clip_ratio:0.595 - response_length/max:256.000 - response_length/mean:230.901 - response_length/min:20.000 - prompt_length/clip_ratio:0.000 - prompt_length/max:215.000 - prompt_length/mean:105.098 - prompt_length/min:65.000 - perf/total_num_tokens:430078.000 - perf/time_per_step:123.805 - perf/throughput:3473.837
Beside, we provides a formatted, easy-to-read summary of core performance metrics on rank 0. This provides a clear, separate view of the most important indicators.
========================= RANK(0): Core Performance Metrics (Step: 1) =========================
--- ⏱️ Overall Performance ---
Step Time : 129.426 s
Throughput (tokens/s) : 3330.45
Total Tokens in Step : 431047
--- 📈 Algorithm Metrics ---
Actor Entropy : 0.4380
Critic Rewards (Mean/Min/Max): 0.013 / 0.000 / 1.000
Critic Scores (Mean/Min/Max): 0.013 / 0.000 / 1.000
--- 🔥 Model Flops Utilization (MFU) ---
Mean MFU : N/A
Actor Training MFU : 0.023
Rollout MFU : 0.079
Reference Policy MFU : 0.021
Actor LogProb MFU : 0.052
--- 💾 Memory Usage ---
Max GPU Memory Allocated : 34.83 GB
Max GPU Memory Reserved : 39.68 GB
CPU Memory Used : 11.58 GB
--- 📏 Sequence Lengths ---
Prompt Length (Mean/Max) : 104.7 / 189
Response Length (Mean/Max) : 232.0 / 256
==================================================================================
Checkout Algorithm Baselines page for full training and validation logs for reference.
If you encounter out of memory issues with HBM less than 32GB, enable the following configs would help:
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
critic.ppo_micro_batch_size_per_gpu=1 \
For the full set of configs, please refer to config-explain-page for detailed explanation and performance tuning.