Reward Interface
Custom reward functions allow you to score model-generated responses. Simply write a Python function and specify its path in configuration.
Official Example: siirl/user_interface/rewards_interface/custom_gsm8k_reward.py
Architecture Overview
Reward Computation Flow
==============================================================================
+------------------+ +-------------------+ +------------------+
| Rollout Node | | Reward Node | | Advantage Node |
| (Generation) |---->| (Scoring) |---->| (Normalization) |
+------------------+ +-------------------+ +------------------+
|
v
+---------------+
| RewardManager |
+---------------+
|
+------------------------+------------------------+
| | |
v v v
+---------------+ +---------------+ +---------------+
| Naive Reward | | Batch Reward | | Custom Reward |
| (Rule-based) | | (Model-based) | | (User-defined)|
+---------------+ +---------------+ +---------------+
|
v
+---------------+
| compute_score |
| (data_source, |
| solution_str,|
| ground_truth,|
| extra_info) |
+-------+-------+
|
v
+---------------+
| Returns float |
| score [0, 1] |
+---------------+
==============================================================================
Custom Reward Function Integration:
Configuration Runtime
+---------------------------+ +---------------------------+
| custom_reward_function: | | RewardManager loads |
| path: /path/to/file.py | -----> | compute_score function |
| name: compute_score | | and calls it per sample |
+---------------------------+ +---------------------------+
Quick Start
Step 1: Write Reward Function
Create a Python file with compute_score function:
# my_reward.py
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""
Custom reward function
Args:
data_source (str): Dataset source identifier (e.g., "openai/gsm8k")
solution_str (str): Model generated text
ground_truth (str): Correct answer
extra_info (dict): Additional information (optional)
Returns:
float: Score (typically 0-1)
"""
# Your scoring logic
if solution_str == ground_truth:
return 1.0
else:
return 0.0
Step 2: Configuration
Command Line:
python -m siirl.main_dag \
custom_reward_function.path=/path/to/my_reward.py \
custom_reward_function.name=compute_score
Official Example: GSM8K
File: siirl/user_interface/rewards_interface/custom_gsm8k_reward.py
import re
def extract_solution(solution_str, method="strict"):
"""Extract answer from solution"""
if method == "strict":
# Requires #### <answer> format
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
if solution is None:
return None
final_answer = solution.group(1).replace(",", "")
return final_answer
elif method == "flexible":
# Extract last number
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
if len(answer) == 0:
return None
for final_answer in reversed(answer):
if final_answer not in ["", "."]:
return final_answer
return None
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""
GSM8K scoring function
Checks format and compares answer
"""
method = "strict"
format_score = 0.0
score = 1.0
answer = extract_solution(solution_str, method=method)
if answer is None:
return 0 # Format error
elif answer == ground_truth:
return score # Correct answer
else:
return format_score # Correct format but wrong answer
Usage:
python -m siirl.main_dag \
custom_reward_function.path=siirl/user_interface/rewards_interface/custom_gsm8k_reward.py \
custom_reward_function.name=compute_score \
data.train_files=/path/to/gsm8k.parquet
Custom Examples
Example 1: Keyword Matching
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Keyword-based reward"""
score = 0.0
# Check keywords
keywords = ["because", "therefore", "thus"]
for keyword in keywords:
if keyword in solution_str.lower():
score += 0.3
# Length check
words = len(solution_str.split())
if 50 <= words <= 200:
score += 0.4
return min(score, 1.0)
Example 2: Regex Validation
import re
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Regex-based format validation"""
# Extract numeric answer
match = re.search(r"答案[是为][::]\s*(\d+)", solution_str)
if match is None:
return 0.0 # Incorrect format
answer = match.group(1)
if answer == ground_truth:
return 1.0 # Correct
else:
return 0.1 # Correct format but wrong answer
Example 3: Multi-stage Scoring
import re
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Multi-stage scoring: format + reasoning + correctness"""
score = 0.0
# Stage 1: Format check (0.2 points)
if "####" in solution_str:
score += 0.2
# Stage 2: Reasoning steps (0.3 points)
steps = solution_str.count('\n')
if steps >= 3:
score += 0.3
# Stage 3: Answer correctness (0.5 points)
answer_match = re.search(r"#### ([\-0-9\.]+)", solution_str)
if answer_match:
answer = answer_match.group(1)
if answer == ground_truth:
score += 0.5
return score
Example 4: Multiple Datasets
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""Route to different scoring functions based on data_source"""
if data_source == "gsm8k":
return score_gsm8k(solution_str, ground_truth)
elif data_source == "math":
return score_math(solution_str, ground_truth)
else:
return 0.0
def score_gsm8k(solution_str, ground_truth):
# GSM8K specific logic
pass
def score_math(solution_str, ground_truth):
# MATH specific logic
pass
Function Specification
Required Signature
def compute_score(data_source, solution_str, ground_truth, extra_info):
"""
Args:
data_source (str): Dataset source
solution_str (str): Model generated response
ground_truth (str): Correct answer
extra_info (dict): Additional information
Returns:
float: Score value
"""
pass
Important Notes
Function Name: Can be customized, specify via
custom_reward_function.nameReturn Type: Must return
float, typically in [0, 1] rangeError Handling: Recommended to catch exceptions and return default value (e.g., 0.0)
Parameter Order: Must follow the signature order