diff --git a/maya.py b/maya.py index 1b3360e..e7323b0 100644 --- a/maya.py +++ b/maya.py @@ -23,19 +23,26 @@ from tzlocal import get_localzone _EPOCH_START = (1970, 1, 1) -def validate_type_mayadt(func): +def validate_class_type_arguments(operator): """ Decorator to validate all the arguments to function - are of type `MayaDT` + are of the type of calling class """ - def inner(*args, **kwargs): - for arg in args + tuple(kwargs.values()): - if not isinstance(arg, MayaDT): - raise ValueError("Operation allowed only on object of type '{}'".format(MayaDT.__name__)) - return func(*args, **kwargs) + + def inner(function): + def wrapper(self, *args, **kwargs): + for arg in args + tuple(kwargs.values()): + if not isinstance(arg, self.__class__): + raise TypeError('unorderable types: {}() {} {}()'.format( + type(self).__name__, operator, type(arg).__name__)) + return function(self, *args, **kwargs) + + return wrapper + return inner + class MayaDT(object): """The Maya Datetime object.""" @@ -53,29 +60,30 @@ class MayaDT(object): """Return's the datetime's format""" return format(self.datetime(), *args, **kwargs) - @validate_type_mayadt + + @validate_class_type_arguments('==') def __eq__(self, maya_dt): return self._epoch == maya_dt._epoch - @validate_type_mayadt + @validate_class_type_arguments('!=') def __ne__(self, maya_dt): - return not self.__eq__(maya_dt) + return self._epoch != maya_dt._epoch - @validate_type_mayadt + @validate_class_type_arguments('<') def __lt__(self, maya_dt): return self._epoch < maya_dt._epoch - @validate_type_mayadt + @validate_class_type_arguments('<=') def __le__(self, maya_dt): - return self.__lt__(maya_dt) or self.__eq__(maya_dt) + return self._epoch <= maya_dt._epoch - @validate_type_mayadt + @validate_class_type_arguments('>') def __gt__(self, maya_dt): return self._epoch > maya_dt._epoch - @validate_type_mayadt + @validate_class_type_arguments('>=') def __ge__(self, maya_dt): - return self.__gt__(maya_dt) or self.__eq__(maya_dt) + return self._epoch >= maya_dt._epoch # Timezone Crap diff --git a/test_maya.py b/test_maya.py index ad1d8b7..99b2389 100644 --- a/test_maya.py +++ b/test_maya.py @@ -138,3 +138,17 @@ def test_comparison_operations(): assert (now >= now_copy) is True assert (now >= tomorrow) is False + + # Check Exceptions + with pytest.raises(TypeError): + now == 1 + with pytest.raises(TypeError): + now != 1 + with pytest.raises(TypeError): + now < 1 + with pytest.raises(TypeError): + now <= 1 + with pytest.raises(TypeError): + now > 1 + with pytest.raises(TypeError): + now >= 1