refac: add more generic to improve design

This commit is contained in:
Dmitriy Pleshevskiy 2021-08-21 08:55:52 +03:00
parent b9bbe16b6b
commit 67e07c79e6
2 changed files with 115 additions and 71 deletions

View file

@ -75,10 +75,10 @@ Deno.test("should change state", async function () {
assertEquals(sm.allowedTransitionStates(), [active, archived]);
await sm.changeState(ProjectStatus.Active);
await sm.tryChangeState(ProjectStatus.Active, null);
assertEquals(sm.allowedTransitionStates(), [completed]);
await sm.changeState(ProjectStatus.Completed);
await sm.tryChangeState(ProjectStatus.Completed, null);
assertEquals(sm.allowedTransitionStates(), []);
});
@ -111,14 +111,20 @@ Deno.test("should trigger state actions", async function () {
assertEquals(triggeredTimes, { beforeExit: 0, onEntry: 0 });
assertEquals(sm.allowedTransitionStates(), [active, archived]);
assertEquals(
sm.allowedTransitionStateNames(),
[ProjectStatus.Active, ProjectStatus.Archived],
);
await sm.changeState(ProjectStatus.Active);
await sm.tryChangeState(ProjectStatus.Active, null);
assertEquals(triggeredTimes, { beforeExit: 1, onEntry: 1 });
assertEquals(sm.allowedTransitionStates(), [completed]);
assertEquals(sm.allowedTransitionStateNames(), [ProjectStatus.Completed]);
await sm.changeState(ProjectStatus.Completed);
await sm.tryChangeState(ProjectStatus.Completed, null);
assertEquals(triggeredTimes, { beforeExit: 2, onEntry: 2 });
assertEquals(sm.allowedTransitionStates(), []);
assertEquals(sm.allowedTransitionStateNames(), []);
});
Deno.test("should stringify state", function () {
@ -146,12 +152,22 @@ Deno.test("should throw error if transition to the state doesn't exist", () => {
.withStates(Object.values(ProjectStatus))
.build(ProjectStatus.Pending);
assertThrowsAsync(
() => sm.changeState(ProjectStatus.Active),
() => sm.tryChangeState(ProjectStatus.Active, null),
fsm.FsmError,
`cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`,
);
});
Deno.test("should return null if transition to the state doesn't exist", async () => {
const sm = new fsm.StateMachineBuilder()
.withStates(Object.values(ProjectStatus))
.build(ProjectStatus.Pending);
assertEquals(
await sm.maybeChangeState(ProjectStatus.Active, null),
null,
);
});
Deno.test("should throw error if beforeExit action returns false", () => {
const sm = new fsm.StateMachineBuilder()
.withStates(
@ -163,7 +179,7 @@ Deno.test("should throw error if beforeExit action returns false", () => {
])
.build(ProjectStatus.Pending);
assertThrowsAsync(
() => sm.changeState(ProjectStatus.Active),
() => sm.tryChangeState(ProjectStatus.Active, null),
fsm.FsmError,
`cannot change state from "${ProjectStatus.Pending}" to "${ProjectStatus.Active}"`,
);

158
fsm.ts
View file

@ -1,18 +1,19 @@
type StateTransitions<Context> = WeakMap<
State<Context>,
WeakSet<State<Context>>
type StateTransitions<Context, StateName extends string> = WeakMap<
State<Context, StateName>,
WeakSet<State<Context, StateName>>
>;
type StateName = string;
type StateOrName<Context> = State<Context> | StateName;
type StateOrName<Context, StateName extends string> =
| State<Context, StateName>
| StateName;
export const _states = Symbol("states");
export const _stateTransitions = Symbol("state transitions");
export const _prevState = Symbol("previous state");
export const _currState = Symbol("current state");
export class StateMachineBuilder<Context> {
[_states]: Map<StateName, Actions<Context>>;
export class StateMachineBuilder<Context, StateName extends string = string> {
[_states]: Map<StateName, Events<Context, StateName>>;
[_stateTransitions]: Array<[StateName, Array<StateName>]> | undefined;
@ -25,17 +26,20 @@ export class StateMachineBuilder<Context> {
return this;
}
withStates(names: StateName[], actions?: Actions<Context>) {
withStates(names: StateName[], actions?: Events<Context, StateName>) {
names.forEach((name) => this.addStateUnchecked(name, actions));
return this;
}
withState(name: StateName, actions?: Actions<Context>) {
withState(name: StateName, actions?: Events<Context, StateName>) {
this.addStateUnchecked(name, actions);
return this;
}
private addStateUnchecked(name: StateName, actions?: Actions<Context>) {
private addStateUnchecked(
name: StateName,
actions?: Events<Context, StateName>,
) {
const oldActions = this[_states].get(name);
return this[_states].set(name, { ...oldActions, ...actions });
}
@ -44,7 +48,7 @@ export class StateMachineBuilder<Context> {
const states = this.buildStates();
const transitions = this.buildTransitions(states);
const currState = validStateFromName(states, currentStateName);
return new StateMachine(states, transitions, currState);
return new StateMachine<Context, StateName>(states, transitions, currState);
}
private buildStates() {
@ -53,7 +57,7 @@ export class StateMachineBuilder<Context> {
);
}
private buildTransitions(states: State<Context>[]) {
private buildTransitions(states: State<Context, StateName>[]) {
const sourceTransitions = this[_stateTransitions] || [];
return new WeakMap(
@ -65,28 +69,31 @@ export class StateMachineBuilder<Context> {
}
}
export class StateMachine<Context> {
[_states]: State<Context>[];
export class StateMachine<Context, StateName extends string = string> {
[_states]: State<Context, StateName>[];
[_stateTransitions]: StateTransitions<Context>;
[_stateTransitions]: StateTransitions<Context, StateName>;
[_prevState]: State<Context> | undefined;
[_prevState]: State<Context, StateName> | undefined;
[_currState]: State<Context>;
[_currState]: State<Context, StateName>;
constructor(
states: State<Context>[],
transitions: StateTransitions<Context>,
currentState: State<Context>,
states: State<Context, StateName>[],
transitions: StateTransitions<Context, StateName>,
currentState: State<Context, StateName>,
) {
this[_states] = states;
this[_stateTransitions] = transitions;
this[_currState] = currentState;
}
async changeState(sourceState: StateOrName<Context>, context?: Context) {
const fromState = validState(this[_currState]);
const toState = validNormalizedState(this[_states], sourceState);
async tryChangeState(
state: StateOrName<Context, StateName>,
context: Context,
) {
const fromState = validState<Context, StateName>(this[_currState]);
const toState = validNormalizedState(this[_states], state);
if (
!this.hasTransition(toState) ||
@ -99,11 +106,17 @@ export class StateMachine<Context> {
await toState.entry(fromState, toState, context);
this[_currState] = toState;
this[_prevState] = fromState;
this[_currState] = toState;
return this[_currState];
}
hasTransition(to: StateOrName<Context>) {
maybeChangeState(state: StateOrName<Context, StateName>, context: Context) {
return this.tryChangeState(state, context).catch(() => null);
}
hasTransition(to: StateOrName<Context, StateName>) {
return hasTransition(
this[_stateTransitions],
this[_currState],
@ -117,26 +130,30 @@ export class StateMachine<Context> {
hasTransition.bind(null, this[_stateTransitions], fromState),
);
}
allowedTransitionStateNames() {
return this.allowedTransitionStates().map(String);
}
}
const _stateName = Symbol("state name");
const _stateActions = Symbol("state actions");
const _stateEvents = Symbol("state events");
interface Actions<Context> {
interface Events<Context, StateName extends string> {
beforeExit?(
fromState: State<Context>,
toState: State<Context>,
context?: Context,
fromState: State<Context, StateName>,
toState: State<Context, StateName>,
context: Context,
): boolean;
onEntry?(
fromState: State<Context>,
toState: State<Context>,
context?: Context,
fromState: State<Context, StateName>,
toState: State<Context, StateName>,
context: Context,
): Promise<void> | void;
}
export class State<Context> {
[_stateActions]: Actions<Context>;
export class State<Context, StateName extends string = string> {
[_stateEvents]: Events<Context, StateName>;
[_stateName]: StateName;
@ -144,25 +161,29 @@ export class State<Context> {
return this[_stateName];
}
constructor(name: StateName, actions: Actions<Context> = {}) {
constructor(name: StateName, events: Events<Context, StateName> = {}) {
this[_stateName] = name;
this[_stateActions] = actions;
this[_stateEvents] = events;
}
async entry(
fromState: State<Context>,
toState: State<Context>,
context?: Context,
fromState: State<Context, StateName>,
toState: State<Context, StateName>,
context: Context,
) {
const action = this[_stateActions].onEntry;
if (isFn(action)) {
await action(fromState, toState, context);
const event = this[_stateEvents].onEntry;
if (isFn(event)) {
await event(fromState, toState, context);
}
}
exit(fromState: State<Context>, toState: State<Context>, context: Context) {
const action = this[_stateActions].beforeExit;
return isFn(action) ? action(fromState, toState, context) : true;
exit(
fromState: State<Context, StateName>,
toState: State<Context, StateName>,
context: Context,
) {
const event = this[_stateEvents].beforeExit;
return isFn(event) ? event(fromState, toState, context) : true;
}
toString() {
@ -174,46 +195,53 @@ export class State<Context> {
}
}
function validNormalizedState<Context>(
states: State<Context>[],
state: StateOrName<Context>,
function validNormalizedState<Context, StateName extends string>(
states: State<Context, StateName>[],
state: StateOrName<Context, StateName>,
) {
return validState<Context>(normalizeState(states, state));
return validState<Context, StateName>(normalizeState(states, state));
}
function normalizeState<Context>(
states: State<Context>[],
state: StateOrName<Context>,
): State<Context> | undefined {
function normalizeState<Context, StateName extends string>(
states: State<Context, StateName>[],
state: StateOrName<Context, StateName>,
): State<Context, StateName> | undefined {
return isStr(state) ? stateFromName(states, state) : state;
}
function validStateFromName<Context>(
states: State<Context>[],
function validStateFromName<Context, StateName extends string>(
states: State<Context, StateName>[],
name: StateName,
) {
return validState<Context>(stateFromName(states, name));
return validState<Context, StateName>(stateFromName(states, name));
}
function stateFromName<Context>(states: State<Context>[], name: StateName) {
function stateFromName<Context, StateName extends string>(
states: State<Context, StateName>[],
name: StateName,
) {
return states.find((state) => state.name === name);
}
function validState<Context>(val: unknown): State<Context> {
if (!isState<Context>(val)) {
function validState<Context, StateName extends string>(
val: unknown,
): State<Context, StateName> {
if (!isState<Context, StateName>(val)) {
throw new TypeError("an instance of State class is expected");
}
return val;
}
function isState<Context>(val: unknown): val is State<Context> {
function isState<Context, StateName extends string>(
val: unknown,
): val is State<Context, StateName> {
return val instanceof State;
}
function hasTransition<Context>(
transitions: StateTransitions<Context>,
from: State<Context>,
to: State<Context>,
function hasTransition<Context, StateName extends string>(
transitions: StateTransitions<Context, StateName>,
from: State<Context, StateName>,
to: State<Context, StateName>,
) {
return transitions.get(from)?.has(to) || false;
}