Authors:
- @qihqi
A comprehensive set of APIs for defining a PyTorch backend entirely in Python. The design goal of this APIs are:
-
Easy to write a minimal backend: backend developers should be able to implement a PyTorch backend with only essential work. PyTorch should infer it can infer based on the minimal set of required things implemented.
-
Maximaze customization with optional components: a backend developer can choose to implement more than neccesary for his/her backend, say, to improve performance.
-
Everything should be doable in Python for fastest experimentation. C++ APIs can be made available for speed with the understanding that it's optional for one to use. i.e. everything doable in C++ is also doable in Python.
In other words, we start by asking the question: “What is the absolute neccesary amount of information that a backend has to tell PyTorch, for it to create tensors and run math on that backend?”
If a backend provided those, PyTorch should infer the rest for it do its job.
In this RFC, I hope to discuss the following items with the community:
- What is considered the "minimal set", and how to infer the rest.
- Proposed implementation.
- Look and feel of the registration API itself.
There are a few ways to register stuff to PyTorch and customize it's behavior. Notably:
-
C++ backend registration API: https://docs.pytorch.org/tutorials/advanced/privateuseone.html
- Eager mode backend
- You can define both data representation and operators in this API
-
Python custom operators: https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html
- You can only define operators.
- cannot define data representation
-
Dynamo custom backend: https://docs.pytorch.org/docs/stable/torch.compiler_custom_backends.html
- a torch.compile backend: you can define how to compile a
fx graph-> callable - cannot define data representation
- a torch.compile backend: you can define how to compile a
In summary:
| C++ | Python | |
|---|---|---|
| Define data representation | yes | No |
| Define single operator | yes | yes |
| Define graph compiler | No | yes |
| --- | --- | --- |
So now, to define a full-featured backend for a device, using torch_xla and GoogleTPU as the example here:
- Define the
XLATensorin C++, that contains the internal data representation - Define and register operators that knows how to do math on XLATensor, (can be done in both C++ or Python)
- Define a dynamo backend, in Python.
In the above process, one also have to learn a lots of implementation
details and concepts of PyTorch, such as DeviceGuard and DispatchKey.
If we ask the question of, “What is the absolute neccesary amount of information that a backend has to tell PyTorch, for it to create tensors and run math on that backend?” The answer are more minimalistic:
- We have to define a data structure that is opaque to PyTorch, representing our on-device data.
- We have to implement the ops for some op set (Core ATen / Prim only?).
- We have to tell PyTorch how to move a CPU tensor to our Backend Tensor.
Can we instead, have an API that let the backend developer tell us exactly that?
Next question is that, what are the optional things that PyTorch can infer, but a backend can choose to tell PyTorch for improve performance and usability?
- Tensor constructors: if absent PyTorch can create tensors on CPU and transfer via 3 above. if present, it makes it faster / more natural
- non-core Aten ops: maybe regitering direct lowering for say,
einsumcan improve performance? - Dynamo backend: if I only have an eager backend, then PyTorch can produce a dynamo backend
identical to
torch.compile(backend='eager')which is already there. Conversely: if I only provide a dynamo backend, PyTorch should be able to generate an eager backend by calling my compiler with graphs with only one node.
Pytorch provides:
class BackendEnvironment:
def __init__(self, name, blob_type, cpu_to_backend, backend_to_cpu):
...
def register_operator(self, aten_op_or_torch_function, callable):
...
Backend developer writes (below using Apple MLX as example):
def torch_to_mlx(tensor):
return mx.array(tensor.numpy())
def mlx_to_torch(array):
return torch.from_numpy(np.array(array))
environment = BackendEnvironment(
'mlx',
mx.array,
torch_to_mlx,
mlx_to_torch
)
environment.register_default_decompositions()
from torch.ops import aten
def register_op(op):
def inner(func):
environment.register_op(op, func)
return func
return inner
@register_op(aten.detach)
@register_op(aten.detach.default)
@register_op(aten.clone.default)
def aten_detach_default(x, *args):
return x
@register_op(aten.view.default)
@register_op(aten._unsafe_view.default)
def aten_view(x, shape):
return mx.reshape(x, shape)
...Let's parse the above:
This section:
environment = BackendEnvironment(
'mlx',
mx.array,
torch_to_mlx,
mlx_to_torch
)is saying:
-
My backend's name is 'mlx'
-
My opaque data is
mx.array; this could be any Python class, so if a backend want to have a tuple of elements etc can also do that. -
the 2 functions that maps CPU torch.Tensor to my blob and back are these 2. so if a user do
tensor.to('mlx')PyTorch would knows what to call. -
Everything else, including tensor constructors, please refer registered operators. If a particular tensor constructor doesn't exist, run the CPU one and move to device.
A strawman (runnable) version of this API is located here: https://github.com/qihqi/tnt/blob/main/torch_mlx.py Although the above has used the alternative implementation described below.
- Tensor creation:
When we create a backend tensor, we will first create the blob, then attach it to an empty CPU tensor.
- Operator wrapping / unwrapping
On call of a particular operator, say aten.add; PyTorch will:
- Intercept call
- unwrap the tensor to get the backend blob,
- call the registered op passing down the blob.
This is down via torch.library registry to have a handle capturing of
each operators.
A strawman implementation of this is illustrated in this unit test:
https://github.com/pytorch/pytorch/blob/main/test/test_privateuseone_python_backend.py
using numpy as the backend array.
- Dynamo integration
Currently the above implementation does not work on Dynamo, there probably will need some minor changes in dynamo itself.
However, in the limit, having a numba backend dynamo backend for numpy should be doable.
There are so many ways to extend PyTorch (https://docs.google.com/presentation/d/1piuv9nBzyoqdH49D1SoE5OZUPSMpOOFqfSKOhr-ab2c/edit) this is adding yet another way to do it. It currently utilizes existing mechanisms.
I have tried a tensor subclass based mechanism, used in the MLX demo
above. While works, it does not work well with torch.compile.
The spirit of tensor subclass is “wrapper around other Tensor that eventually desugar into ops on plain Tensors”.; so we
should probably respect that convention.
- Alban's new_device demo: https://github.com/albanD/subclass_zoo/blob/main/new_device.py -- uses torch_dispatch
- Tinygrad's torch backend: https://github.com/tinygrad/tinygrad/tree/82f10cfe2ee309fc048c4b04279e70102e84ca98/extra/torch_backend -- uses mechanism proposed by this RFC
- Torchax: https://google.github.io/torchax/ -- uses torch dispatch
One thing to distinguish is in terminology is backend vs. device.
A backend is a way to run compute, it could be for a device or could
be for an existint device.
What parts of the design do you expect to resolve through the RFC process before this gets merged?
The primary goal of this RFC is to understanding PyTorch maintainers' and the community's opinion on this approach.
The questions to maintainers are:
- Is this a good idea in general?
- If yes is there interest in formalizing and implementing it together?
- Is the API look-and-feel OK? Any improvement on those?
The questions to community is:
- Do you see yourself using something like this, if yes what are the usecases?
Choose one of the following:
- 1: Overwhelming positive feedback.
- 2: Positive feedback.
- 3: Majority Acceptance, with conflicting Feedback.
- 4: Acceptance, with Little Feedback.
- 5: Unclear Resolution.
- 6: RFC Rejected.
- 7: RFC Rejected, with Conflicting Feedback.