# reference.py
def ref_kernel(input_tensor):
return input_tensor * 2
# my_kernel.py
import torch
import triton
import triton.language as tl
@triton.jit
def kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask)
tl.store(output_ptr + offsets, x * 2, mask=mask)
def custom_kernel(input_tensor):
output = torch.empty_like(input_tensor)
n_elements = input_tensor.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
kernel[grid](input_tensor, output, n_elements, BLOCK_SIZE=1024)
return output