Towhee compiler is a Python JIT compiler that speeds up AI-related codes by native code generation. The project is inspired by Numba, Pyjion and TorchDynamo. Towhee compiler uses a frame evaluation hook (see [PEP 523]: https://www.python.org/dev/peps/pep-0523/) to get the chance of compiling python bytecodes into native code.
The code is based on a forked version of torchdynamo, which extract
fx.Graph
by trace the execution of python code. But the goal of towhee compiler iswhole program code generation
, which also includes program that can not be represented byfx.Graph
.
Some environments are not yet supported (like m1 etc.), please try installing from source.
$ pip install towhee.compiler
$ git clone https://github.com/towhee-io/towhee-compiler.git
$ cd towhee-compiler && pip install -r requirements
$ python3 setup.py develop
- Compile
Towhee compiler can speedup any models, for example, we just need to add jit_compile
context to the image_embedding
function.
import torch
import torchvision.models as models
import numpy as np
import towhee.compiler
from towhee.compiler import jit_compile
# towhee.compiler.config.debug = True
torch_model = models.resnet50()
torch_model = torch.nn.Sequential(*(list(torch_model.children())[:-1]))
torch_model = torch_model.eval()
def image_embedding(inputs):
imgs = torch.tensor(inputs)
embedding = torch_model(imgs).detach().numpy()
return embedding.reshape([2048])
inputs = np.random.randn(1, 3, 244, 244).astype(np.float32)
with jit_compile():
embeddings = image_embedding(inputs)
- Timer
We have compiled the model with the nebullvm backend (the default backend in towhee.compiler ), and we can define a Timer class to record the time spent.
import time
class Timer:
def __init__(self, name):
self._name = name
def __enter__(self):
self._start = time.time()
return self
def __exit__(self, *args):
self._interval = time.time() - self._start
print('%s: %.2fs'%(self._name, self._interval))
And we can see that the compiled function is more than 3 times faster.
with Timer('Image Embedding'):
embeddings = image_embedding(inputs)
with Timer('Image Embedding with towhee compiler'), jit_compile():
embeddings_jit = image_embedding(inputs)
Image Embedding: 0.14s
Image Embedding with towhee compiler: 0.04s
Towhee supports setting JIT to use towhee.compiler to compile.
- Set JIT
For example, we can add set_jit('towhee')
in image embedding pipeline, then the following operator will be automatically compiled
import towhee
embeddings_towhee = (
towhee.dc(['https://raw.githubusercontent.com/towhee-io/towhee/main/towhee_logo.png'])
.image_decode()
.set_jit('towhee')
.image_embedding.timm(model_name='resnet50')
)
- Timer
And we can make two towhee pipeline function to record the time cost.
towhee_func = (towhee.dummy_input()
.image_embedding.timm(model_name='resnet50')
.as_function()
)
towhee_func_jit = (towhee.dummy_input()
.set_jit('towhee')
.image_embedding.timm(model_name='resnet50')
.as_function()
)
data = towhee.ops.image_decode()('https://raw.githubusercontent.com/towhee-io/towhee/main/towhee_logo.png')
with Timer('Towhee function'):
emb = towhee_func(data)
with Timer('Towhee function with Compiler'):
emb_jit = towhee_func_jit(data)
Towhee function: 0.14s
Towhee function with Compiler: 0.08s
According to the README of Operator on Towhee Hub, we set jit to compile and speedup model , theresults are as follows:
5.5 means that the performance after jit is 5.5 times, and N means no speedup or compilation failure. And more test results will be updated continuously.
Field | Task | Operator | Speedup(CPU/GPU) |
Image | Image Embedding | image_embedding.timm | 1.3/1.3 |
image_embedding.data2vec | 1.2/1.7 | ||
image_embedding.swag | 1.4/N | ||
Face Embedding | face_embedding.inceptionresnetv1 | 3.2/N | |
Face Landmark | face_landmark_detection.mobilefacenet | 2.1/2.1 | |
NLP | Text Embedding | text_embedding.transformers | 2.6/N |
text_embedding.data2vec | 1.8/N | ||
text_embedding.realm | 5.5/1.9 | ||
text_embedding.xlm_prophetnet | 2.1/2.8 | ||
Audio | Audio Classification | audio_classification.panns | 1.6/N |
Audio Embedding | audio_embedding.vggish | 1.5/N | |
audio_embedding.data2vec | 1.5/N | ||
Multimodal | Image Text | image_text_embedding.blip | 2.3/N |
Video Text | video_text_embedding.bridge_former(modality='text') | 2.1/N | |
video_text_embedding.frozen_in_time(modality='text') | 2.2/N |