CUDA】grid、block、thread的关系及thread索引的计算
https://hujingshuang.blog.csdn.net/article/details/53097222
import torch print(torch.version.cuda) # 11.0 print(torch.__version__) # 1.7.0 import numpy import pycuda.autoinit import pycuda.driver as cuda from pycuda.compiler import SourceModule mod = SourceModule(""" __global__ void matrix_mul(float *dest, float *a, float *b, int width) { int i = threadIdx.x + blockDim.x * blockIdx.x; int j = threadIdx.y + blockDim.y * blockIdx.y; float sum = 0; for(int k=0;k<width;k++) { float a_k = a[j*width+k]; float b_k = b[k*width+i]; sum += a_k*b_k; } dest[j*width+i] = sum; } """) matrix_mul = mod.get_function("matrix_mul") a = numpy.random.randn(400, 400).astype(numpy.float32) b = numpy.random.randn(400, 400).astype(numpy.float32) dest = numpy.zeros_like(a) width = numpy.int32(400) matrix_mul(cuda.Out(dest), cuda.In(a), cuda.In(b), width, block=(16, 16, 1), grid=(25, 25)) print(dest) print("="*10) print(numpy.dot(a,b))
from __future__ import print_function, division import os from PIL import Image import torch import torch.utils.data import torchvision from skimage import io from torch.utils.data import Dataset import random import numpy as np import pickle import lmdb import sys import cv2, numpy as np import sys import torch print(torch.version.cuda) # 11.0 print(torch.__version__) # 1.7.0 import pycuda.autoinit import pycuda.driver as cuda from pycuda.compiler import SourceModule mod = SourceModule(""" __global__ void test(float *a, float *dest) { float x = threadIdx.y; //[0,3) float y = blockIdx.x; //[0,2) int width = 4; int height = 3; int dest_width = 3; int dest_height = 2; int n=0; int srcIdxOffl = width * (height * (2 * n + 0) + (int) (y + (float) 0.0)) + (int) (x + (float) 0.0); float label = (a[srcIdxOffl]); dest[dest_width * (dest_height * (2 * n + 0) + (int) y) + (int) x] = label; } """) a = np.arange(24).reshape(2, 3, 4).astype(np.float32) a[0][0][0] = -1 b = np.zeros([2,2,3]).astype(np.float32) test = mod.get_function("test") test(cuda.In(a),cuda.Out(b), block=(1, 3, 1), grid=(2, 1, 1)) ccc = 0