Discussions about JIT #235
Replies: 8 comments
-
import torch
from spikingjelly.activation_based import cuda_utils
from spikingjelly import configure
import cupy
@torch.jit.script
def fun_jit(x: torch.Tensor, y: torch.Tensor):
return x * y + 2. * x + 3. * y + torch.pow(x, 2.) + torch.pow(y, 3.)
kernel_code = r'''
extern "C" __global__
void fun(const float *x, const float *y, float *z, const int &N)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
z[index] = x[index] * y[index] + 2.0f * x[index] + 3.0f * y[index] + powf(x[index], 2.0f) + powf(y[index], 3.0f);
}
}
'''
def fun_cupy(x, y):
z = torch.zeros_like(x)
N = x.numel()
blocks = cuda_utils.cal_blocks(N)
device_id = x.get_device()
with cuda_utils.DeviceEnvironment(device_id):
N = cupy.asarray(N)
kernel = cupy.RawKernel(kernel_code, 'fun', options=('-use_fast_math',), backend=configure.cuda_compiler_backend)
x, y, z, N = cuda_utils.get_contiguous(x, y, z, N)
kernel_args = [x, y, z, N]
kernel(
(blocks,), (configure.cuda_threads,),
cuda_utils.wrap_args_to_raw_kernel(
device_id,
*kernel_args
)
)
return z
with torch.no_grad():
device = 'cuda:1'
x = torch.rand([1024, 1024, 16], device=device)
y = torch.rand_like(x)
print((fun_jit(x, y) - fun_cupy(x, y)).abs().max())
t1 = cuda_utils.cal_fun_t(2048, device, fun_jit, x, y)
t2 = cuda_utils.cal_fun_t(2048, device, fun_cupy, x, y)
print(t1)
print(t2)
|
Beta Was this translation helpful? Give feedback.
-
import torch
from spikingjelly.activation_based import cuda_utils
from spikingjelly import configure
import cupy
@torch.jit.script
def fun_jit(x: torch.Tensor, y: torch.Tensor):
return x * y + 2. * x + 3. * y + torch.pow(x, 2.) + torch.pow(y, 3.)
kernel_code = r'''
extern "C" __global__
void fun(const float *x, const float *y, float *z, const int &N)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
z[index] = x[index] * y[index] + 2.0f * x[index] + 3.0f * y[index] + powf(x[index], 2.0f) + powf(y[index], 3.0f);
}
}
'''
def fun_cupy(x, y):
z = torch.zeros_like(x)
N = x.numel()
blocks = cuda_utils.cal_blocks(N)
device_id = x.get_device()
with cuda_utils.DeviceEnvironment(device_id):
kernel = cupy.RawKernel(kernel_code, 'fun', options=('-use_fast_math',), backend=configure.cuda_compiler_backend)
x, y, z = cuda_utils.get_contiguous(x, y, z)
mini_numel = blocks * configure.cuda_threads
start = 0
while True:
end = min(start + mini_numel, N)
kernel_args = [x[start: end], y[start: end], z[start: end], cupy.asarray(end - start + 1)]
kernel(
(blocks,), (configure.cuda_threads,),
cuda_utils.wrap_args_to_raw_kernel(
device_id,
*kernel_args
)
)
start = end
if start >= N:
break
return z
with torch.no_grad():
device = 'cuda:1'
for i in range(4):
x = torch.rand([1024 * 1024 * 1024], device=device)
y = torch.rand_like(x)
print((fun_jit(x, y) - fun_cupy(x, y)).abs().max())
t1 = cuda_utils.cal_fun_t(2048, device, fun_jit, x, y)
t2 = cuda_utils.cal_fun_t(2048, device, fun_cupy, x, y)
print(t1)
print(t2)
This method is slower. I will revert this commit. |
Beta Was this translation helpful? Give feedback.
-
import torch
from spikingjelly.activation_based import cuda_utils
from spikingjelly import configure
import cupy
@torch.jit.script
def fun_jit(x: torch.Tensor, y: torch.Tensor):
return x * y + 2. * x + 3. * y + torch.pow(x, 2.) + torch.pow(y, 3.)
kernel_code = r'''
extern "C" __global__
void fun(const float *x, const float *y, float *z, const int &N)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
z[index] = x[index] * y[index] + 2.0f * x[index] + 3.0f * y[index] + powf(x[index], 2.0f) + powf(y[index], 3.0f);
}
}
'''
def wrap_args_to_raw_kernel(offset: int, device: int, *args):
ret_list = []
for item in args:
if isinstance(item, torch.Tensor):
assert item.get_device() == device
assert item.is_contiguous()
ret_list.append(item.data_ptr() + offset * item.element_size())
elif isinstance(item, cupy.ndarray):
assert item.device.id == device
assert item.flags['C_CONTIGUOUS']
ret_list.append(item)
else:
raise TypeError
return tuple(ret_list)
def fun_cupy(x, y):
z = torch.zeros_like(x)
N = x.numel()
blocks = cuda_utils.cal_blocks(N)
device_id = x.get_device()
with cuda_utils.DeviceEnvironment(device_id):
kernel = cupy.RawKernel(kernel_code, 'fun', options=('-use_fast_math',), backend=configure.cuda_compiler_backend)
x, y, z = cuda_utils.get_contiguous(x, y, z)
mini_numel = blocks * configure.cuda_threads
start = 0
while True:
end = min(start + mini_numel, N)
kernel_args = [x, y, z, cupy.asarray(end - start + 1)]
kernel(
(blocks,), (configure.cuda_threads,),
wrap_args_to_raw_kernel(
start,
device_id,
*kernel_args
)
)
start = end
if start >= N:
break
return z
with torch.no_grad():
device = 'cuda:0'
for i in range(4):
x = torch.rand([1024 * 1024 * 16], device=device)
y = torch.rand_like(x)
print((fun_jit(x, y) - fun_cupy(x, y)).abs().max())
t1 = cuda_utils.cal_fun_t(2048, device, fun_jit, x, y)
t2 = cuda_utils.cal_fun_t(2048, device, fun_cupy, x, y)
print(t1)
print(t2)
|
Beta Was this translation helpful? Give feedback.
-
import torch
import math
from spikingjelly.activation_based import neuron, surrogate, functional, cuda_utils
import tqdm
@torch.jit.script
def IFNode_fptt_hardReset(x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
h_seq = torch.zeros_like(x_seq)
spike_seq = torch.zeros_like(x_seq)
v_seq = torch.zeros_like(x_seq)
v_v_seq = torch.cat((v.unsqueeze(0), v_seq), 0)
for t in range(x_seq.shape[0]):
h_seq[t] = v_v_seq[t] + x_seq[t]
spike_seq[t] = (h_seq[t] >= v_th).to(x_seq)
v_v_seq[t + 1] = h_seq[t] * (1. - spike_seq[t]) + spike_seq[t] * v_reset
return v_v_seq[1:], h_seq, spike_seq
@torch.jit.script
def IFNode_bptt_hardReset(grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor, h_seq: torch.Tensor, spike_seq: torch.Tensor, v_th: float, v_reset: float):
grad_x_seq = torch.zeros_like(grad_spike_seq)
over_th = h_seq - v_th
alpha = 2.
grad_s_to_h = alpha / 2 / (1 + (math.pi / 2 * alpha * over_th).pow(2))
grad_v_to_h = 1.0 - spike_seq + (v_reset - h_seq) * grad_s_to_h
grad_h = torch.zeros_like(grad_spike_seq[0])
for t in range(grad_spike_seq.shape[0] - 1, -1, -1):
grad_h = grad_spike_seq[t] * grad_s_to_h[t] + (grad_v_seq[t] + grad_h) * grad_v_to_h[t]
grad_x_seq[t] = grad_h
grad_v_last = grad_h
return grad_x_seq, grad_v_last
class IFNodeATan(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
v_seq, h_seq, spike_seq = IFNode_fptt_hardReset(x_seq, v, v_th, v_reset)
ctx.save_for_backward(h_seq, spike_seq)
ctx.v_th = v_th
ctx.v_reset = v_reset
return spike_seq, v_seq
@staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq):
h_seq, spike_seq = ctx.saved_tensors
grad_x_seq, grad_v_last = IFNode_bptt_hardReset(grad_spike_seq, grad_v_seq, h_seq, spike_seq, ctx.v_th, ctx.v_reset)
return grad_x_seq, grad_v_last, None, None
class JITIFNode(neuron.IFNode):
def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = IFNodeATan.apply(
x_seq, self.v, self.v_threshold, self.v_reset)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
self.v_float_to_tensor(x_seq[0])
if self.v_reset is None:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq, self.v, self.v_threshold)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(x_seq, self.v, self.v_threshold)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq =self.jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq, self.v, self.v_threshold, self.v_reset)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(x_seq, self.v, self.v_threshold, self.v_reset)
return spike_seq
def forward_backward(multi_step_neuron, x):
multi_step_neuron(x).mean().backward()
multi_step_neuron.reset()
x.grad.zero_()
def cal_forward_backward_t(multi_step_neuron, x, repeat_times):
x.requires_grad_(True)
used_t = cuda_utils.cal_fun_t(repeat_times, x.device, forward_backward, multi_step_neuron, x)
return used_t * 1000
device = 'cuda:0'
ms_if = neuron.IFNode(step_mode='m', surrogate_function=surrogate.ATan(), backend='cupy')
ms_if.to(device)
jit_if = JITIFNode(step_mode='m')
jit_if.to(device)
T = 128
repeat_times = 32
x = torch.rand(T, 1024, device=device, requires_grad=True)
y_ms = ms_if(x)
y_ms.sum().backward()
grad_ms = x.grad.clone()
x.grad.zero_()
ms_if.reset()
y_jit = jit_if(x)
y_jit.sum().backward()
grad_jit = x.grad.clone()
x.grad.zero_()
jit_if.reset()
with torch.no_grad():
print((y_jit - y_ms).abs().max())
print((grad_jit - grad_ms).abs().max())
with open('./test_half.csv', 'w+') as csv_file:
csv_file.write('T, jit_if, t_cupy\n')
N = 2 ** 20
print('forward and backward')
ms_if.train()
jit_if.train()
for T in tqdm.trange(1, T + 1):
# for T in [2, 4, 8, 16, 32, 64, 128]:
x = torch.rand(T, N, device=device).half()
t_jit = cal_forward_backward_t(jit_if, x, repeat_times)
t_cupy = cal_forward_backward_t(ms_if, x, repeat_times)
csv_file.write(f'{T}, {t_jit}, {t_cupy}\n') CUPY forward and backward is still faster than JIT. |
Beta Was this translation helpful? Give feedback.
-
Compare the bool2float between jit and cupy: import torch
import torch.nn.functional as F
from spikingjelly.activation_based import tensor_cache, cuda_utils
@torch.jit.script
def float2bool(x: torch.Tensor, mask: torch.Tensor):
x_shape = x.shape
padding = 8 - x.numel() % 8
if padding != 0 and padding != 8:
x = x.flatten()
x = F.pad(x, (0, padding))
x = x.view(-1, 8)
return (torch.bitwise_left_shift(x, mask)).sum(1).to(torch.uint8), x_shape, padding
@torch.jit.script
def bool2float(x: torch.Tensor, mask: torch.Tensor, x_shape: list[int], padding: int, dtype: int = torch.float32):
x = x.unsqueeze(1).repeat(1, 8)
x = torch.bitwise_and(x, mask).ne(0).to(dtype)
if padding != 0 and padding != 8:
x = x.flatten()
x = x[:x.numel() - padding]
return x.reshape(x_shape)
def cp_convert(x):
return tensor_cache.bool_spike_to_float(*tensor_cache.float_spike_to_bool(x))
def jit_convert(x, float2bool_mask, bool2float_mask):
dtype = x.dtype
xb, x_shape, padding = float2bool(x, float2bool_mask)
y = bool2float(xb, bool2float_mask, x_shape, padding, dtype)
return y
dtype = torch.float32
device = 'cuda:0'
float2bool_mask = torch.arange(8, device=device, dtype=dtype)
bool2float_mask = (2 ** float2bool_mask).to(torch.uint8)
x = (torch.rand([64, 2, 3], device=device) > 0.5).to(dtype)
xb, x_shape, padding = float2bool(x, float2bool_mask)
y = bool2float(xb, bool2float_mask, x_shape, padding, dtype)
print((x - y).abs().sum())
print((cp_convert(x) - jit_convert(x, float2bool_mask, bool2float_mask)).abs().sum())
repeats = 32
dtype = torch.float32
for i in range(1, 33):
x = (torch.rand([i, 16, 128, 64, 64], device=device) > 0.5).to(dtype)
t_cp = cuda_utils.cal_fun_t(repeats, device, cp_convert, x)
t_jit = cuda_utils.cal_fun_t(repeats, device, jit_convert, x, float2bool_mask, bool2float_mask)
print(f'{i}, {t_cp}, {t_jit}, {t_cp / t_jit}') The results show that cupy is faster. |
Beta Was this translation helpful? Give feedback.
-
About using if-else in jit functions, rather than using python if-else to call different jit functions: import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import cuda_utils, neuron
device = 'cuda:0'
class IFNode(neuron.IFNode):
@staticmethod
@torch.jit.script
def jit_eval_multi_step_forward(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, store_v_seq: bool):
spike_seq = torch.zeros_like(x_seq)
if store_v_seq:
v_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
if v_reset == math.nan:
v = v - spike * v_threshold
else:
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
if store_v_seq:
v_seq[t] = v
if store_v_seq:
return spike_seq, v, v_seq
else:
return spike_seq, v
def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
return super(IFNode, self).multi_step_forward(x_seq)
else:
self.v_float_to_tensor(x_seq[0])
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward(x_seq, self.v, self.v_threshold, self.v_reset, self.store_v_seq)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward(x_seq, self.v, self.v_threshold, self.v_reset, self.store_v_seq)
return spike_seq
class IFNode(neuron.IFNode):
@staticmethod
@torch.jit.script
def jit_eval_multi_step_forward(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
if v_reset == math.nan:
v = v - spike * v_threshold
else:
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
else:
return spike_seq, v
|
Beta Was this translation helpful? Give feedback.
-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import cuda_utils, neuron
device = 'cuda:0'
@torch.jit.script
def jit_eval_single_step_forward_hard_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
v = v + x
spike = (v >= v_threshold).to(x)
v = v_reset * spike + (1. - spike) * v
return spike, v
class IFNode(neuron.IFNode):
# test: use many jit functions, rather than one jit function
@staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_hard_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
spike_seq[t], v = jit_eval_single_step_forward_hard_reset(x_seq[t], v, v_threshold, v_reset)
return spike_seq, v
net = neuron.IFNode(step_mode='s')
net.eval()
net2 = IFNode(step_mode='s')
net2.eval()
x = torch.rand([16, 1024, 1024], device=device)
t1 = cuda_utils.cal_fun_t(64, device, net, x)
t2 = cuda_utils.cal_fun_t(64, device, net2, x)
print(t1)
print(t2)
|
Beta Was this translation helpful? Give feedback.
-
from curses import tigetflag
import torch
import math
from spikingjelly.activation_based import neuron, surrogate, functional, cuda_utils
import tqdm
from functorch.compile import memory_efficient_fusion
# torch.manual_seed(0)
# @torch.jit.script
def IFNode_fptt_hardReset(x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
h_seq = []
spike_seq = []
v_seq = []
vt = v
for t in range(x_seq.shape[0]):
h = vt + x_seq[t]
spike = (h >= v_th).to(x_seq)
vt = h * (1. - spike) + spike * v_reset
h_seq.append(h)
spike_seq.append(spike)
v_seq.append(vt)
return torch.stack(v_seq), torch.stack(h_seq), torch.stack(spike_seq)
IFNode_fptt_hardReset = memory_efficient_fusion(IFNode_fptt_hardReset, static_argnums=(2, 3))
# @torch.jit.script
def IFNode_bptt_hardReset(grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor, h_seq: torch.Tensor, spike_seq: torch.Tensor, v_th: float, v_reset: float):
grad_x_seq = []
over_th = h_seq - v_th
alpha = 2.
grad_s_to_h = alpha / 2 / (1 + (math.pi / 2 * alpha * over_th).pow(2))
grad_v_to_h = 1.0 - spike_seq + (v_reset - h_seq) * grad_s_to_h
grad_h = torch.zeros_like(grad_spike_seq[0])
for t in range(grad_spike_seq.shape[0] - 1, -1, -1):
grad_h = grad_spike_seq[t] * grad_s_to_h[t] + (grad_v_seq[t] + grad_h) * grad_v_to_h[t]
grad_x_seq.insert(0, grad_h)
grad_v_last = grad_h
return torch.stack(grad_x_seq), grad_v_last
IFNode_bptt_hardReset = memory_efficient_fusion(IFNode_bptt_hardReset, static_argnums=(4, 5))
class IFNodeATan(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v: torch.Tensor, v_th: float, v_reset: float):
v_seq, h_seq, spike_seq = IFNode_fptt_hardReset(x_seq, v, v_th, v_reset)
ctx.save_for_backward(h_seq, spike_seq)
ctx.v_th = v_th
ctx.v_reset = v_reset
return spike_seq, v_seq
@staticmethod
def backward(ctx, grad_spike_seq, grad_v_seq):
h_seq, spike_seq = ctx.saved_tensors
grad_x_seq, grad_v_last = IFNode_bptt_hardReset(grad_spike_seq, grad_v_seq, h_seq, spike_seq, ctx.v_th, ctx.v_reset)
return grad_x_seq, grad_v_last, None, None
class JITIFNode(neuron.IFNode):
def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = IFNodeATan.apply(
x_seq, self.v, self.v_threshold, self.v_reset)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
self.v_float_to_tensor(x_seq[0])
if self.v_reset is None:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq, self.v, self.v_threshold)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(x_seq, self.v, self.v_threshold)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq =self.jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq, self.v, self.v_threshold, self.v_reset)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(x_seq, self.v, self.v_threshold, self.v_reset)
return spike_seq
def forward_backward(multi_step_neuron, x):
multi_step_neuron(x).mean().backward()
multi_step_neuron.reset()
x.grad.zero_()
def cal_forward_backward_t(multi_step_neuron, x, repeat_times):
x.requires_grad_(True)
used_t = cuda_utils.cal_fun_t(repeat_times, x.device, forward_backward, multi_step_neuron, x)
return used_t * 1000
device = 'cuda:0'
ms_if = neuron.IFNode(step_mode='m', surrogate_function=surrogate.ATan(), backend='cupy')
ms_if.to(device)
jit_if = JITIFNode(step_mode='m')
jit_if.to(device)
T = 32
repeat_times = 1024
x = torch.rand(T, 1, device=device, requires_grad=True)
y_ms = ms_if(x)
y_ms.sum().backward()
grad_ms = x.grad.clone()
x.grad.zero_()
ms_if.reset()
y_jit = jit_if(x)
y_jit.sum().backward()
grad_jit = x.grad.clone()
x.grad.zero_()
jit_if.reset()
with torch.no_grad():
print((y_jit - y_ms).abs().max())
print((grad_jit - grad_ms).abs().max())
with open('./test_half.csv', 'w+') as csv_file:
csv_file.write('T, jit_if, t_cupy\n')
N = 2 ** 20
print('forward and backward')
ms_if.train()
jit_if.train()
# for T in tqdm.trange(1, T + 1):
for T in [2, 4, 8, 16, 32, 64, 128]:
x = torch.rand(T, N, device=device).half()
t_jit = cal_forward_backward_t(jit_if, x, repeat_times)
t_cupy = cal_forward_backward_t(ms_if, x, repeat_times)
print(f'{T}, {t_jit}, {t_cupy}')
csv_file.write(f'{T}, {t_jit}, {t_cupy}\n') (sj-dev) wfang@mlg-ThinkStation-P920:~$ /home/wfang/anaconda3/envs/sj-dev/bin/python /home/wfang/tempdir/w_jit.py
tensor(0., device='cuda:0')
tensor(1.1921e-07, device='cuda:0')
forward and backward
2, 2041.997970198281, 1939.34284534771
4, 2219.71859713085, 1938.3031243924052
8, 2580.2629694808275, 1901.0408432222903
16, 2918.2764051947743, 1735.7245936291292
32, 4113.050622399896, 2867.309346329421
64, 7693.348492030054, 5766.34434890002
128, 15210.309787653387, 11805.190281942487
|
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
All reactions