Optional
attribute hell
Summary
Smell
A class has several optional fields which are linked together
Why is it bad?
The class contains invariants that aren't possible to capture at the type level. It makes it harder and more error-prone to work with the class. Often these invariants aren't explicitly documented, which doesn't help.
Solution
Represent each of the possible states of the class as its own thing.
Sometimes inheritance works better, sometimes you're better off with a Union
.
Example
Suppose that you need to model an arithmetic expression, such as
And then evaluate it, given a dict mapping variable names to integers.A straight-forward way to represent an expression is to make a syntax tree:
Suppose then that you come up with this solution:
from enum import Enum
from collections.abc import Mapping
from dataclasses import dataclass
class Operation(Enum):
add = "add"
subtract = "subtract"
multiply = "multiply"
@dataclass
class Node:
line: int
column: int
variable_name: Optional[str]
constant: Optional[int]
operation: Optional[Operation]
left_operand: Optional["Node"]
right_operand: Optional["Node"]
def evaluate(node: Node, variables: Mapping[str, int]) -> int:
if node.constant is not None:
return node.constant
elif node.variable_name is not None:
number = variables.get(node.variable_name)
if number is None:
raise LookupError(node.variable_name)
return number
elif node.operation is not None:
assert node.left_operand is not None
assert node.right_operand is not None
left_result = evaluate(node.left_operand, variables)
right_result = evaluate(node.right_operand, variables)
if node.operation is Operation.add:
return left_result + right_result
elif node.operation is Operation.subtract:
return left_result - right_result
elif node.operation is Operation.multiply:
return left_result * right_result
else:
assert False
else:
assert False
Nothing about this is right. This Node
class has some complex invariants that we need to respect:
- At least one of
variable_name
,constant
,operation
must be non-None
- If
variable_name
is notNone
,constant
andoperation
must beNone
- If
constant
is notNone
,variable_name
andoperation
must beNone
- If
operation
is notNone
,variable_name
andconstant
must beNone
left_operand
andright_operand
areNone
if and only ifoperation
isNone
If we break some invariants (like the first one), we'll get a runtime error.
If we break others (like setting both constant
and operation
), the program will be silently wrong.
One benefit of types is that they help us eliminate invalid states. So let's leverage that help.
Solution
One possible solution is to make a Node
base class and take advantage of polymorphism.
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
Vars = Mapping[str, int]
@dataclass(frozen=True)
class Pos:
line: int
column: int
class Node(ABC):
@property
@abstractmethod
def pos(self) -> Pos:
raise NotImplementedError
@abstractmethod
def evaluate(self, variables: Vars) -> int:
raise NotImplementedError
class Constant(Node):
def __init__(self, pos: Pos, value: int) -> None:
self._pos = pos
self._value = value
@property
def pos(self) -> Pos:
return self._pos
def evaluate(self, variables: Vars) -> int:
return self._value
class Variable(Node):
def __init__(self, pos: Pos, name: str) -> None:
self._pos = pos
self._name = name
@property
def pos(self) -> Pos:
return self._pos
def evaluate(self, variables: Vars) -> int:
number = variables.get(self._name)
if number is None:
raise LookupError(self._name)
return number
class BinaryOperationNode(Node):
def __init__(self, pos: Pos, left: Node, right: Node) -> None:
self._pos = pos
self._left = left
self._right = right
@property
def pos(self) -> Pos:
return self._pos
@abstractmethod
def _operation(self, left: int, right: int) -> int:
raise NotImplementedError
def evaluate(self, variables: Vars) -> int:
left_result = self._left.evaluate(variables)
right_result = self._right.evaluate(variables)
return self._operation(left_result, right_result)
class Add(BinaryOperationNode):
def _operation(self, left: int, right: int) -> int:
return left + right
class Sub(BinaryOperationNode):
def _operation(self, left: int, right: int) -> int:
return left - right
class Mul(BinaryOperationNode):
def _operation(self, left: int, right: int) -> int:
return left * right