| 
 | 1 | +kernel void kernel_concat_f32_contiguous(  | 
 | 2 | +    global const char * p_src0, ulong off_src0,  | 
 | 3 | +    global const char * p_src1, ulong off_src1,  | 
 | 4 | +    global char * p_dst, ulong off_dst,  | 
 | 5 | +    int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice  | 
 | 6 | +    int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes)  | 
 | 7 | +    int d_ne0,  int d_ne1,  int d_ne2,  // dst->ne[0..2] for the slice  | 
 | 8 | +    int dim  | 
 | 9 | +) {  | 
 | 10 | +    global const float * src0 = (global const float*)((global char*)p_src0 + off_src0);  | 
 | 11 | +    global const float * src1 = (global const float*)((global char*)p_src1 + off_src1);  | 
 | 12 | +    global float * dst        = (global float*)((global char*)p_dst + off_dst);  | 
 | 13 | + | 
 | 14 | +    int i0 = get_global_id(0); // Index along dst's 0th dimension  | 
 | 15 | +    int i1 = get_global_id(1); // Index along dst's 1st dimension  | 
 | 16 | +    int i2 = get_global_id(2); // Index along dst's 2nd dimension  | 
 | 17 | + | 
 | 18 | +    if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) {  | 
 | 19 | +        return;  | 
 | 20 | +    }  | 
 | 21 | + | 
 | 22 | +    ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0;  | 
 | 23 | +    ulong src_idx;  | 
 | 24 | + | 
 | 25 | +    if (dim == 0) {  | 
 | 26 | +        if (i0 < d_ne00) { // Data from src0  | 
 | 27 | +            src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;  | 
 | 28 | +            dst[dst_idx] = src0[src_idx];  | 
 | 29 | +        } else { // Data from src1  | 
 | 30 | +            src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00);  | 
 | 31 | +            dst[dst_idx] = src1[src_idx];  | 
 | 32 | +        }  | 
 | 33 | +    } else if (dim == 1) {  | 
 | 34 | +        if (i1 < d_ne01) { // Data from src0  | 
 | 35 | +            src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;  | 
 | 36 | +            dst[dst_idx] = src0[src_idx];  | 
 | 37 | +        } else { // Data from src1  | 
 | 38 | +            src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0;  | 
 | 39 | +            dst[dst_idx] = src1[src_idx];  | 
 | 40 | +        }  | 
 | 41 | +    } else if (dim == 2) {  | 
 | 42 | +        if (i2 < d_ne02) { // Data from src0  | 
 | 43 | +            src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0;  | 
 | 44 | +            dst[dst_idx] = src0[src_idx];  | 
 | 45 | +        } else { // Data from src1  | 
 | 46 | + | 
 | 47 | +            src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0;  | 
 | 48 | +            dst[dst_idx] = src1[src_idx];  | 
 | 49 | +        }  | 
 | 50 | +    }  | 
 | 51 | +}  | 
 | 52 | + | 
 | 53 | +kernel void kernel_concat_f32_non_contiguous(  | 
 | 54 | +    global const char * p_src0, ulong off_src0,  | 
 | 55 | +    global const char * p_src1, ulong off_src1,  | 
 | 56 | +    global char * p_dst, ulong off_dst,  | 
 | 57 | + | 
 | 58 | +    long ne00, long ne01, long ne02, long ne03,  | 
 | 59 | +    ulong nb00, ulong nb01, ulong nb02, ulong nb03,  | 
 | 60 | + | 
 | 61 | +    ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1  | 
 | 62 | + | 
 | 63 | +    long d_ne0, long d_ne1, long d_ne2, long d_ne3,  | 
 | 64 | +    ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3,  | 
 | 65 | +    int dim  | 
 | 66 | +) {  | 
 | 67 | +    global const char * src0_base = p_src0 + off_src0;  | 
 | 68 | +    global const char * src1_base = p_src1 + off_src1;  | 
 | 69 | +    global char * dst_base        = p_dst + off_dst;  | 
 | 70 | + | 
 | 71 | +    long current_i1 = get_global_id(0); // Index for dst_dim_1  | 
 | 72 | +    long current_i2 = get_global_id(1); // Index for dst_dim_2  | 
 | 73 | +    long current_i3 = get_global_id(2); // Index for dst_dim_3  | 
 | 74 | + | 
 | 75 | +    if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) {  | 
 | 76 | +        return;  | 
 | 77 | +    }  | 
 | 78 | + | 
 | 79 | +    global const float * x_val_ptr;  | 
 | 80 | +    global float * y_val_ptr;  | 
 | 81 | + | 
 | 82 | +    for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) {  | 
 | 83 | +        bool use_src0;  | 
 | 84 | +        long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3;  | 
 | 85 | + | 
 | 86 | +        if (dim == 0) {  | 
 | 87 | +            use_src0 = (current_i0 < ne00);  | 
 | 88 | +            if (!use_src0) { s_i0 = current_i0 - ne00; }  | 
 | 89 | +        } else if (dim == 1) {  | 
 | 90 | +            use_src0 = (current_i1 < ne01);  | 
 | 91 | +            if (!use_src0) { s_i1 = current_i1 - ne01; }  | 
 | 92 | +        } else if (dim == 2) {  | 
 | 93 | +            use_src0 = (current_i2 < ne02);  | 
 | 94 | +            if (!use_src0) { s_i2 = current_i2 - ne02; }  | 
 | 95 | +        } else { // dim == 3  | 
 | 96 | +            use_src0 = (current_i3 < ne03);  | 
 | 97 | +            if (!use_src0) { s_i3 = current_i3 - ne03; }  | 
 | 98 | +        }  | 
 | 99 | + | 
 | 100 | +        if (use_src0) {  | 
 | 101 | +            x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00);  | 
 | 102 | +        } else {  | 
 | 103 | +            x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10);  | 
 | 104 | +        }  | 
 | 105 | + | 
 | 106 | +        y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0);  | 
 | 107 | +        *y_val_ptr = *x_val_ptr;  | 
 | 108 | +    }  | 
 | 109 | +}  | 
0 commit comments