Notes: Finite State Machine in Go

Notes: Finite State Machine in Go

This blog contains finite state machine implementation in Go taking payment flow as an example

ยท

5 min read

Table of contents

No heading

No headings in the article.

A state machine in simple terms is a set of states that a system can have and events which changes the state of the machine from one state to another. Based on state change, the system can trigger some action as well.

Let's take the example of payment flow to understand it better. Below is the state transition table that we'll be using: Screenshot 2022-05-06 at 12.09.52 PM.png

In this table, we have Current State which defines the current state of the system, Event which leads to the change of the state, Next State which gives us the next state based on the event, and lastly the Action which triggers some action.

We can also represent this as a directed graph: autodraw 06_05_2022.png

Now, let's dive into writing code.

Firstly, we'll define the node. A node is a data structure that contains the current state info and a list of all the available transitions w.r.t. events from that state. We can define it like this:

type Node struct {
    State // string
    Transitions map[Event]*Transition
}

Now, Transition should contain the next node and action it ought to trigger after the transition. It can be defined like this:

type Transition struct {
    *Node
    Action
}

Now that we have defined our data structures, we can define our state machine. State machine should contain information about the initial node and the current node. We can define state machine like this:

type StateMachine struct {
    initialNode *Node
    CurrentNode *Node
}

Now, the only method that is required for state machine is Transition(ctx context.Context, event Event) (*Node, error) which takes the event and returns the next state node in return based on the event. This method can execute the Action as well if the state change is valid. Code can look something like this:

func (m *StateMachine) Transition(ctx context.Context, event Event) (*Node, error) {
    node, err := m.getNextNode(event) // gets current node
    if err != nil {
        return nil, fmt.Errorf(err.Error())
    }

    err = m.CurrentNode.Transitions[event].Action(ctx)
    if err != nil {
        return nil, fmt.Errorf(err.Error())
    }

    m.CurrentNode = node

    return m.CurrentNode, nil
}

Once everything is set up, we can move to the use case now. 1) We need to define all the nodes like this:

var (
        cartPageNode,
        checkoutProcessingNode,
        paymentProcessingNode,
        doneNode,
        failedNode Node
    )

2) Then we can add relations to the nodes, e.g. for paymentProcessingNode:

paymentProcessingNode = Node{
    State: "paymentProcessing",
    Transitions: map[Event]*Transition{
        "success": &Transition{
            Node: &doneNode,
            Action: func(ctx context.Context) error {
                fmt.Println("payment processing -> success -> done")
                return nil
            },
        },
        "timed_out": &Transition{
            Node: &cartPageNode,
            Action: func(ctx context.Context) error {
                fmt.Println("payment processing -> timed_out -> cart page")
                return nil
            },
        },
        "failed": &Transition{
            Node: &failedNode,
            Action: func(ctx context.Context) error {
                fmt.Println("payment processing -> failed -> failed")
                return nil
            },
        },
    },
}

3) Next we can execute transitions like this: nextNode, _ := machine.Transition(context.TODO(), "checkout_requested")

nextNode, _ = machine.Transition(context.TODO(), "timed_out")

This completes the state machine implementation logic. Below is the fully working code. Try it and modify it as per your requirement:

package main

import (
    "context"
    "fmt"
)

type State string
type Event string

type Node struct {
    State
    Transitions map[Event]*Transition
}

type Transition struct {
    *Node
    Action
}

type Action func(ctx context.Context) error

type StateMachine struct {
    initialNode *Node
    CurrentNode *Node
}

func (m *StateMachine) getCurrentNode() *Node {
    return m.CurrentNode
}

func (m *StateMachine) getNextNode(event Event) (*Node, error) {
    if m.CurrentNode == nil {
        return nil, fmt.Errorf("nowhere to go anymore!\n")
    }

    transition, ok := m.CurrentNode.Transitions[event]
    if !ok {
        return nil, fmt.Errorf("invalid event: %v", event)
    }

    return transition.Node, nil
}

func (m *StateMachine) Transition(ctx context.Context, event Event) (*Node, error) {
    node, err := m.getNextNode(event)
    if err != nil {
        return nil, fmt.Errorf(err.Error())
    }

    err = m.CurrentNode.Transitions[event].Action(ctx)
    if err != nil {
        return nil, fmt.Errorf(err.Error())
    }

    m.CurrentNode = node

    return m.CurrentNode, nil
}

func NewStateMachine(initialNode *Node) *StateMachine {
    if initialNode == nil {
        return &StateMachine{}
    }

    return &StateMachine{
        initialNode: initialNode,
        CurrentNode: initialNode,
    }
}

func main() {
    var (
        cartPageNode,
        checkoutProcessingNode,
        paymentProcessingNode,
        doneNode,
        failedNode Node
    )

    cartPageNode = Node{
        State: "cartPage",
        Transitions: map[Event]*Transition{
            "checkout_requested": &Transition{
                Node: &checkoutProcessingNode,
                Action: func(ctx context.Context) error {
                    fmt.Println("cart page -> checkout_requested -> checkout processing")
                    return nil
                },
            },
        },
    }

    checkoutProcessingNode = Node{
        State: "checkoutProcessing",
        Transitions: map[Event]*Transition{
            "payment_requested": &Transition{
                Node: &paymentProcessingNode,
                Action: func(ctx context.Context) error {
                    fmt.Println("checkout processing -> payment_requested -> payment processing")
                    return nil
                },
            },
        },
    }

    paymentProcessingNode = Node{
        State: "paymentProcessing",
        Transitions: map[Event]*Transition{
            "success": &Transition{
                Node: &doneNode,
                Action: func(ctx context.Context) error {
                    fmt.Println("payment processing -> success -> done")
                    return nil
                },
            },
            "timed_out": &Transition{
                Node: &cartPageNode,
                Action: func(ctx context.Context) error {
                    fmt.Println("payment processing -> timed_out -> cart page")
                    return nil
                },
            },
            "failed": &Transition{
                Node: &failedNode,
                Action: func(ctx context.Context) error {
                    fmt.Println("payment processing -> failed -> failed")
                    return nil
                },
            },
        },
    }

    machine := NewStateMachine(&cartPageNode)

    fmt.Printf("0. initial: %#v\n\n", machine.getCurrentNode())

    nextNode, _ := machine.Transition(context.TODO(), "checkout_requested")
    fmt.Printf("1. next state for event checkout requested: %#v\n\n", nextNode)

    nextNode, err := machine.Transition(context.TODO(), "gibberish")
    fmt.Printf("2. next state for event gibberish: %#v, error: %#v\n\n", nextNode, err)

    nextNode, _ = machine.Transition(context.TODO(), "payment_requested")
    fmt.Printf("3. next state for event payment requested: %#v\n\n", nextNode)

    nextNode, _ = machine.Transition(context.TODO(), "timed_out")
    fmt.Printf("4. next state for event timed out: %#v\n\n", nextNode)

    nextNode, err = machine.Transition(context.TODO(), "success")
    fmt.Printf("5. next state for event success: %#v, error: %#v\n\n", nextNode, err)

    machine.Transition(context.TODO(), "checkout_requested")
    machine.Transition(context.TODO(), "payment_requested")
    nextNode, _ = machine.Transition(context.TODO(), "success")
    fmt.Printf("6. next state for event new success: %#v\n\n", nextNode)
}
ย