diff --git a/fsm.test.ts b/fsm.test.ts index bdac37c..ba3a8f6 100644 --- a/fsm.test.ts +++ b/fsm.test.ts @@ -75,10 +75,10 @@ Deno.test("should change state", async function () { assertEquals(sm.allowedTransitionStates(), [active, archived]); - await sm.changeState(ProjectStatus.Active); + await sm.tryChangeState(ProjectStatus.Active, null); assertEquals(sm.allowedTransitionStates(), [completed]); - await sm.changeState(ProjectStatus.Completed); + await sm.tryChangeState(ProjectStatus.Completed, null); assertEquals(sm.allowedTransitionStates(), []); }); @@ -111,14 +111,20 @@ Deno.test("should trigger state actions", async function () { assertEquals(triggeredTimes, { beforeExit: 0, onEntry: 0 }); assertEquals(sm.allowedTransitionStates(), [active, archived]); + assertEquals( + sm.allowedTransitionStateNames(), + [ProjectStatus.Active, ProjectStatus.Archived], + ); - await sm.changeState(ProjectStatus.Active); + await sm.tryChangeState(ProjectStatus.Active, null); assertEquals(triggeredTimes, { beforeExit: 1, onEntry: 1 }); assertEquals(sm.allowedTransitionStates(), [completed]); + assertEquals(sm.allowedTransitionStateNames(), [ProjectStatus.Completed]); - await sm.changeState(ProjectStatus.Completed); + await sm.tryChangeState(ProjectStatus.Completed, null); assertEquals(triggeredTimes, { beforeExit: 2, onEntry: 2 }); assertEquals(sm.allowedTransitionStates(), []); + assertEquals(sm.allowedTransitionStateNames(), []); }); Deno.test("should stringify state", function () { @@ -146,12 +152,22 @@ Deno.test("should throw error if transition to the state doesn't exist", () => { .withStates(Object.values(ProjectStatus)) .build(ProjectStatus.Pending); assertThrowsAsync( - () => sm.changeState(ProjectStatus.Active), + () => sm.tryChangeState(ProjectStatus.Active, null), fsm.FsmError, `cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`, ); }); +Deno.test("should return null if transition to the state doesn't exist", async () => { + const sm = new fsm.StateMachineBuilder() + .withStates(Object.values(ProjectStatus)) + .build(ProjectStatus.Pending); + assertEquals( + await sm.maybeChangeState(ProjectStatus.Active, null), + null, + ); +}); + Deno.test("should throw error if beforeExit action returns false", () => { const sm = new fsm.StateMachineBuilder() .withStates( @@ -163,7 +179,7 @@ Deno.test("should throw error if beforeExit action returns false", () => { ]) .build(ProjectStatus.Pending); assertThrowsAsync( - () => sm.changeState(ProjectStatus.Active), + () => sm.tryChangeState(ProjectStatus.Active, null), fsm.FsmError, `cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`, ); diff --git a/fsm.ts b/fsm.ts index 96c6422..db374c7 100644 --- a/fsm.ts +++ b/fsm.ts @@ -1,18 +1,19 @@ -type StateTransitions = WeakMap< - State, - WeakSet> +type StateTransitions = WeakMap< + State, + WeakSet> >; -type StateName = string; -type StateOrName = State | StateName; +type StateOrName = + | State + | StateName; export const _states = Symbol("states"); export const _stateTransitions = Symbol("state transitions"); export const _prevState = Symbol("previous state"); export const _currState = Symbol("current state"); -export class StateMachineBuilder { - [_states]: Map>; +export class StateMachineBuilder { + [_states]: Map>; [_stateTransitions]: Array<[StateName, Array]> | undefined; @@ -25,17 +26,20 @@ export class StateMachineBuilder { return this; } - withStates(names: StateName[], actions?: Actions) { + withStates(names: StateName[], actions?: Events) { names.forEach((name) => this.addStateUnchecked(name, actions)); return this; } - withState(name: StateName, actions?: Actions) { + withState(name: StateName, actions?: Events) { this.addStateUnchecked(name, actions); return this; } - private addStateUnchecked(name: StateName, actions?: Actions) { + private addStateUnchecked( + name: StateName, + actions?: Events, + ) { const oldActions = this[_states].get(name); return this[_states].set(name, { ...oldActions, ...actions }); } @@ -44,7 +48,7 @@ export class StateMachineBuilder { const states = this.buildStates(); const transitions = this.buildTransitions(states); const currState = validStateFromName(states, currentStateName); - return new StateMachine(states, transitions, currState); + return new StateMachine(states, transitions, currState); } private buildStates() { @@ -53,7 +57,7 @@ export class StateMachineBuilder { ); } - private buildTransitions(states: State[]) { + private buildTransitions(states: State[]) { const sourceTransitions = this[_stateTransitions] || []; return new WeakMap( @@ -65,28 +69,31 @@ export class StateMachineBuilder { } } -export class StateMachine { - [_states]: State[]; +export class StateMachine { + [_states]: State[]; - [_stateTransitions]: StateTransitions; + [_stateTransitions]: StateTransitions; - [_prevState]: State | undefined; + [_prevState]: State | undefined; - [_currState]: State; + [_currState]: State; constructor( - states: State[], - transitions: StateTransitions, - currentState: State, + states: State[], + transitions: StateTransitions, + currentState: State, ) { this[_states] = states; this[_stateTransitions] = transitions; this[_currState] = currentState; } - async changeState(sourceState: StateOrName, context?: Context) { - const fromState = validState(this[_currState]); - const toState = validNormalizedState(this[_states], sourceState); + async tryChangeState( + state: StateOrName, + context: Context, + ) { + const fromState = validState(this[_currState]); + const toState = validNormalizedState(this[_states], state); if ( !this.hasTransition(toState) || @@ -99,11 +106,17 @@ export class StateMachine { await toState.entry(fromState, toState, context); - this[_currState] = toState; this[_prevState] = fromState; + this[_currState] = toState; + + return this[_currState]; } - hasTransition(to: StateOrName) { + maybeChangeState(state: StateOrName, context: Context) { + return this.tryChangeState(state, context).catch(() => null); + } + + hasTransition(to: StateOrName) { return hasTransition( this[_stateTransitions], this[_currState], @@ -117,26 +130,30 @@ export class StateMachine { hasTransition.bind(null, this[_stateTransitions], fromState), ); } + + allowedTransitionStateNames() { + return this.allowedTransitionStates().map(String); + } } const _stateName = Symbol("state name"); -const _stateActions = Symbol("state actions"); +const _stateEvents = Symbol("state events"); -interface Actions { +interface Events { beforeExit?( - fromState: State, - toState: State, - context?: Context, + fromState: State, + toState: State, + context: Context, ): boolean; onEntry?( - fromState: State, - toState: State, - context?: Context, + fromState: State, + toState: State, + context: Context, ): Promise | void; } -export class State { - [_stateActions]: Actions; +export class State { + [_stateEvents]: Events; [_stateName]: StateName; @@ -144,25 +161,29 @@ export class State { return this[_stateName]; } - constructor(name: StateName, actions: Actions = {}) { + constructor(name: StateName, events: Events = {}) { this[_stateName] = name; - this[_stateActions] = actions; + this[_stateEvents] = events; } async entry( - fromState: State, - toState: State, - context?: Context, + fromState: State, + toState: State, + context: Context, ) { - const action = this[_stateActions].onEntry; - if (isFn(action)) { - await action(fromState, toState, context); + const event = this[_stateEvents].onEntry; + if (isFn(event)) { + await event(fromState, toState, context); } } - exit(fromState: State, toState: State, context: Context) { - const action = this[_stateActions].beforeExit; - return isFn(action) ? action(fromState, toState, context) : true; + exit( + fromState: State, + toState: State, + context: Context, + ) { + const event = this[_stateEvents].beforeExit; + return isFn(event) ? event(fromState, toState, context) : true; } toString() { @@ -174,46 +195,53 @@ export class State { } } -function validNormalizedState( - states: State[], - state: StateOrName, +function validNormalizedState( + states: State[], + state: StateOrName, ) { - return validState(normalizeState(states, state)); + return validState(normalizeState(states, state)); } -function normalizeState( - states: State[], - state: StateOrName, -): State | undefined { +function normalizeState( + states: State[], + state: StateOrName, +): State | undefined { return isStr(state) ? stateFromName(states, state) : state; } -function validStateFromName( - states: State[], +function validStateFromName( + states: State[], name: StateName, ) { - return validState(stateFromName(states, name)); + return validState(stateFromName(states, name)); } -function stateFromName(states: State[], name: StateName) { +function stateFromName( + states: State[], + name: StateName, +) { return states.find((state) => state.name === name); } -function validState(val: unknown): State { - if (!isState(val)) { +function validState( + val: unknown, +): State { + if (!isState(val)) { throw new TypeError("an instance of State class is expected"); } return val; } -function isState(val: unknown): val is State { +function isState( + val: unknown, +): val is State { return val instanceof State; } -function hasTransition( - transitions: StateTransitions, - from: State, - to: State, +function hasTransition( + transitions: StateTransitions, + from: State, + to: State, ) { return transitions.get(from)?.has(to) || false; }