Skip to content

Commit

Permalink
Add USE_NUMBA_CACHE
Browse files Browse the repository at this point in the history
Numba caching seems to sometimes cause trouble, not sure why, but it's
probably good to allow disabling it anyway.
  • Loading branch information
otsaw committed Jun 14, 2023
1 parent 4e0cf20 commit 908b303
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
PENDING: Dataiter 0.45
======================

* `USE_NUMBA_CACHE`: New option, read from environment variable
`DATAITER_USE_NUMBA_CACHE` if exists, defauls to `True`

2023-06-13: Dataiter 0.44
=========================

Expand Down
6 changes: 5 additions & 1 deletion dataiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
PRINT_THOUSAND_SEPARATOR = ""
PRINT_TRUNCATE_WIDTH = 36
USE_NUMBA = False
USE_NUMBA_CACHE = True

with contextlib.suppress(LookupError):
USE_NUMBA_CACHE = util.parse_env_boolean("DATAITER_USE_NUMBA_CACHE")

try:
# Force Numba on or off if environment variable defined.
Expand All @@ -47,7 +51,7 @@
# and calling a trivial function works.
import numba
try:
@numba.njit(cache=True)
@numba.njit(cache=USE_NUMBA_CACHE)
def check(x):
return x**2
assert check(10) == 100
Expand Down
14 changes: 7 additions & 7 deletions dataiter/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def count_unique_apply(x, group, drop_na):
for xg in yield_groups(x, group, drop_na):
yield len(set(xg))

@njit(cache=True)
@njit(cache=dataiter.USE_NUMBA_CACHE)
def count_unique_apply_numba(x, group, drop_na):
out = []
for xg in yield_groups_numba(x, group, drop_na):
Expand Down Expand Up @@ -210,7 +210,7 @@ def aggregate(x, group, drop_na, default, nrequired):

@functools.lru_cache(256)
def generic_numba(function):
@njit(cache=True)
@njit(cache=dataiter.USE_NUMBA_CACHE)
def aggregate(x, group, drop_na, default, nrequired):
out = []
for xg in yield_groups_numba(x, group, drop_na):
Expand All @@ -237,7 +237,7 @@ def is_na_item_numba_overload(x):
return lambda x: x == ""
return lambda x: False

@njit(cache=True)
@njit(cache=dataiter.USE_NUMBA_CACHE)
def is_na_numba(x):
na = np.full(len(x), False)
for i in range(len(x)):
Expand Down Expand Up @@ -411,7 +411,7 @@ def mode_apply(x, group, drop_na):
for xg in yield_groups(x, group, drop_na):
yield mode1(xg) if len(xg) >= 1 else None

@njit(cache=True)
@njit(cache=dataiter.USE_NUMBA_CACHE)
def mode_apply_numba(x, group, drop_na):
out = []
for xg in yield_groups_numba(x, group, drop_na):
Expand Down Expand Up @@ -473,7 +473,7 @@ def nth_apply(x, group, index, drop_na):
except IndexError:
yield None

@njit(cache=True)
@njit(cache=dataiter.USE_NUMBA_CACHE)
def nth_apply_numba(x, group, index, drop_na):
out = []
for xg in yield_groups_numba(x, group, drop_na):
Expand Down Expand Up @@ -519,7 +519,7 @@ def quantile_apply(x, group, q, drop_na):
for xg in yield_groups(x, group, drop_na):
yield np.quantile(xg, q) if len(xg) >= 1 else np.nan

@njit(cache=True)
@njit(cache=dataiter.USE_NUMBA_CACHE)
def quantile_apply_numba(x, group, q, drop_na):
out = []
for xg in yield_groups_numba(x, group, drop_na):
Expand Down Expand Up @@ -654,7 +654,7 @@ def yield_groups(x, group, drop_na):
yield xij
i = j

@njit(cache=True)
@njit(cache=dataiter.USE_NUMBA_CACHE)
def yield_groups_numba(x, group, drop_na):
# Groups must be contiguous for this to work!
i = 0
Expand Down

0 comments on commit 908b303

Please sign in to comment.