Introduction

Reinforcement learning (RL) is a subfield of artificial intelligence (AI) that teaches a machine to maximize a numerical reward through direct interaction with its environment. This is in contrast to supervised learning, which uses labeled input data to train a machine to make predictions on unseen data, and unsupervised learning, which discovers hidden patterns in unlabeled data [1]. Reinforcement learning has been applied in different contexts, such as reinforcement learning from human feedback (RLHF) used to improve large language models’ responses following users’ intent [2], autonomous vehicles [3], and achieving superhuman performance in complex games such as Go [4].

In this tutorial, we will explore the fundamentals of RL including game states, legal actions, and reward/penalty by building a Q-learning agent to play the classic Snake game from scratch.

RL Core Components

Game Rules

Snake game rules illustration
Figure 1.The snake can move in four directions: up, down, left, and right. If it eats the food, it grows longer. The snake dies if it hits the wall or itself (self-collision). The snake can grow up to 9 cells long.

Valid Game States

When playing the snake game, at each time step, the information we want to know is:

  • Where is the head position?
  • What is the head direction?
  • What is the current body segment?
  • Where is the food position located, non-overlapped with the snake’s body?

The combination of all the above information constitutes a game state at a given discrete time step. Knowing all of these helps the snake make a decision on what to do next, and it enables us to design a reward system that rewards and penalizes the learning agent. For example, if the snake’s head is facing East (→), it will be penalized if it turns in the opposite direction (←) because of self-collision. Similarly, knowing the body segment also helps us detect when the snake hits the wall.

Importantly, the snake’s body shape and its head’s direction should be valid. Specifically, the snake’s body shape should consist of adjacent segments, and the snake’s head direction should be valid for its location.

Invalid game states illustration
Figure 2. Examples of invalid game states. The snake's body shapes are disrupted and do not form a continuous adjacent segment.
Invalid game states illustration
Figure 3. Examples of invalid game states. The red cells represent the snake’s head, and the arrows show its invalid direction. In the first grid, if the snake's head is located on cell (1,1) for its given shape, then the only valid direction for its head is → (rightward). Note that the snake can turn its head in the arrow direction in the next step, but at the given time step, its direction cannot be in that direction - the direction must represent where it’s coming from in the previous step. Similar logic applies to the other two grids.

Q-Learning Algorithm

Reward and Penalty (Q-Table)

You are tasked with teaching the snake how to navigate this environment so that it learns which one is a good or bad move given its current state. What would you do? Given that the Snake lives in the computational world, we need to quantify the reward and penalty. It's like giving grades to students in a class.

Reward and penalty illustration
Figure 4. Reward and penalty for the snake. The snake is rewarded for eating the food and penalized for dying by hitting the wall or itself. If we keep iterating this process for all valid game states, we will have a table of numerical values for each state-action pair. We call those values Q-values and the table Q-table.

This table acts as the snake's memory. To build the memory on how to navigate this environment, the snake first needs to learn a lot through trials and errors. Once the memory is built, the snake can look up its memory to decide which action to take next.

Training Algorithm

The algorithm we use trains the snake to learn through trials and errors is called Q-learning:

$$Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha[R_{t+1} + \gamma \max_{a} Q(S_{t+1}, a) - Q(S_t, A_t)]$$

Yes, it's this scary formula that helps the snake learn! Let's unpack each symbol one by one:

  1. $Q(S_t, A_t)$ is the table that stores all values of state-action pairs (i.e., the snake's memory)
  2. $\alpha$ (0–1) is the learning rate (i.e., how much should we update the value)
  3. $\gamma$ (0–1) is the discounting factor (i.e., how much we value future rewards compared to immediate rewards. Higher $\gamma$ = more patient/strategic, lower $\gamma$ = more greedy/shortsighted)
  4. $\max_{a} Q(S_{t+1}, a)$ is the maximum value considering all possible actions in the next time step

You might ask why don't we just directly add values like -10 for penalty and +10 for reward at every time step? Why bother applying this complex algorithm?

It's all about the expected future return, not about immediate rewards. At every step, Snaky needs to answer the question: "If I choose this action now, how good will my future be on average?"

Implementation

1. Generate all valid game states

In order to enumerate all possible valid game states, we will follow these steps:

  1. For each snake length, compute the possible number of shapes (S).
  2. For each shape, enumerate all valid placements (P).
  3. For each placement, consider all possible head positions and orientations (O), as well as all valid food placements (F).
  4. Calculate the total number of valid states (T) for a given length as:

$$ T = \sum_{i=1}^{9} (S_{i} \times P_i \times O_i \times F_i) $$

Choose grid size and snake length to see all valid (head position, head direction, body shape) configurations. Full state count includes all possible food positions.

We start with type imports and a helper that yields valid neighboring cells (up, down, left, right) within the grid. This will be used to explore connected placements.

from typing import List, Tuple, Set, Dict
            
def neighbors(cell: Tuple[int, int], grid_size: int):
    """Generate valid neighboring cells within the grid.

    Args:
        cell: (row, col) tuple representing the current cell
        grid_size: size of the grid

    Yields:
        Tuple[int, int]: valid neighboring cells
    """
    r, c = cell

    # Loop through all possible movements: up, down, left, right
    for dr, dc in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
        nr, nc = r + dr, c + dc

        # Check if the new position is within the grid boundaries
        if 0 <= nr < grid_size and 0 <= nc < grid_size:
            yield (nr, nc)

Given two adjacent cells, we need the direction from the first to the second (e.g. upward, downward). This is used when enumerating which way the snake's head is facing for each shape.

def dir_from(from_pos: Tuple[int, int], to_pos: Tuple[int, int]) -> str:
      """Determine direction from one position to another.
  
      Args:
          from_pos: the coordinates of the starting position
          to_pos: the coordinates of the target position
  
      Returns:
          str: the direction from from_pos to to_pos ('upward', 'downward', 'rightward', 'leftward')
      """
      dr, dc = to_pos[0] - from_pos[0], to_pos[1] - from_pos[1]
  
      if dr == -1:
          return 'upward'
      if dr == 1:
          return 'downward'
      if dc == 1:
          return 'rightward'
      return 'leftward'

Using depth-first search (DFS), we enumerate all connected placements of a given length: each placement is a set of distinct, adjacent cells. We start from every cell and extend paths until we reach the desired length.

def generate_connected_placements(length: int, grid_size: int) -> List[Tuple]:
    """Generate all connected placements of a snake of given length on the grid.

    Args:
        length: the length of the snake
        grid_size: size of the grid

    Returns:
        List[Tuple]: list of unique snake placements (as sorted tuples of cells)
    """
    cells = [(r, c) for r in range(grid_size) for c in range(grid_size)]
    placements = set()

    def dfs(path):
        # If the path has reached the desired length, record it
        if len(path) == length:
            placements.add(tuple(sorted(path)))
            return

        # If the path is not yet complete, extend it
        tail = path[-1]
        for nb in neighbors(tail, grid_size):
            if nb not in path:
                dfs(path + [nb])

    # Start DFS from each cell in the grid
    for cell in cells:
        dfs([cell])

    return [tuple(p) for p in placements]
          

For each snake shape (placement), we need every valid (head position, head direction) pair. We use DFS over the shape cells: any ordering of the cells defines a head (first cell) and a direction (from head to second cell). A single-cell snake can face any of the four directions.

def head_dir_pairs_for_placement(shape_cells: List[Tuple[int, int]], grid_size: int) -> Set:
    """Generate all valid (head_pos, head_dir) pairs for a given snake shape.

    Args:
        shape_cells: a list of (r, c) tuples representing the snake's body cells
        grid_size: size of the grid

    Returns:
        set: a set of (head_pos, head_dir) pairs
    """
    shape_cells = set(shape_cells)
    pairs = set()

    # Special-case single cell: any facing is valid
    if len(shape_cells) == 1:
        only = next(iter(shape_cells))
        for d in ['upward', 'downward', 'rightward', 'leftward']:
            pairs.add((only, d))
        return pairs

    def dfs(path):
        """Recursive DFS to build paths through the shape cells."""
        # If the path covers all shape cells, record the head and direction
        if len(path) == len(shape_cells):
            head, second = path[0], path[1]
            pairs.add((head, dir_from(head, second)))
            return

        # If not complete, extend the path
        tail = path[-1]
        for nb in neighbors(tail, grid_size):
            # Only consider neighbors that are part of the shape and not already in the path
            if nb in shape_cells and nb not in path:
                dfs(path + [nb])

    # Start DFS from each cell in the shape
    for start in shape_cells:
        dfs([start])
    return pairs
          

We combine everything: for each length, each placement, each (head, direction) pair, and each food position not on the snake, we create one state and initialize its Q-values to 0.0 for every action.

def generate_all_valid_states(grid_size: int, actions: List[str]) -> Dict:
    """Generate all valid game states for a given grid size.

    Args:
        grid_size: size of the grid
        actions: list of available actions (used to initialize Q-values)

    Returns:
        Dict: dictionary mapping states to action dictionaries initialized to 0.0
              Format: {(head_pos, head_dir, body_tuple, food_pos): {action: 0.0, ...}, ...}
    """
    game_states = {}
    food_positions = [(x, y) for x in range(grid_size) for y in range(grid_size)]

    for length in range(1, grid_size * grid_size + 1):
        placements = generate_connected_placements(length, grid_size)

        for placement in placements:
            pairs = head_dir_pairs_for_placement(placement, grid_size)

            for head_pos, head_dir in pairs:
                body_tuple = tuple(placement)

                for food_pos in food_positions:
                    if food_pos not in placement:
                        state = (head_pos, head_dir, body_tuple, food_pos)
                        game_states[state] = {action: 0.0 for action in actions}

    return game_states
          

2. Create a Q-Learning Agent

The agent keeps a Q-table (state → action → value), learning rate, discount factor γ, and exploration rate ε. We initialize the Q-table with all valid states and zero Q-values.

import random
from typing import Dict, Tuple

class QLearningAgent:
    REWARD_FOOD = 10
    REWARD_DEATH = -10
    REWARD_STEP = 0

    def __init__(self,
                  actions: list,
                  learning_rate: float,
                  epsilon: float,
                  gamma: float,
                  grid_size: int,
                  num_episodes: int):
        """Define the attributes of the learning agent."""
        self.q_table = {}
        self.learning_rate = learning_rate
        self.epsilon = epsilon
        self.gamma = gamma
        self.actions = actions
        self.grid_size = grid_size
        self.num_episodes = num_episodes

    def set_q_table(self) -> Dict:
        """Initialize the Q-table with all valid game states and zero Q-values."""
        self.q_table = generate_all_valid_states(self.grid_size, self.actions)
        for action in self.actions:
            for state in self.q_table:
                self.q_table[state][action] = 0.0
        return self.q_table

We look up Q(s, a) and choose actions with epsilon-greedy: with probability ε we pick a random action; otherwise we pick an action with the highest Q-value for the current state (breaking ties randomly).

    def get_q_value(self, state: Tuple, action: str) -> float:
        return self.q_table.get(state, {}).get(action, 0.0)

    def choose_action(self, state: Tuple, epsilon: float) -> str:
        if random.uniform(0, 1) < epsilon:
            return random.choice(self.actions)
        state_actions = self.q_table.get(state, {})
        max_q = max(state_actions.values(), default=0.0)
        best_actions = [a for a, q in state_actions.items() if q == max_q]
        return random.choice(best_actions) if best_actions else random.choice(self.actions)

After each step we update Q(s, a) using the Q-learning rule: the new value is the old value plus the learning rate times the TD error (reward + γ max Q(s′, a′) − Q(s, a)).

    def update_q_value(self, state: Tuple, action: str, reward: int, next_state: Tuple) -> float:
        if state not in self.q_table:
            self.q_table[state] = {a: 0.0 for a in self.actions}
        current_q_value = self.get_q_value(state, action)
        max_future_q = max(self.q_table.get(next_state, {}).values(), default=0.0)
        new_q_value = current_q_value + self.learning_rate * (
            reward + self.gamma * max_future_q - current_q_value
        )
        self.q_table[state][action] = new_q_value
        return self.q_table[state][action]

3. Environment Setup

The snake is represented as a list of (row, col) positions (head first) and a flag indicating whether it just ate food (so we don't remove the tail on that step).

import random
import os

class Snake:
    def __init__(self, initial_length: int = 1) -> None:
        self.length = initial_length
        self.just_ate_food = False
        self.snake_positions = [(1, 1)]

The environment holds the grid size, current score, and the food cell. We keep a size alias for compatibility.

class GameEnvironment:
    def __init__(self, grid_size: int) -> None:
        self.grid_size = grid_size
        self.size = grid_size
        self.score = 0
        self.grid = [[0] * grid_size for _ in range(grid_size)]
        self.food_pos = (grid_size // 2, grid_size // 2)

Game logic ties the snake and environment together. We define direction deltas (row, col) and load/save the high score from a file.

class GameLogic:
    DOWN = (1, 0)
    UP = (-1, 0)
    LEFT = (0, -1)
    RIGHT = (0, 1)

    def __init__(self, grid_size: int, score_history=None) -> None:
        if score_history is None:
            score_history = []
        self.GameEnvironment = GameEnvironment(grid_size)
        self.Snake = Snake()
        self.directions = [self.DOWN, self.UP, self.LEFT, self.RIGHT]
        self.highest_score = self.load_high_score()

    def load_high_score(self) -> int:
        if os.path.exists("highscore.txt"):
            with open("highscore.txt", "r") as f:
                try:
                    return int(f.read().strip())
                except ValueError:
                    return 0
        return 0

    def save_high_score(self) -> None:
        if self.GameEnvironment.score > self.highest_score:
            self.highest_score = self.GameEnvironment.score
            with open("highscore.txt", "w") as f:
                f.write(str(self.highest_score))

Food is placed on a random empty cell. move updates the snake: new head from direction, then check self-collision and wall-collision; if valid, prepend new head and remove tail unless the snake just ate.

    def place_food(self) -> None:
        # Place food at a random position in the grid where the snake is not located
        empty = [
            (r, c)
            for r in range(self.GameEnvironment.grid_size)
            for c in range(self.GameEnvironment.grid_size)
            if (r, c) not in self.Snake.snake_positions
        ]
        self.GameEnvironment.food_pos = random.choice(empty)

    def move(self, direction):
        # Get the snake's head position (row, column)
        head = self.Snake.snake_positions[0]

        # Calculate the new head position based on the direction
        new_head = (head[0] + direction[0], head[1] + direction[1])

        # Check self-collision
        if new_head in self.Snake.snake_positions[1:]:
            raise Exception("Game Over: Self-collision")

        # Check wall-collision
        if not (0 <= new_head[0] < self.GameEnvironment.grid_size and 0 <= new_head[1] < self.GameEnvironment.grid_size):
            raise Exception("Game Over: Wall-collision")

        # Add the new head position to the snake's positions
        self.Snake.snake_positions.insert(0, new_head)

        # Check if the snake has eaten food
        if new_head == self.GameEnvironment.food_pos:
            self.Snake.just_ate_food = True
            self.GameEnvironment.score += 1
            self.place_food()

        # If the snake has not just eaten food, remove the last position
        if not self.Snake.just_ate_food:
            self.Snake.snake_positions.pop()
        else:
            self.Snake.just_ate_food = False

When the game ends we save the high score and return the final score. get_state exposes snake, food, score, and grid size for the agent; render prints a simple grid for debugging.

    def stop_game(self) -> int:
        # update the highest score
        self.save_high_score()
        return self.GameEnvironment.score

    def get_state(self):
        return {
            "snake": self.Snake.snake_positions,
            "food": self.GameEnvironment.food_pos,
            "score": self.GameEnvironment.score,
            "grid_size": self.GameEnvironment.size
        }

    def render(self):
        size = self.GameEnvironment.grid_size
        grid = [['.'] * size for _ in range(size)]

        for r, c in self.Snake.snake_positions:
            grid[r][c] = 'S'

        fr, fc = self.GameEnvironment.food_pos
        grid[fr][fc] = 'F'

        for row in grid:
            print(' '.join(row))
        print()
          

4. Set up the training loop

We use a fixed action set, learning rate, grid size, and number of episodes. A moving-average window smooths the learning curves for plotting.

import random
from typing import List, Dict
import matplotlib.pyplot as plt
from tqdm import tqdm

ACTIONS = ["turn_left", "go_straight", "turn_right", "turn_around"]
LEARNING_RATE = 0.2
GRID_SIZE = 4
NUM_EPISODES = 20000
MAX_STEPS_PER_EPISODE = 200
MOVING_AVG_WINDOW = 50

run_training builds a Q-learning agent, initializes the Q-table, then for each episode runs a training game (with exploration) and an evaluation game (epsilon=0). It returns the list of evaluation scores per episode.

def run_training(
    train_epsilon: float,
    gamma: float,
    num_episodes: int = NUM_EPISODES,
    grid_size: int = GRID_SIZE,
    learning_rate: float = LEARNING_RATE,
) -> List[float]:
    """
    Train a fresh Q-learning agent for a given training epsilon and,
    after each training episode, run a separate evaluation episode
    with epsilon = 0.0. Returns the per-episode evaluation scores
    over num_episodes.

    This mirrors the logic in train.py but without visualization or Q-table I/O.
    Training uses epsilon = train_epsilon, evaluation uses epsilon = 0.0
    (pure exploitation of the learned Q-table).
    """
    agent = QLearningAgent(
        actions=ACTIONS,
        learning_rate=learning_rate,
        epsilon=train_epsilon,
        gamma=gamma,
        grid_size=grid_size,
        num_episodes=num_episodes,
    )

    # Start from a fresh Q-table for each training run
    agent.set_q_table()

    eval_scores: List[float] = []

    for episode in tqdm(
        range(num_episodes), desc=f"Training (epsilon={train_epsilon})"
    ):
        # -------- Training episode (epsilon = train_epsilon) --------
        game = GameLogic(grid_size=agent.grid_size)
        game.place_food()

        steps = 0
        while True:
            steps += 1
            current_state = get_state_representation(game)
            action = agent.choose_action(current_state, train_epsilon)
            current_dir = get_current_direction(game.Snake.snake_positions)
            direction = get_direction(action, current_dir)

            try:
                old_score = game.GameEnvironment.score
                game.move(direction)
                new_score = game.GameEnvironment.score

                # Reward structure consistent with train.py
                if new_score > old_score:
                    reward = 10  # Ate food
                else:
                    reward = -0.1  # Normal step

                next_state = get_state_representation(game)

            except Exception as e:
                if "Wall" in str(e):
                    reward = -10
                elif "Self" in str(e):
                    reward = -10
                else:
                    reward = -10

                next_state = current_state  # Terminal state
                agent.update_q_value(current_state, action, reward, next_state)
                break

            # Update Q-value for non-terminal transition
            agent.update_q_value(current_state, action, reward, next_state)

            if steps >= MAX_STEPS_PER_EPISODE:
                break

        # -------- Evaluation episode (epsilon = 0.0, pure exploitation) --------
        eval_game = GameLogic(grid_size=agent.grid_size)
        eval_game.place_food()

        eval_steps = 0
        while True:
            eval_steps += 1
            eval_state = get_state_representation(eval_game)
            eval_action = agent.choose_action(eval_state, 0.0)
            eval_current_dir = get_current_direction(eval_game.Snake.snake_positions)
            eval_direction = get_direction(eval_action, eval_current_dir)

            try:
                eval_game.move(eval_direction)
            except Exception:
                # Game over; record evaluation score and stop this eval episode
                break

            if eval_steps >= MAX_STEPS_PER_EPISODE:
                break

        eval_scores.append(eval_game.GameEnvironment.score)

    return eval_scores

We smooth the raw per-episode scores with a simple moving average (same-length output). Then we plot each run's smoothed curve and save the figure.

def moving_average(values: List[float], window_size: int) -> List[float]:
    """
    Compute a simple moving average over the given values.
    The output list has the same length as the input list.
    """
    if not values or window_size <= 1:
        return values

    averaged: List[float] = []
    cumulative_sum = 0.0

    for i, v in enumerate(values):
        cumulative_sum += v
        if i >= window_size:
            cumulative_sum -= values[i - window_size]
            averaged.append(cumulative_sum / window_size)
        else:
            averaged.append(cumulative_sum / (i + 1))

    return averaged


def plot_learning_trajectories(results: Dict[str, List[float]]) -> None:
    """
    Plot smoothed average score per episode for each epsilon regime.
    """
    episodes = list(range(1, NUM_EPISODES + 1))

    plt.figure(figsize=(10, 6))

    for label, scores in results.items():
        smoothed = moving_average(scores, MOVING_AVG_WINDOW)
        plt.plot(episodes, smoothed, label=label)

    plt.xlabel("Episode")
    plt.ylabel("Average score per episode (moving average)")
    plt.title(
        f"Snake Q-learning: Learning trajectories over {NUM_EPISODES} episodes\n"
        f"(window size = {MOVING_AVG_WINDOW})"
    )
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig("learning_trajectories.png", dpi=200)
    # plt.show()

main fixes the random seed, runs training for several epsilon values, smooths and plots the results, and saves the figure.

def main() -> None:
    # Make runs reproducible
    random.seed(42)

    epsilon_configs = {
        "Pure exploitation (epsilon=0.0)": 0.0,
        "Pure exploration (epsilon=1.0)": 1.0,
        "Mixed (epsilon=0.1)": 0.1,
    }

    results: Dict[str, List[float]] = {}

    for label, eps in epsilon_configs.items():
        scores = run_training(train_epsilon=eps, gamma=0.5, num_episodes=NUM_EPISODES)
        results[label] = scores

    plot_learning_trajectories(results)


if __name__ == "__main__":
    main()            
          

Results and Takeaways

After training for several thousand episodes, our Q-learning agent should learn to navigate the snake toward food while avoiding walls and its own tail. Performance will depend on state representation and hyperparameters like learning rate and exploration rate.

Snake Q-learning training output
Learning trajectories. Smoothed evaluation score per episode for different exploration rates (ε). With pure exploitation (ε=0) from the start, the agent never explores and cannot discover better actions, so performance stays flat. With pure exploration (ε=1), actions are random every step and the agent does not exploit what it has learned, so scores remain low. A mixed setting (e.g. ε=0.1) balances exploration and exploitation: the agent improves over episodes as the Q-table is updated and evaluation runs (ε=0) reflect that learned policy.

This simple setup lays the groundwork for more advanced methods like Deep Q-Networks (DQN), which we'll explore in the Flappy Bird tutorial.

References

  1. Sutton, R. S., & Barto, A. G. (2018). Reinforcement learning: An introduction (2nd ed.). The MIT Press.
  2. Ghasemi, M. and Ebrahimi, F. (2024). Reinforcement learning in artificial intelligence. Artificial Intelligence Review.
  3. Ouyang, L., et al. (2022). Training language models to follow instructions with human feedback. Advances in Neural Information Processing Systems, 35.
  4. Wang, Y., et al. (2023). Reinforcement learning for autonomous driving: A survey. IEEE Transactions on Intelligent Transportation Systems.
  5. Silver, D., et al. (2016). Mastering the game of Go with deep neural networks and tree search. Nature, 529(7587), 484-489.