Discussions about JIT #235
Replies: 8 comments
import torch
from spikingjelly.activation_based import cuda_utils
from spikingjelly import configure
import cupy
def fun_jit(x: torch.Tensor, y: torch.Tensor):
return x * y + 2. * x + 3. * y + torch.pow(x, 2.) + torch.pow(y, 3.)
kernel_code = r'''
extern "C" __global__
void fun(const float *x, const float *y, float *z, const int &N)
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
z[index] = x[index] * y[index] + 2.0f * x[index] + 3.0f * y[index] + powf(x[index], 2.0f) + powf(y[index], 3.0f);
def fun_cupy(x, y):
z = torch.zeros_like(x)
N = x.numel()
blocks = cuda_utils.cal_blocks(N)
device_id = x.get_device()
with cuda_utils.DeviceEnvironment(device_id):
N = cupy.asarray(N)
kernel = cupy.RawKernel(kernel_code, 'fun', options=('-use_fast_math',), backend=configure.cuda_compiler_backend)
x, y, z, N = cuda_utils.get_contiguous(x, y, z, N)
kernel_args = [x, y, z, N]
(blocks,), (configure.cuda_threads,),
return z
with torch.no_grad():
device = 'cuda:1'
x = torch.rand([1024, 1024, 16], device=device)
y = torch.rand_like(x)
print((fun_jit(x, y) - fun_cupy(x, y)).abs().max())
t1 = cuda_utils.cal_fun_t(2048, device, fun_jit, x, y)
t2 = cuda_utils.cal_fun_t(2048, device, fun_cupy, x, y)
Beta Was this translation helpful? Give feedback.
import torch
from spikingjelly.activation_based import cuda_utils
from spikingjelly import configure
import cupy
def fun_jit(x: torch.Tensor, y: torch.Tensor):
return x * y + 2. * x + 3. * y + torch.pow(x, 2.) + torch.pow(y, 3.)
kernel_code = r'''
extern "C" __global__
void fun(const float *x, const float *y, float *z, const int &N)
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
z[index] = x[index] * y[index] + 2.0f * x[index] + 3.0f * y[index] + powf(x[index], 2.0f) + powf(y[index], 3.0f);
def fun_cupy(x, y):
z = torch.zeros_like(x)
N = x.numel()
blocks = cuda_utils.cal_blocks(N)
device_id = x.get_device()
with cuda_utils.DeviceEnvironment(device_id):
kernel = cupy.RawKernel(kernel_code, 'fun', options=('-use_fast_math',), backend=configure.cuda_compiler_backend)
x, y, z = cuda_utils.get_contiguous(x, y, z)
mini_numel = blocks * configure.cuda_threads
start = 0
while True:
end = min(start + mini_numel, N)
kernel_args = [x[start: end], y[start: end], z[start: end], cupy.asarray(end - start + 1)]
(blocks,), (configure.cuda_threads,),
start = end
if start >= N:
return z
with torch.no_grad():
device = 'cuda:1'
for i in range(4):
x = torch.rand([1024 * 1024 * 1024], device=device)
y = torch.rand_like(x)
print((fun_jit(x, y) - fun_cupy(x, y)).abs().max())
t1 = cuda_utils.cal_fun_t(2048, device, fun_jit, x, y)
t2 = cuda_utils.cal_fun_t(2048, device, fun_cupy, x, y)
This method is slower. I will revert this commit. |
Beta Was this translation helpful? Give feedback.
import torch
from spikingjelly.activation_based import cuda_utils
from spikingjelly import configure
import cupy
def fun_jit(x: torch.Tensor, y: torch.Tensor):
return x * y + 2. * x + 3. * y + torch.pow(x, 2.) + torch.pow(y, 3.)
kernel_code = r'''
extern "C" __global__
void fun(const float *x, const float *y, float *z, const int &N)
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
z[index] = x[index] * y[index] + 2.0f * x[index] + 3.0f * y[index] + powf(x[index], 2.0f) + powf(y[index], 3.0f);
def wrap_args_to_raw_kernel(offset: int, device: int, *args):
ret_list = []
for item in args:
if isinstance(item, torch.Tensor):
assert item.get_device() == device
assert item.is_contiguous()
ret_list.append(item.data_ptr() + offset * item.element_size())
elif isinstance(item, cupy.ndarray):
assert item.device.id == device
assert item.flags['C_CONTIGUOUS']
raise TypeError
return tuple(ret_list)
def fun_cupy(x, y):
z = torch.zeros_like(x)
N = x.numel()
blocks = cuda_utils.cal_blocks(N)
device_id = x.get_device()
with cuda_utils.DeviceEnvironment(device_id):
kernel = cupy.RawKernel(kernel_code, 'fun', options=('-use_fast_math',), backend=configure.cuda_compiler_backend)
x, y, z = cuda_utils.get_contiguous(x, y, z)
mini_numel = blocks * configure.cuda_threads
start = 0
while True:
end = min(start + mini_numel, N)
kernel_args = [x, y, z, cupy.asarray(end - start + 1)]
(blocks,), (configure.cuda_threads,),
start = end
if start >= N:
return z
with torch.no_grad():
device = 'cuda:0'
for i in range(4):
x = torch.rand([1024 * 1024 * 16], device=device)
y = torch.rand_like(x)
print((fun_jit(x, y) - fun_cupy(x, y)).abs().max())
t1 = cuda_utils.cal_fun_t(2048, device, fun_jit, x, y)
t2 = cuda_utils.cal_fun_t(2048, device, fun_cupy, x, y)
Beta Was this translation helpful? Give feedback.
import torch
import math
from spikingjelly.activation_based import neuron, surrogate, functional, cuda_utils
import tqdm
def IFNode_fptt_hardReset(x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
h_seq = torch.zeros_like(x_seq)
spike_seq = torch.zeros_like(x_seq)
v_seq = torch.zeros_like(x_seq)
v_v_seq = torch.cat((v.unsqueeze(0), v_seq), 0)
for t in range(x_seq.shape[0]):
h_seq[t] = v_v_seq[t] + x_seq[t]
spike_seq[t] = (h_seq[t] >= v_th).to(x_seq)
v_v_seq[t + 1] = h_seq[t] * (1. - spike_seq[t]) + spike_seq[t] * v_reset
return v_v_seq[1:], h_seq, spike_seq
def IFNode_bptt_hardReset(grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor, h_seq: torch.Tensor, spike_seq: torch.Tensor, v_th: float, v_reset: float):
grad_x_seq = torch.zeros_like(grad_spike_seq)
over_th = h_seq - v_th
alpha = 2.
grad_s_to_h = alpha / 2 / (1 + (math.pi / 2 * alpha * over_th).pow(2))
grad_v_to_h = 1.0 - spike_seq + (v_reset - h_seq) * grad_s_to_h
grad_h = torch.zeros_like(grad_spike_seq[0])
for t in range(grad_spike_seq.shape[0] - 1, -1, -1):
grad_h = grad_spike_seq[t] * grad_s_to_h[t] + (grad_v_seq[t] + grad_h) * grad_v_to_h[t]
grad_x_seq[t] = grad_h
grad_v_last = grad_h
return grad_x_seq, grad_v_last
class IFNodeATan(torch.autograd.Function):
def forward(ctx, x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
v_seq, h_seq, spike_seq = IFNode_fptt_hardReset(x_seq, v, v_th, v_reset)
ctx.save_for_backward(h_seq, spike_seq)
ctx.v_th = v_th
ctx.v_reset = v_reset
return spike_seq, v_seq
def backward(ctx, grad_spike_seq, grad_v_seq):
h_seq, spike_seq = ctx.saved_tensors
grad_x_seq, grad_v_last = IFNode_bptt_hardReset(grad_spike_seq, grad_v_seq, h_seq, spike_seq, ctx.v_th, ctx.v_reset)
return grad_x_seq, grad_v_last, None, None
class JITIFNode(neuron.IFNode):
def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
spike_seq, v_seq = IFNodeATan.apply(
x_seq, self.v, self.v_threshold, self.v_reset)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
if self.v_reset is None:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq, self.v, self.v_threshold)
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(x_seq, self.v, self.v_threshold)
if self.store_v_seq:
spike_seq, self.v, self.v_seq =self.jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq, self.v, self.v_threshold, self.v_reset)
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(x_seq, self.v, self.v_threshold, self.v_reset)
return spike_seq
def forward_backward(multi_step_neuron, x):
def cal_forward_backward_t(multi_step_neuron, x, repeat_times):
used_t = cuda_utils.cal_fun_t(repeat_times, x.device, forward_backward, multi_step_neuron, x)
return used_t * 1000
device = 'cuda:0'
ms_if = neuron.IFNode(step_mode='m', surrogate_function=surrogate.ATan(), backend='cupy')
jit_if = JITIFNode(step_mode='m')
T = 128
repeat_times = 32
x = torch.rand(T, 1024, device=device, requires_grad=True)
y_ms = ms_if(x)
grad_ms = x.grad.clone()
y_jit = jit_if(x)
grad_jit = x.grad.clone()
with torch.no_grad():
print((y_jit - y_ms).abs().max())
print((grad_jit - grad_ms).abs().max())
with open('./test_half.csv', 'w+') as csv_file:
csv_file.write('T, jit_if, t_cupy\n')
N = 2 ** 20
print('forward and backward')
for T in tqdm.trange(1, T + 1):
# for T in [2, 4, 8, 16, 32, 64, 128]:
x = torch.rand(T, N, device=device).half()
t_jit = cal_forward_backward_t(jit_if, x, repeat_times)
t_cupy = cal_forward_backward_t(ms_if, x, repeat_times)
csv_file.write(f'{T}, {t_jit}, {t_cupy}\n') CUPY forward and backward is still faster than JIT. |
Beta Was this translation helpful? Give feedback.
Compare the bool2float between jit and cupy: import torch
import torch.nn.functional as F
from spikingjelly.activation_based import tensor_cache, cuda_utils
def float2bool(x: torch.Tensor, mask: torch.Tensor):
x_shape = x.shape
padding = 8 - x.numel() % 8
if padding != 0 and padding != 8:
x = x.flatten()
x = F.pad(x, (0, padding))
x = x.view(-1, 8)
return (torch.bitwise_left_shift(x, mask)).sum(1).to(torch.uint8), x_shape, padding
def bool2float(x: torch.Tensor, mask: torch.Tensor, x_shape: list[int], padding: int, dtype: int = torch.float32):
x = x.unsqueeze(1).repeat(1, 8)
x = torch.bitwise_and(x, mask).ne(0).to(dtype)
if padding != 0 and padding != 8:
x = x.flatten()
x = x[:x.numel() - padding]
return x.reshape(x_shape)
def cp_convert(x):
return tensor_cache.bool_spike_to_float(*tensor_cache.float_spike_to_bool(x))
def jit_convert(x, float2bool_mask, bool2float_mask):
dtype = x.dtype
xb, x_shape, padding = float2bool(x, float2bool_mask)
y = bool2float(xb, bool2float_mask, x_shape, padding, dtype)
return y
dtype = torch.float32
device = 'cuda:0'
float2bool_mask = torch.arange(8, device=device, dtype=dtype)
bool2float_mask = (2 ** float2bool_mask).to(torch.uint8)
x = (torch.rand([64, 2, 3], device=device) > 0.5).to(dtype)
xb, x_shape, padding = float2bool(x, float2bool_mask)
y = bool2float(xb, bool2float_mask, x_shape, padding, dtype)
print((x - y).abs().sum())
print((cp_convert(x) - jit_convert(x, float2bool_mask, bool2float_mask)).abs().sum())
repeats = 32
dtype = torch.float32
for i in range(1, 33):
x = (torch.rand([i, 16, 128, 64, 64], device=device) > 0.5).to(dtype)
t_cp = cuda_utils.cal_fun_t(repeats, device, cp_convert, x)
t_jit = cuda_utils.cal_fun_t(repeats, device, jit_convert, x, float2bool_mask, bool2float_mask)
print(f'{i}, {t_cp}, {t_jit}, {t_cp / t_jit}') The results show that cupy is faster. |
Beta Was this translation helpful? Give feedback.
About using if-else in jit functions, rather than using python if-else to call different jit functions: import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import cuda_utils, neuron
device = 'cuda:0'
class IFNode(neuron.IFNode):
def jit_eval_multi_step_forward(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, store_v_seq: bool):
spike_seq = torch.zeros_like(x_seq)
if store_v_seq:
v_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
if v_reset == math.nan:
v = v - spike * v_threshold
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
if store_v_seq:
v_seq[t] = v
if store_v_seq:
return spike_seq, v, v_seq
return spike_seq, v
def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
return super(IFNode, self).multi_step_forward(x_seq)
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward(x_seq, self.v, self.v_threshold, self.v_reset, self.store_v_seq)
spike_seq, self.v = self.jit_eval_multi_step_forward(x_seq, self.v, self.v_threshold, self.v_reset, self.store_v_seq)
return spike_seq
class IFNode(neuron.IFNode):
def jit_eval_multi_step_forward(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
if v_reset == math.nan:
v = v - spike * v_threshold
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
return spike_seq, v
Beta Was this translation helpful? Give feedback.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import cuda_utils, neuron
device = 'cuda:0'
def jit_eval_single_step_forward_hard_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
v = v + x
spike = (v >= v_threshold).to(x)
v = v_reset * spike + (1. - spike) * v
return spike, v
class IFNode(neuron.IFNode):
# test: use many jit functions, rather than one jit function
def jit_eval_multi_step_forward_hard_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
spike_seq[t], v = jit_eval_single_step_forward_hard_reset(x_seq[t], v, v_threshold, v_reset)
return spike_seq, v
net = neuron.IFNode(step_mode='s')
net2 = IFNode(step_mode='s')
x = torch.rand([16, 1024, 1024], device=device)
t1 = cuda_utils.cal_fun_t(64, device, net, x)
t2 = cuda_utils.cal_fun_t(64, device, net2, x)
Beta Was this translation helpful? Give feedback.
from curses import tigetflag
import torch
import math
from spikingjelly.activation_based import neuron, surrogate, functional, cuda_utils
import tqdm
from functorch.compile import memory_efficient_fusion
# torch.manual_seed(0)
# @torch.jit.script
def IFNode_fptt_hardReset(x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
h_seq = []
spike_seq = []
v_seq = []
vt = v
for t in range(x_seq.shape[0]):
h = vt + x_seq[t]
spike = (h >= v_th).to(x_seq)
vt = h * (1. - spike) + spike * v_reset
return torch.stack(v_seq), torch.stack(h_seq), torch.stack(spike_seq)
IFNode_fptt_hardReset = memory_efficient_fusion(IFNode_fptt_hardReset, static_argnums=(2, 3))
# @torch.jit.script
def IFNode_bptt_hardReset(grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor, h_seq: torch.Tensor, spike_seq: torch.Tensor, v_th: float, v_reset: float):
grad_x_seq = []
over_th = h_seq - v_th
alpha = 2.
grad_s_to_h = alpha / 2 / (1 + (math.pi / 2 * alpha * over_th).pow(2))
grad_v_to_h = 1.0 - spike_seq + (v_reset - h_seq) * grad_s_to_h
grad_h = torch.zeros_like(grad_spike_seq[0])
for t in range(grad_spike_seq.shape[0] - 1, -1, -1):
grad_h = grad_spike_seq[t] * grad_s_to_h[t] + (grad_v_seq[t] + grad_h) * grad_v_to_h[t]
grad_x_seq.insert(0, grad_h)
grad_v_last = grad_h
return torch.stack(grad_x_seq), grad_v_last
IFNode_bptt_hardReset = memory_efficient_fusion(IFNode_bptt_hardReset, static_argnums=(4, 5))
class IFNodeATan(torch.autograd.Function):
def forward(ctx, x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
v_seq, h_seq, spike_seq = IFNode_fptt_hardReset(x_seq, v, v_th, v_reset)
ctx.save_for_backward(h_seq, spike_seq)
ctx.v_th = v_th
ctx.v_reset = v_reset
return spike_seq, v_seq
def backward(ctx, grad_spike_seq, grad_v_seq):
h_seq, spike_seq = ctx.saved_tensors
grad_x_seq, grad_v_last = IFNode_bptt_hardReset(grad_spike_seq, grad_v_seq, h_seq, spike_seq, ctx.v_th, ctx.v_reset)
return grad_x_seq, grad_v_last, None, None
class JITIFNode(neuron.IFNode):
def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
spike_seq, v_seq = IFNodeATan.apply(
x_seq, self.v, self.v_threshold, self.v_reset)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
if self.v_reset is None:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq, self.v, self.v_threshold)
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(x_seq, self.v, self.v_threshold)
if self.store_v_seq:
spike_seq, self.v, self.v_seq =self.jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq, self.v, self.v_threshold, self.v_reset)
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(x_seq, self.v, self.v_threshold, self.v_reset)
return spike_seq
def forward_backward(multi_step_neuron, x):
def cal_forward_backward_t(multi_step_neuron, x, repeat_times):
used_t = cuda_utils.cal_fun_t(repeat_times, x.device, forward_backward, multi_step_neuron, x)
return used_t * 1000
device = 'cuda:0'
ms_if = neuron.IFNode(step_mode='m', surrogate_function=surrogate.ATan(), backend='cupy')
jit_if = JITIFNode(step_mode='m')
T = 32
repeat_times = 1024
x = torch.rand(T, 1, device=device, requires_grad=True)
y_ms = ms_if(x)
grad_ms = x.grad.clone()
y_jit = jit_if(x)
grad_jit = x.grad.clone()
with torch.no_grad():
print((y_jit - y_ms).abs().max())
print((grad_jit - grad_ms).abs().max())
with open('./test_half.csv', 'w+') as csv_file:
csv_file.write('T, jit_if, t_cupy\n')
N = 2 ** 20
print('forward and backward')
# for T in tqdm.trange(1, T + 1):
for T in [2, 4, 8, 16, 32, 64, 128]:
x = torch.rand(T, N, device=device).half()
t_jit = cal_forward_backward_t(jit_if, x, repeat_times)
t_cupy = cal_forward_backward_t(ms_if, x, repeat_times)
print(f'{T}, {t_jit}, {t_cupy}')
csv_file.write(f'{T}, {t_jit}, {t_cupy}\n') (sj-dev) wfang@mlg-ThinkStation-P920:~$ /home/wfang/anaconda3/envs/sj-dev/bin/python /home/wfang/tempdir/w_jit.py
tensor(0., device='cuda:0')
tensor(1.1921e-07, device='cuda:0')
forward and backward
2, 2041.997970198281, 1939.34284534771
4, 2219.71859713085, 1938.3031243924052
8, 2580.2629694808275, 1901.0408432222903
16, 2918.2764051947743, 1735.7245936291292
32, 4113.050622399896, 2867.309346329421
64, 7693.348492030054, 5766.34434890002
128, 15210.309787653387, 11805.190281942487
Beta Was this translation helpful? Give feedback.
Beta Was this translation helpful? Give feedback.
All reactions