Using Protocols
Protocols allow you to accept objects of different types that match a particular shape.
You may have used a protocol before without realizing it. Sized
, Iterable
and Mapping
from collections.abc
are protocols.
They define what methods an object needs to have, but don't restrict what specific class it is.
The prelude
We'll be making a simple toolkit for making ASCII art
The code so far
from collections.abc import Iterable
import warnings
class Block:
def __init__(self, text: str) -> None:
self._text = text
def render(self, width: int) -> Iterable[str]:
for i in range(len(self._text) // width + 1):
yield self._text[i * width : (i + 1) * width]
class Guard:
def __init__(self, wrapped) -> None:
self._wrapped = wrapped
def render(self, width: int) -> Iterable[str]:
for line in self._wrapped.render(width):
if len(line) > width:
warnings.warn(f"Line length ({len(line)}) exceeds supplied width ({width})")
yield line[:width]
else:
yield line + " " * (width - len(line))
class Pad:
def __init__(self, char: str, wrapped) -> None:
if len(char) != 1:
raise ValueError("Expected a single character for padding")
self._char = char
self._wrapped = wrapped
def render(self, width: int) -> Iterable[str]:
yield self._char * width
if width > 2:
for line in self._wrapped.render(width - 2):
yield self._char + line + self._char
yield self._char * width
scene = \
Pad("*",
Guard(
Block("Hello, world! Lorem Ipsum Dolor Sit Amet, or so I've been told.")))
for line in scene.render(12):
print(line)
The program outputs:
************
*Hello, wor*
*ld! Lorem *
*Ipsum Dolo*
*r Sit Amet*
*, or so I'*
*ve been to*
*ld. *
************
Guard
widget, the *ld. *
line would be *ld.*
, which is not what we want — we extracted the padding functionality so that we don't have to trim the line in every widget's render
.
This is a classic example of the Decorator pattern: one object wraps another object with a well-known interface (the render
method here) and presents the same interface in return. That way, you can add functionality to widgets of potentially different classes and retain their widget-ness.
Defining a protocol
Right now we have an unwritten contract for all the widgets
: they must have a render
method, accepting a width
parameter that is an int
, and returning an iterable of strings. We can formalize this contract with a protocol class:
from collections.abc import Iterable
from typing import Protocol
class Widget(Protocol):
def render(self, width: int, /) -> Iterable[str]:
...
Note
The ...
here doesn't imply something's missing: leave the method body as ...
, the ellipsis object. I'll use # <...>
when I really mean to skip something.
This protocol defines the methods an object needs to support to be a Widget
. We can now add annotations to our wrapped
parameters:
class Guard:
def __init__(self, wrapped: Widget) -> None:
# <...>
class Pad:
def __init__(self, char: str, wrapped: Widget) -> None:
# <...>
Now if you try to call Guard("some text")
, your type checker will complain:
Argument of type "Literal['some text']" cannot be assigned to parameter "wrapped" of type "Widget" in function "__init__"
"Literal['some text']" is incompatible with protocol "Widget"
"render" is not present
Slash (/
) in the protocol definition
A /
in a function definition means that all the parameters before it can only be provided with positional arguments (see PEP 570).
This will be common in protocols: if you just said def render(self, width: int)
, it would mean that to implement this protocol you must name your parameter width
exactly.
But we're only going to pass the width as a positional argument, and it would be silly to reject something from being a Widget
just because its render
parameter is called w
or max_width
.
In addition to ...
, you can provide a docstring for the method, documenting its purpose and potentially other contract requirements:
class Widget(Protocol):
def render(self, width: int, /) -> Iterable[str]:
"""
Yield a finite amount of text lines.
Every line must be at most `width` characters long.
"""
...
Attributes
A protocol class may also define attributes:
This requires that the attribute is both readable and writable. So the following is allowed:
def just_read(thing: Labelled) -> None:
print(thing.label)
def calm_down(thing: Labelled) -> None:
thing.label = thing.label.replace("!", "")
More often you'll want a read-only attribute. You can do this by defining a property in the protocol class:
With this definition of Labelled
, calm_down
will not be accepted by a type checker.
Note that this doesn't require a Labelled
object to have an actual property, it could be a regular attribute as well:
Now, back to our Widget
s. If we want to do more complex layouts, we might need to know how wide a widget wants to be:
class Widget(Protocol):
def render(self, width: int, /) -> Iterable[str]:
...
@property
def desired_width(self) -> int:
"""
If I had unlimited width, how much would I want to grow?
"""
...
Now we're getting red squiggles in the scene = ...
line because our Guard
and Block
aren't widgets anymore. Let's fix that:
Code with desired_width
from collections.abc import Iterable
from typing import Protocol
import warnings
class Widget(Protocol):
def render(self, width: int, /) -> Iterable[str]:
...
@property
def desired_width(self) -> int:
...
class Block:
def __init__(self, text: str) -> None:
self._text = text
@property
def desired_width(self) -> int:
return len(self._text)
def render(self, width: int) -> Iterable[str]:
for i in range(len(self._text) // width + 1):
yield self._text[i * width : (i + 1) * width]
class Guard:
def __init__(self, wrapped: Widget) -> None:
self._wrapped = wrapped
@property
def desired_width(self) -> int:
return self._wrapped.desired_width
def render(self, width: int) -> Iterable[str]:
for line in self._wrapped.render(width):
if len(line) > width:
warnings.warn(f"Line length ({len(line)}) exceeds supplied width ({width})")
yield line[:width]
else:
yield line + " " * (width - len(line))
class Pad:
def __init__(self, char: str, wrapped: Widget) -> None:
if len(char) != 1:
raise ValueError("Expected a single character for padding")
self._char = char
self._wrapped = wrapped
@property
def desired_width(self) -> int:
return self._wrapped.desired_width + 2
def render(self, width: int) -> Iterable[str]:
yield self._char * width
if width > 2:
for line in self._wrapped.render(width - 2):
yield self._char + line + self._char
yield self._char * width
scene = \
Pad("*",
Guard(
Block("Hello, world! Lorem Ipsum Dolor Sit Amet, or so I've been told.")))
for line in scene.render(12):
print(line)
Now we can add classes that handle horizontal and flexible layout!
Advanced layout
This is a lot of code, so if you don't want to read all of it, note:
DesireWidth
has a plaindesired_width
attribute- other classes have a
desired_width
property - all of them work as
Widget
s!
class DesireWidth:
"""
Override the `desired_width` of a child widget with a fixed value.
"""
def __init__(self, desired_width: int, wrapped: Widget) -> None:
self.desired_width = desired_width
self._wrapped = wrapped
def render(self, width: int) -> Iterable[str]:
return self._wrapped.render(width)
class Vertical:
"""
Arrange widgets vertically, one after another.
"""
def __init__(self, items: Iterable[Widget]) -> None:
self._items = list(items)
@property
def desired_width(self) -> int:
return max(item.desired_width for item in self._items)
def render(self, width: int) -> Iterable[str]:
for item in self._items:
yield from item.render(width)
class Horizontal:
"""
Arrange widgets horizontally. You must specify how much horizontal
space each item occupies.
If the total of the spaces is bigger than the required width, the
lines will be cut off from the right.
"""
def __init__(self, items: Iterable[tuple[int, Widget]]) -> None:
self._items = list(items)
@property
def desired_width(self) -> int:
return sum(width for (width, _widget) in self._items)
def render(self, width: int) -> Iterable[str]:
if not self._items:
return
state = [
(subwidth, iter(Guard(widget).render(subwidth)))
for (subwidth, widget) in self._items
]
active = len(state)
while True:
total_line = ""
for i, (subwidth, render_iter) in enumerate(state):
try:
total_line += next(render_iter)
except StopIteration:
active -= 1
if active == 0:
return
total_line += " " * subwidth
# replace iterator with a dummy that always returns a blank line:
state[i] = (subwidth, iter(lambda: " " * subwidth, None))
yield total_line[:width]
class Flex:
"""
Arrange items in rows, switching to a new row when the next item
wouldn't fit according to its `desired_width`.
"""
def __init__(self, items: Iterable[Widget]) -> None:
self._items = list(items)
@property
def desired_width(self) -> int:
return sum(item.desired_width for item in self._items)
def render(self, width: int) -> Iterable[str]:
rows: list[Horizontal] = []
row: list[tuple[int, Widget]] = []
remaining = width
for item in self._items:
if remaining < item.desired_width:
if row:
rows.append(Horizontal(row))
row = [(min(width, item.desired_width), item)]
remaining = width - item.desired_width
else:
row.append((item.desired_width, item))
remaining -= item.desired_width
if row:
rows.append(Horizontal(row))
return Guard(Vertical(rows)).render(width)
Let's test it:
Test program and output
b1 = Pad("1", Guard(Block("First")))
b2 = DesireWidth(5, Pad("2", Guard(Block("SecondSecond"))))
b3 = Pad("3", Guard(Block("The Very Third!!")))
b4 = DesireWidth(10, Pad("4", Guard(Block("Fourth"))))
inner = Pad("+", Flex([b1, b2]))
outer = Pad("*", Flex([inner, b3, b4]))
for line in outer.render(30):
print(line)
print()
for line in outer.render(40):
print(line)
******************************
*++++++++++++++ *
*+111111122222+ *
*+1First12Sec2+ *
*+1 12ond2+ *
*+11111112Sec2+ *
*+ 2ond2+ *
*+ 2 2+ *
*+ 22222+ *
*++++++++++++++ *
*3333333333333333334444444444*
*3The Very Third!!34Fourth 4*
*3 34444444444*
*333333333333333333 *
******************************
****************************************
*++++++++++++++333333333333333333 *
*+111111122222+3The Very Third!!3 *
*+1First12Sec2+3 3 *
*+1 12ond2+333333333333333333 *
*+11111112Sec2+ *
*+ 2ond2+ *
*+ 2 2+ *
*+ 22222+ *
*++++++++++++++ *
*4444444444 *
*4Fourth 4 *
*4444444444 *
****************************************
Intermezzo: points, structural and nominal typing
In statically typed languages like Java and C++, you typically need to specify that you're conforming to an interface, either through dedicated syntax or through inheritance:
That's called nominal typing: the implementor must "name" the things it implements. This is opposed to structural typing: if a value has the required structure, it fits the interface. Both have their pros and cons.
- The obvious benefit of structural typing is that you don't need to explicitly refer to the interface you're implementing. That's how Python has always worked, with dunder methods and other duck typing techniques.
- Protocols make it easy to type existing contracts without requiring users to inherit from a base class all of a sudden.
-
You can't modify something you don't own. But you can make a protocol that just happens to match someone else's type. For example, a protocol with a
status_code: int
property and atext: str
method matchesrequests.Response
and potentially your own class. -
On the other hand, structural typing can make it hard to realize that a particular class implements a particular interface. You'll need to know about that interface beforehand to understand that intention.
- The errors for structural types are worse, as they occur at usage time. If one of our widgets had a typoed
desired_widht
property, you'd get a nasty error message when passing the widget somewhere else, saying that"Foo" is not a "Widget"
. This gets worse with unions, more complex protocols, and generics.
As a more subtle point, you can accidentally match a protocol, when your class doesn't semantically satisfy it.
class TwoDimensional(Protocol):
x: float
y: float
def reset(point: TwoDimensional) -> None:
point.x = point.y = 0.0
def distance_to_origin(point: TwoDimensional) -> float:
return (point.x**2 + point.y**2)**0.5
@dataclass
class Point2D:
x: float
y: float
assert 4.99 < distance_to_origin(Point2D(3.0, 4.0)) < 5.01
x: float
and y: float
attributes, but reset
and distance_to_origin
will give you completely wrong results with a Point3D
. So if it's vital that a class understands when it implements the contract, don't use a protocol.
The nominal counterpart to Protocol
: ABC
If you don't want the structural behaviour of protocols, you can define an abstract base class:
from abc import ABC, abstractmethod
class Widget(ABC):
@abstractmethod
def render(self, width: int, /) -> Iterable[str]:
raise NotImplementedError
@property
@abstractmethod
def desired_width(self) -> int:
raise NotImplementedError
With this, you'll need to inherit from Widget
explicitly in order to mark a class as a widget:
class Block(Widget):
def __init__(self, text: str) -> None:
self._text = text
@property
def desired_width(self) -> int:
return len(self._text)
def render(self, width: int) -> Iterable[str]:
for i in range(len(self._text) // width + 1):
yield self._text[i * width : (i + 1) * width]
You can read more about ABCs in the documentation for the abc
module.
Inheriting from a protocol
We said before that...
... structural typing can make it hard to realize that a particular class implements a particular interface. You'll need to know about that interface beforehand to understand that intention.
... The errors for structural types are worse, as they occur at usage time
If you want to alleviate that without requiring inheritance from your base class, you can just inherit from the protocol!
If you expect other people to inherit from a Protocol
class, you should prepare it for the real world a bit:
from abc import abstractmethod
from typing import Protocol
class Widget(Protocol):
@abstractmethod
def render(self, width: int, /) -> Iterable[str]:
raise NotImplementedError
@property
@abstractmethod
def desired_width(self) -> int:
raise NotImplementedError
# or provide a sensible default implementation:
@property
def desired_width(self) -> int:
return 1_000_000
-
Mark the required-to-implement methods with
abstractmethod
. Yes, the item from theabc
module. With this change, a class that doesn't implement all the items is considered abstract, and instantiating it will cause an error at runtime:class Foo(Protocol): @abstractmethod def method1(self) -> None: # <...> @abstractmethod def method2(self) -> None: # <...> class Bar(Foo): def method1(self): # <...> Bar() # TypeError: Can't instantiate abstract class Bar without an implementation for abstract method 'method2'
A type checker should already prevent you from instantiating
Bar()
, but it's good to catch the error at runtime as well. If you're making a library, you shouldn't assume that its users are using a type checker. There are also many creative ways around type checkers in Python. -
Raise a
NotImplementedError
in the method body. In an ABC, you're allowed to "bottom out" to an abstract method when usingsuper
:class Foo(ABC): @abstractmethod def greet(self) -> str: pass class Bar(Foo): def greet(self) -> str: return super().greet() print(Bar().greet()) # prints: None
So if you leave the method body to be
...
, asuper().render()
call in a widget could give you aNone
, which is wrong. For that reason, you should normally useraise NotImplementedError
for the body of abstract methods to catch those mistaken calls — both in ABCs and (inheritable) protocols.(Unless, of course, you do want to allow subclasses to execute some default behaviour if they end up
super
ing into an abstract method. Just make sure that the abstractm method makes sense for the type annotation )
And yes, that's the secret: Protocol
is mostly the same as abc.ABC
1, except type checkers understand that it's structural, not nominal.
A note on "compatibility"
In order to be compatible with a protocol, the method doesn't need to have exactly the same method signatures as the protocol, they just need to be "compatible".
In other words, it should be "at least as good as a protocol": the parameters can be more wide, but the return type can be more narrow.
All these are valid render
methods for a widget:
def render(self, width: int, /) -> Iterable[str]: # exactly the same
def render(self, width: int) -> Iterable[str]: # width can also be a keyword argument
def render(self, width: int, with_unicorn: bool = True) -> Iterable[str]: # _optional_ extra parameter is fine
def render(self, width: int | str) -> Iterable[str]: #
def render(self, width: object) -> Iterable[str]: # wider parameter type is allowed
def render(self, width: int) -> Iterator[str]: #
def render(self, width: int) -> list[str]: # narrower return type is allowed
Conclusion
Complete program listing
from abc import abstractmethod
from collections.abc import Iterable
from typing import Protocol
import warnings
class Widget(Protocol):
@abstractmethod
def render(self, width: int, /) -> Iterable[str]:
raise NotImplementedError
@property
@abstractmethod
def desired_width(self) -> int:
raise NotImplementedError
class Block:
def __init__(self, text: str) -> None:
self._text = text
@property
def desired_width(self) -> int:
return len(self._text)
def render(self, width: int) -> Iterable[str]:
for i in range(len(self._text) // width + 1):
yield self._text[i * width : (i + 1) * width]
class Guard:
def __init__(self, wrapped: Widget) -> None:
self._wrapped = wrapped
@property
def desired_width(self) -> int:
return self._wrapped.desired_width
def render(self, width: int) -> Iterable[str]:
for line in self._wrapped.render(width):
if len(line) > width:
warnings.warn(
f"Line length ({len(line)}) exceeds supplied width ({width})"
)
yield line[:width]
else:
yield line + " " * (width - len(line))
class Pad:
def __init__(self, char: str, wrapped: Widget) -> None:
if len(char) != 1:
raise ValueError("Expected a single character for padding")
self._char = char
self._wrapped = wrapped
@property
def desired_width(self) -> int:
return self._wrapped.desired_width + 2
def render(self, width: int) -> Iterable[str]:
yield self._char * width
if width > 2:
for line in self._wrapped.render(width - 2):
yield self._char + line + self._char
yield self._char * width
class DesireWidth:
"""
Override the `desired_width` of a child widget with a fixed value.
"""
def __init__(self, desired_width: int, wrapped: Widget) -> None:
self.desired_width = desired_width
self._wrapped = wrapped
def render(self, width: int) -> Iterable[str]:
return self._wrapped.render(width)
class Vertical:
"""
Arrange widgets vertically, one after another.
"""
def __init__(self, items: Iterable[Widget]) -> None:
self._items = list(items)
@property
def desired_width(self) -> int:
return max(item.desired_width for item in self._items)
def render(self, width: int) -> Iterable[str]:
for item in self._items:
yield from item.render(width)
class Horizontal:
"""
Arrange widgets horizontally. You must specify how much horizontal
space each item occupies.
If the total of the spaces is bigger than the required width, the
lines will be cut off from the right.
"""
def __init__(self, items: Iterable[tuple[int, Widget]]) -> None:
self._items = list(items)
@property
def desired_width(self) -> int:
return sum(width for (width, _widget) in self._items)
def render(self, width: int) -> Iterable[str]:
if not self._items:
return
state = [
(subwidth, iter(Guard(widget).render(subwidth)))
for (subwidth, widget) in self._items
]
active = len(state)
while True:
total_line = ""
for i, (subwidth, render_iter) in enumerate(state):
try:
total_line += next(render_iter)
except StopIteration:
active -= 1
if active == 0:
return
total_line += " " * subwidth
# replace iterator with a dummy that always returns a blank line:
state[i] = (subwidth, iter(lambda: " " * subwidth, None))
yield total_line[:width]
class Flex:
"""
Arrange items in rows, switching to a new row when the next item
wouldn't fit according to its `desired_width`.
"""
def __init__(self, items: Iterable[Widget]) -> None:
self._items = list(items)
@property
def desired_width(self) -> int:
return sum(item.desired_width for item in self._items)
def render(self, width: int) -> Iterable[str]:
rows: list[Horizontal] = []
row: list[tuple[int, Widget]] = []
remaining = width
for item in self._items:
if remaining < item.desired_width:
if row:
rows.append(Horizontal(row))
row = [(min(width, item.desired_width), item)]
remaining = width - item.desired_width
else:
row.append((item.desired_width, item))
remaining -= item.desired_width
if row:
rows.append(Horizontal(row))
return Guard(Vertical(rows)).render(width)
b1 = Pad("1", Guard(Block("First")))
b2 = DesireWidth(5, Pad("2", Guard(Block("SecondSecond"))))
b3 = Pad("3", Guard(Block("The Very Third!!")))
b4 = DesireWidth(10, Pad("4", Guard(Block("Fourth"))))
inner = Pad("+", Flex([b1, b2]))
outer = Pad("*", Flex([inner, b3, b4]))
for line in outer.render(30):
print(line)
print()
for line in outer.render(40):
print(line)
-
Protocol
has some added machinery to supporttyping.get_protocol_members
, but that's very niche ↩