def statemethod(method):
def call_statemethod(self, *args, **kwargs):
# Use self.state.<method> if available, else method itself.
real_method = getattr(self.state, method.func_name, method)
return real_method(self, *args, **kwargs)
call_statemethod.default = method
return call_statemethod
# Sample usage:
class State(object):
"""Base State class, direct parent to non-instantiated states.
Useful when you have lots of base objects and don't need to store
per-state data."""
@classmethod
def new(cls):
"""Create a new Base object with this as the initial state."""
return Base(cls.get_state())
@classmethod
def get_state(cls):
"""Get the state, for use with an existing Base object"""
return cls
class InstantiatedState(State):
"""InstantiatedState creates a new object every time get_state is called.
This allows for independant per-state data storage by multiple base
objects."""
@classmethod
def get_state(cls):
"""Get a state object, for use with an existing Base object"""
return cls()
class Base(object):
def __init__(self, initial_state):
self.state = initial_state
def ordinary_method(self):
print "This method is ordinary."
@statemethod
def default_method(self):
print "This is a default method that has not been overridden."
@statemethod
def overridden_method(self):
print "You shouldn't see this."
assert False
class SimpleState(State):
@staticmethod
def overridden_method(base):
print "The method on %r has been overridden by SimpleState." % base
class DataState(InstantiatedState):
message = "Awesome."
def overridden_method(self, base):
print "This method on %r has been overridden by DataState. %s" \
% (base, self.message)
print "Base A"
print "======"
base_a = SimpleState.new()
print "Calling default_method:"
base_a.default_method()
print "Calling overridden_method:"
base_a.overridden_method()
print "Switching to DataState."
base_a.state = DataState.get_state()
print "Calling overridden_method:"
base_a.overridden_method()
print "Changing message."
base_a.state.message = "Excellent."
print "Calling overridden_method:"
base_a.overridden_method()
print
print "Base B"
print "======"
base_b = DataState.new()
print "Calling default_method:"
base_b.default_method()
print "Calling overridden_method:"
base_b.overridden_method()