Rules

This module defines the classes and methods used to organize rules, manage their priority, and calculate how well an agent follows them during a simulation.

The Rule class acts as the primary method to evaluate ralizations. It stores the logic for a single rule, which includes a name, a unique ID, a calculate_violation function, and an aggregation_method. When called, the Rule executes its violation function using the provided VariableHandler and time step. Any parameters passed during rule initialization or at runtime are merged and passed to this function. The Result class stores the output, tracking the total_violation score and a history of violations across the simulation. The aggregation_method typically max or sum determines how the step-by-step scores are combined into the final result.

How to Define a Rule

To create a new rule, you must define a Python function that calculates a violation score in a single realization step. This function needs to accept two required arguments: VariableHandler and step. Inside this function, you use the handler to retrieve the VariablePool for that specific time step—by calling handler(step) and then extract the specific data you need, such as vehicle velocity or object distances. If the criteria for a rule violation are met, the function should return a positive numerical score representing the magnitude of the violation; otherwise, it should return zero. For instance, if you were defining a minimum distance rule, your function would call the handler to get the current state of both the ego vehicle and the nearest adversary, calculate the distance between them, and return the difference between that distance and your safety threshold if the threshold is violated.

Once your function is defined, you pass it into the Rule class constructor along with a name, a numeric ID (that would match the ID you wrote for the corresponding rule in the .graph file), and an aggregation method. The aggregation method tells the system how to handle the sequence of scores returned over the entire simulation; if you want to know the single worst violation that occurred, you would pass max, but if you want to know the total cumulative violation, you would pass sum.

Below is an example speed limit rule where the violation increases linearly.

import numpy as np

def speed_limit_violation(handler, step, limit=20):
    # Retrieve ego state from the pool at this step
    ego_state = handler(step).ego_state
    
    # Calculate magnitude of the velocity vector
    current_speed = np.linalg.norm(ego_state.velocity)
    
    # Return the overshoot, or 0 if within the limit
    return max(0, current_speed - limit)

# Registering the rule with ID 4, this means the rule will correspond to the number 4 in your .graph file
speed_limit_rule = Rule(speed_limit_violation, max, "speed_limit", 4)

Below is a clearance rule where the violation is the difference between the minimum required distance and the actual distance to the nearest vehicle in proximity.

def proximity_violation(handler, step, safety_threshold=2.0):
    pool = handler(step)
    # Get states of vehicles within the proximity_threshold
    nearby_vehicles = pool.vehicles_in_proximity
    max_violation = 0

    for veh_state in nearby_vehicles:
        # pool.distance(state) returns the polygonal distance to ego
        dist = pool.distance(veh_state)
        
        if dist < safety_threshold:
            # Calculate how much the safety buffer was breached
            max_violation = max(max_violation, safety_threshold - dist)
            
    return max_violation

# Registering the rule with ID 5
proximity_violation_rule = Rule(proximity_violation, max, "proximity_violation", 5)

Evaluation and Scoring

The RuleEngine iterates through the realization to evaluate all rules. It applies the violation function of each rule at every time step and collects the scores in Result objects.

The Result class tracks how a rule’s violation score evolves over the simulation. As the engine processes each step, the Result.add method stores the current violation score and updates the total_violation value based on the aggregation logic. Beyond the final score, the object maintains a violation_history, which is a list containing the violation score recorded at every time step.