refac: add more generic to improve design

This commit is contained in:
Dmitriy Pleshevskiy 2021-08-21 08:55:52 +03:00
parent b9bbe16b6b
commit 67e07c79e6
2 changed files with 115 additions and 71 deletions

View file

@ -75,10 +75,10 @@ Deno.test("should change state", async function () {
assertEquals(sm.allowedTransitionStates(), [active, archived]); assertEquals(sm.allowedTransitionStates(), [active, archived]);
await sm.changeState(ProjectStatus.Active); await sm.tryChangeState(ProjectStatus.Active, null);
assertEquals(sm.allowedTransitionStates(), [completed]); assertEquals(sm.allowedTransitionStates(), [completed]);
await sm.changeState(ProjectStatus.Completed); await sm.tryChangeState(ProjectStatus.Completed, null);
assertEquals(sm.allowedTransitionStates(), []); assertEquals(sm.allowedTransitionStates(), []);
}); });
@ -111,14 +111,20 @@ Deno.test("should trigger state actions", async function () {
assertEquals(triggeredTimes, { beforeExit: 0, onEntry: 0 }); assertEquals(triggeredTimes, { beforeExit: 0, onEntry: 0 });
assertEquals(sm.allowedTransitionStates(), [active, archived]); assertEquals(sm.allowedTransitionStates(), [active, archived]);
assertEquals(
sm.allowedTransitionStateNames(),
[ProjectStatus.Active, ProjectStatus.Archived],
);
await sm.changeState(ProjectStatus.Active); await sm.tryChangeState(ProjectStatus.Active, null);
assertEquals(triggeredTimes, { beforeExit: 1, onEntry: 1 }); assertEquals(triggeredTimes, { beforeExit: 1, onEntry: 1 });
assertEquals(sm.allowedTransitionStates(), [completed]); assertEquals(sm.allowedTransitionStates(), [completed]);
assertEquals(sm.allowedTransitionStateNames(), [ProjectStatus.Completed]);
await sm.changeState(ProjectStatus.Completed); await sm.tryChangeState(ProjectStatus.Completed, null);
assertEquals(triggeredTimes, { beforeExit: 2, onEntry: 2 }); assertEquals(triggeredTimes, { beforeExit: 2, onEntry: 2 });
assertEquals(sm.allowedTransitionStates(), []); assertEquals(sm.allowedTransitionStates(), []);
assertEquals(sm.allowedTransitionStateNames(), []);
}); });
Deno.test("should stringify state", function () { Deno.test("should stringify state", function () {
@ -146,12 +152,22 @@ Deno.test("should throw error if transition to the state doesn't exist", () => {
.withStates(Object.values(ProjectStatus)) .withStates(Object.values(ProjectStatus))
.build(ProjectStatus.Pending); .build(ProjectStatus.Pending);
assertThrowsAsync( assertThrowsAsync(
() => sm.changeState(ProjectStatus.Active), () => sm.tryChangeState(ProjectStatus.Active, null),
fsm.FsmError, fsm.FsmError,
`cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`, `cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`,
); );
}); });
Deno.test("should return null if transition to the state doesn't exist", async () => {
const sm = new fsm.StateMachineBuilder()
.withStates(Object.values(ProjectStatus))
.build(ProjectStatus.Pending);
assertEquals(
await sm.maybeChangeState(ProjectStatus.Active, null),
null,
);
});
Deno.test("should throw error if beforeExit action returns false", () => { Deno.test("should throw error if beforeExit action returns false", () => {
const sm = new fsm.StateMachineBuilder() const sm = new fsm.StateMachineBuilder()
.withStates( .withStates(
@ -163,7 +179,7 @@ Deno.test("should throw error if beforeExit action returns false", () => {
]) ])
.build(ProjectStatus.Pending); .build(ProjectStatus.Pending);
assertThrowsAsync( assertThrowsAsync(
() => sm.changeState(ProjectStatus.Active), () => sm.tryChangeState(ProjectStatus.Active, null),
fsm.FsmError, fsm.FsmError,
`cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`, `cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`,
); );

158
fsm.ts
View file

@ -1,18 +1,19 @@
type StateTransitions<Context> = WeakMap< type StateTransitions<Context, StateName extends string> = WeakMap<
State<Context>, State<Context, StateName>,
WeakSet<State<Context>> WeakSet<State<Context, StateName>>
>; >;
type StateName = string; type StateOrName<Context, StateName extends string> =
type StateOrName<Context> = State<Context> | StateName; | State<Context, StateName>
| StateName;
export const _states = Symbol("states"); export const _states = Symbol("states");
export const _stateTransitions = Symbol("state transitions"); export const _stateTransitions = Symbol("state transitions");
export const _prevState = Symbol("previous state"); export const _prevState = Symbol("previous state");
export const _currState = Symbol("current state"); export const _currState = Symbol("current state");
export class StateMachineBuilder<Context> { export class StateMachineBuilder<Context, StateName extends string = string> {
[_states]: Map<StateName, Actions<Context>>; [_states]: Map<StateName, Events<Context, StateName>>;
[_stateTransitions]: Array<[StateName, Array<StateName>]> | undefined; [_stateTransitions]: Array<[StateName, Array<StateName>]> | undefined;
@ -25,17 +26,20 @@ export class StateMachineBuilder<Context> {
return this; return this;
} }
withStates(names: StateName[], actions?: Actions<Context>) { withStates(names: StateName[], actions?: Events<Context, StateName>) {
names.forEach((name) => this.addStateUnchecked(name, actions)); names.forEach((name) => this.addStateUnchecked(name, actions));
return this; return this;
} }
withState(name: StateName, actions?: Actions<Context>) { withState(name: StateName, actions?: Events<Context, StateName>) {
this.addStateUnchecked(name, actions); this.addStateUnchecked(name, actions);
return this; return this;
} }
private addStateUnchecked(name: StateName, actions?: Actions<Context>) { private addStateUnchecked(
name: StateName,
actions?: Events<Context, StateName>,
) {
const oldActions = this[_states].get(name); const oldActions = this[_states].get(name);
return this[_states].set(name, { ...oldActions, ...actions }); return this[_states].set(name, { ...oldActions, ...actions });
} }
@ -44,7 +48,7 @@ export class StateMachineBuilder<Context> {
const states = this.buildStates(); const states = this.buildStates();
const transitions = this.buildTransitions(states); const transitions = this.buildTransitions(states);
const currState = validStateFromName(states, currentStateName); const currState = validStateFromName(states, currentStateName);
return new StateMachine(states, transitions, currState); return new StateMachine<Context, StateName>(states, transitions, currState);
} }
private buildStates() { private buildStates() {
@ -53,7 +57,7 @@ export class StateMachineBuilder<Context> {
); );
} }
private buildTransitions(states: State<Context>[]) { private buildTransitions(states: State<Context, StateName>[]) {
const sourceTransitions = this[_stateTransitions] || []; const sourceTransitions = this[_stateTransitions] || [];
return new WeakMap( return new WeakMap(
@ -65,28 +69,31 @@ export class StateMachineBuilder<Context> {
} }
} }
export class StateMachine<Context> { export class StateMachine<Context, StateName extends string = string> {
[_states]: State<Context>[]; [_states]: State<Context, StateName>[];
[_stateTransitions]: StateTransitions<Context>; [_stateTransitions]: StateTransitions<Context, StateName>;
[_prevState]: State<Context> | undefined; [_prevState]: State<Context, StateName> | undefined;
[_currState]: State<Context>; [_currState]: State<Context, StateName>;
constructor( constructor(
states: State<Context>[], states: State<Context, StateName>[],
transitions: StateTransitions<Context>, transitions: StateTransitions<Context, StateName>,
currentState: State<Context>, currentState: State<Context, StateName>,
) { ) {
this[_states] = states; this[_states] = states;
this[_stateTransitions] = transitions; this[_stateTransitions] = transitions;
this[_currState] = currentState; this[_currState] = currentState;
} }
async changeState(sourceState: StateOrName<Context>, context?: Context) { async tryChangeState(
const fromState = validState(this[_currState]); state: StateOrName<Context, StateName>,
const toState = validNormalizedState(this[_states], sourceState); context: Context,
) {
const fromState = validState<Context, StateName>(this[_currState]);
const toState = validNormalizedState(this[_states], state);
if ( if (
!this.hasTransition(toState) || !this.hasTransition(toState) ||
@ -99,11 +106,17 @@ export class StateMachine<Context> {
await toState.entry(fromState, toState, context); await toState.entry(fromState, toState, context);
this[_currState] = toState;
this[_prevState] = fromState; this[_prevState] = fromState;
this[_currState] = toState;
return this[_currState];
} }
hasTransition(to: StateOrName<Context>) { maybeChangeState(state: StateOrName<Context, StateName>, context: Context) {
return this.tryChangeState(state, context).catch(() => null);
}
hasTransition(to: StateOrName<Context, StateName>) {
return hasTransition( return hasTransition(
this[_stateTransitions], this[_stateTransitions],
this[_currState], this[_currState],
@ -117,26 +130,30 @@ export class StateMachine<Context> {
hasTransition.bind(null, this[_stateTransitions], fromState), hasTransition.bind(null, this[_stateTransitions], fromState),
); );
} }
allowedTransitionStateNames() {
return this.allowedTransitionStates().map(String);
}
} }
const _stateName = Symbol("state name"); const _stateName = Symbol("state name");
const _stateActions = Symbol("state actions"); const _stateEvents = Symbol("state events");
interface Actions<Context> { interface Events<Context, StateName extends string> {
beforeExit?( beforeExit?(
fromState: State<Context>, fromState: State<Context, StateName>,
toState: State<Context>, toState: State<Context, StateName>,
context?: Context, context: Context,
): boolean; ): boolean;
onEntry?( onEntry?(
fromState: State<Context>, fromState: State<Context, StateName>,
toState: State<Context>, toState: State<Context, StateName>,
context?: Context, context: Context,
): Promise<void> | void; ): Promise<void> | void;
} }
export class State<Context> { export class State<Context, StateName extends string = string> {
[_stateActions]: Actions<Context>; [_stateEvents]: Events<Context, StateName>;
[_stateName]: StateName; [_stateName]: StateName;
@ -144,25 +161,29 @@ export class State<Context> {
return this[_stateName]; return this[_stateName];
} }
constructor(name: StateName, actions: Actions<Context> = {}) { constructor(name: StateName, events: Events<Context, StateName> = {}) {
this[_stateName] = name; this[_stateName] = name;
this[_stateActions] = actions; this[_stateEvents] = events;
} }
async entry( async entry(
fromState: State<Context>, fromState: State<Context, StateName>,
toState: State<Context>, toState: State<Context, StateName>,
context?: Context, context: Context,
) { ) {
const action = this[_stateActions].onEntry; const event = this[_stateEvents].onEntry;
if (isFn(action)) { if (isFn(event)) {
await action(fromState, toState, context); await event(fromState, toState, context);
} }
} }
exit(fromState: State<Context>, toState: State<Context>, context: Context) { exit(
const action = this[_stateActions].beforeExit; fromState: State<Context, StateName>,
return isFn(action) ? action(fromState, toState, context) : true; toState: State<Context, StateName>,
context: Context,
) {
const event = this[_stateEvents].beforeExit;
return isFn(event) ? event(fromState, toState, context) : true;
} }
toString() { toString() {
@ -174,46 +195,53 @@ export class State<Context> {
} }
} }
function validNormalizedState<Context>( function validNormalizedState<Context, StateName extends string>(
states: State<Context>[], states: State<Context, StateName>[],
state: StateOrName<Context>, state: StateOrName<Context, StateName>,
) { ) {
return validState<Context>(normalizeState(states, state)); return validState<Context, StateName>(normalizeState(states, state));
} }
function normalizeState<Context>( function normalizeState<Context, StateName extends string>(
states: State<Context>[], states: State<Context, StateName>[],
state: StateOrName<Context>, state: StateOrName<Context, StateName>,
): State<Context> | undefined { ): State<Context, StateName> | undefined {
return isStr(state) ? stateFromName(states, state) : state; return isStr(state) ? stateFromName(states, state) : state;
} }
function validStateFromName<Context>( function validStateFromName<Context, StateName extends string>(
states: State<Context>[], states: State<Context, StateName>[],
name: StateName, name: StateName,
) { ) {
return validState<Context>(stateFromName(states, name)); return validState<Context, StateName>(stateFromName(states, name));
} }
function stateFromName<Context>(states: State<Context>[], name: StateName) { function stateFromName<Context, StateName extends string>(
states: State<Context, StateName>[],
name: StateName,
) {
return states.find((state) => state.name === name); return states.find((state) => state.name === name);
} }
function validState<Context>(val: unknown): State<Context> { function validState<Context, StateName extends string>(
if (!isState<Context>(val)) { val: unknown,
): State<Context, StateName> {
if (!isState<Context, StateName>(val)) {
throw new TypeError("an instance of State class is expected"); throw new TypeError("an instance of State class is expected");
} }
return val; return val;
} }
function isState<Context>(val: unknown): val is State<Context> { function isState<Context, StateName extends string>(
val: unknown,
): val is State<Context, StateName> {
return val instanceof State; return val instanceof State;
} }
function hasTransition<Context>( function hasTransition<Context, StateName extends string>(
transitions: StateTransitions<Context>, transitions: StateTransitions<Context, StateName>,
from: State<Context>, from: State<Context, StateName>,
to: State<Context>, to: State<Context, StateName>,
) { ) {
return transitions.get(from)?.has(to) || false; return transitions.get(from)?.has(to) || false;
} }