44 lines
1.3 KiB
Plaintext
44 lines
1.3 KiB
Plaintext
#include <torch/extension.h>
|
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
#include <cassert>
|
|
#include <map>
|
|
|
|
void copy_blocks(
|
|
torch::Tensor& src,
|
|
torch::Tensor& dst,
|
|
const std::map<int64_t, int64_t>& block_mapping) {
|
|
torch::Device src_device = src.device();
|
|
torch::Device dst_device = dst.device();
|
|
cudaMemcpyKind memcpy_type;
|
|
if (src_device.is_cuda() && dst_device.is_cuda()) {
|
|
assert(src_device.index() == dst_device.index());
|
|
memcpy_type = cudaMemcpyDeviceToDevice;
|
|
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
|
|
memcpy_type = cudaMemcpyDeviceToHost;
|
|
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
|
|
memcpy_type = cudaMemcpyHostToDevice;
|
|
} else {
|
|
assert(false);
|
|
}
|
|
|
|
void *src_ptr = src.data_ptr();
|
|
void *dst_ptr = dst.data_ptr();
|
|
|
|
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
for (const auto& pair : block_mapping) {
|
|
int64_t src_block_number = pair.first;
|
|
int64_t dst_block_number = pair.second;
|
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
|
cudaMemcpyAsync(
|
|
dst_ptr + dst_offset,
|
|
src_ptr + src_offset,
|
|
block_size_in_bytes,
|
|
memcpy_type,
|
|
stream);
|
|
}
|
|
}
|