Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
linkous8 committed Nov 14, 2023
1 parent 7c6f90b commit d882c35
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 29 deletions.
18 changes: 9 additions & 9 deletions statesman.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ class Guard(str, enum.Enum):
warning = "warning"
exception = "exception"

class StateMachineConfig:
state_entry = Entry.allow
guard_with = Guard.silence

class StateMachine(pydantic.BaseModel):
"""StateMachine objects model state machines comprised of states, events,
Expand All @@ -378,6 +381,7 @@ class StateMachine(pydantic.BaseModel):
_state: Optional[State] = pydantic.PrivateAttr(None)
_states: List[State] = pydantic.PrivateAttr([])
_events: List[Event] = pydantic.PrivateAttr([])
_config: StateMachineConfig = pydantic.PrivateAttr(StateMachineConfig())

def __init__(
self,
Expand Down Expand Up @@ -736,7 +740,7 @@ async def enter_state(
LookupError: Raised if the state cannot be found by name or enum value.
TypeError: Raised if the state value given is not a State, StateEnum, or str object.
"""
state_entry = self.__config__.state_entry
state_entry = self._config.state_entry
if state_entry == Entry.allow:
pass
elif state_entry == Entry.initial:
Expand Down Expand Up @@ -833,10 +837,6 @@ async def after_transition(self, transition: 'Transition', *args, **kwargs) -> N
def __repr_args__(self) -> pydantic.ReprArgs:
return [('states', self.states), ('events', self.events), ('state', self.state)]

class Config:
state_entry = Entry.allow
guard_with = Guard.silence


# The types that transition results can be represented as.
Result = TypeVar('Result', bool, object, tuple, list, 'Transition')
Expand Down Expand Up @@ -924,7 +924,7 @@ async def __call__(self, *args, **kwargs) -> bool:
# Guards can cancel the transition via return value or failed assertion
self.cancelled = False
self.succeeded = False
guard_with = self.state_machine.__config__.guard_with
guard_with = self.state_machine._config.guard_with
try:
result = await _call_with_matching_parameters(self.state_machine.guard_transition, self, *args, **kwargs)
if result not in (True, False, None):
Expand Down Expand Up @@ -1483,9 +1483,9 @@ def _state_entry(
) -> Iterator[StateMachine]:
"""Temporarily override the value of the `entry` setting on a StateMachine object.
"""
original = obj.__config__.state_entry
obj.__config__.state_entry = entry
original = obj._config.state_entry
obj._config.state_entry = entry
try:
yield obj
finally:
obj.__config__.state_entry = original
obj._config.state_entry = original
30 changes: 10 additions & 20 deletions statesman_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ class TestEntryConfig:
@pytest.mark.asyncio
async def test_allow(self) -> None:
state_machine = statesman.StateMachine(states=statesman.State.from_enum(States))
state_machine.__config__.state_entry = statesman.Entry.allow
state_machine._config.state_entry = statesman.Entry.allow
assert state_machine.state is None
assert await state_machine.enter_state(States.starting)
assert state_machine.state == States.starting
Expand All @@ -584,7 +584,7 @@ async def test_allow(self) -> None:
async def test_initial(self) -> None:
# Enter once for initial, then raise on next try
state_machine = statesman.StateMachine(states=statesman.State.from_enum(States))
state_machine.__config__.state_entry = statesman.Entry.initial
state_machine._config.state_entry = statesman.Entry.initial
assert state_machine.state is None
assert await state_machine.enter_state(States.starting)
assert state_machine.state == States.starting
Expand All @@ -596,7 +596,7 @@ async def test_initial(self) -> None:
async def test_ignore(self) -> None:
# Return false every time
state_machine = statesman.StateMachine(states=statesman.State.from_enum(States))
state_machine.__config__.state_entry = statesman.Entry.ignore
state_machine._config.state_entry = statesman.Entry.ignore
assert state_machine.state is None
assert not await state_machine.enter_state(States.starting)
assert state_machine.state is None
Expand All @@ -608,7 +608,7 @@ async def test_ignore(self) -> None:
@pytest.mark.asyncio
async def test_forbid(self) -> None:
state_machine = statesman.StateMachine(states=statesman.State.from_enum(States))
state_machine.__config__.state_entry = statesman.Entry.forbid
state_machine._config.state_entry = statesman.Entry.forbid
assert state_machine.state is None
with pytest.raises(RuntimeError, match="state entry failed: use of the `enter_state` method is forbidden"):
assert await state_machine.enter_state(States.starting)
Expand Down Expand Up @@ -893,7 +893,7 @@ async def test_non_assertion_errors_raise(self, state_machine: statesman.StateMa

@pytest.mark.asyncio
async def test_guard_with_silence(self, state_machine: statesman.StateMachine, mocker) -> None:
state_machine.__config__.guard_with = statesman.Guard.silence
state_machine._config.guard_with = statesman.Guard.silence
await state_machine.enter_state(States.starting)
with extra(state_machine):
guard_mock = mocker.patch.object(state_machine, 'guard_transition')
Expand All @@ -904,7 +904,7 @@ async def test_guard_with_silence(self, state_machine: statesman.StateMachine, m

@pytest.mark.asyncio
async def test_guard_with_warning(self, state_machine: statesman.StateMachine, mocker) -> None:
state_machine.__config__.guard_with = statesman.Guard.warning
state_machine._config.guard_with = statesman.Guard.warning
await state_machine.enter_state(States.starting)
with extra(state_machine):
guard_mock = mocker.patch.object(state_machine, 'guard_transition')
Expand All @@ -914,7 +914,7 @@ async def test_guard_with_warning(self, state_machine: statesman.StateMachine, m

@pytest.mark.asyncio
async def test_guard_with_exception(self, state_machine: statesman.StateMachine, mocker) -> None:
state_machine.__config__.guard_with = statesman.Guard.exception
state_machine._config.guard_with = statesman.Guard.exception
await state_machine.enter_state(States.starting)
with extra(state_machine):
guard_mock = mocker.patch.object(state_machine, 'guard_transition')
Expand Down Expand Up @@ -999,7 +999,7 @@ async def test_guard_actions_run_sequentially(self, state_machine: statesman.Sta

@pytest.mark.asyncio
async def test_guard_with_silence(self, state_machine: statesman.StateMachine, mocker) -> None:
state_machine.__config__.guard_with = statesman.Guard.silence
state_machine._config.guard_with = statesman.Guard.silence
await state_machine.enter_state(States.starting)
event = state_machine.get_event('finish')
guard_action = mocker.stub(name='action')
Expand All @@ -1010,7 +1010,7 @@ async def test_guard_with_silence(self, state_machine: statesman.StateMachine, m

@pytest.mark.asyncio
async def test_guard_with_warning(self, state_machine: statesman.StateMachine, mocker) -> None:
state_machine.__config__.guard_with = statesman.Guard.warning
state_machine._config.guard_with = statesman.Guard.warning
await state_machine.enter_state(States.starting)
event = state_machine.get_event('finish')
guard_action = mocker.stub(name='action')
Expand All @@ -1022,7 +1022,7 @@ async def test_guard_with_warning(self, state_machine: statesman.StateMachine, m

@pytest.mark.asyncio
async def test_guard_with_exception(self, state_machine: statesman.StateMachine, mocker) -> None:
state_machine.__config__.guard_with = statesman.Guard.exception
state_machine._config.guard_with = statesman.Guard.exception
await state_machine.enter_state(States.starting)
event = state_machine.get_event('finish')
guard_action = mocker.stub(name='action')
Expand Down Expand Up @@ -1320,16 +1320,6 @@ class States(statesman.StateEnum):
three = '3'
four = statesman.InitialState('4')

@pytest.fixture(autouse=True)
async def reset_config() -> None:
state_entry = statesman.StateMachine.__config__.state_entry
guard_with = statesman.StateMachine.__config__.guard_with
try:
yield
finally:
statesman.StateMachine.__config__.state_entry = state_entry
statesman.StateMachine.__config__.guard_with = guard_with


@contextlib.contextmanager
def extra(
Expand Down

0 comments on commit d882c35

Please sign in to comment.