750 likes | 770 Views
Learn about automatic generation of large base cases, divide and conquer computations, transformations, and related work. Implement efficient matrix multiplication using recursive strategies. Understand advantages and disadvantages of recursion.
E N D
Recursion Unrolling for Divide and Conquer Programs Radu Rugina and Martin Rinard Laboratory for Computer Science Massachusetts Institute of Technology
What This Talk Is About • Automatic generation of efficient large base cases for divide and conquer programs
Outline • Motivating Example • Computation Structure • Transformations • Related Work • Conclusion
Divide and Conquer Matrix Multiply A B= R = • Divide matrices into sub-matrices: A0 , A1, A2 etc • Use blocked matrix multiply equations
Divide and Conquer Matrix Multiply A B= R = • Recursively multiply sub-matrices
Divide and Conquer Matrix Multiply A B= R = • Terminate recursion with a simple base case
Divide and Conquer Matrix Multiply void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); } Implements R += A B
Divide and Conquer Matrix Multiply Divide matrices in sub-matrices and recursively multiply sub-matrices void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }
Divide and Conquer Matrix Multiply Identify sub-matrices with pointers void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }
Divide and Conquer Matrix Multiply Use a simple algorithm for the base case void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }
Divide and Conquer Matrix Multiply void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); } • Advantage of small base case: simplicity • Code is easy to: • Write • Maintain • Debug • Understand
Divide and Conquer Matrix Multiply • Disadvantage: inefficiency • Large control flow overhead: • Most of the time is spent in dividing the matrix in sub-matrices void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }
Hand Coded Implementation void serialmul(block *As, block *Bs, block *Rs) { int i, j; DOUBLE *A = (DOUBLE *) As; DOUBLE *B = (DOUBLE *) Bs; DOUBLE *R = (DOUBLE *) Rs; for (j = 0; j < 16; j += 2) { DOUBLE *bp = &B[j]; for (i = 0; i < 16; i += 2) { DOUBLE *ap = &A[i * 16]; DOUBLE *rp = &R[j + i * 16]; register DOUBLE s0_0 = rp[0], s0_1 = rp[1]; register DOUBLE s1_0 = rp[16], s1_1 = rp[17]; s0_0 += ap[0] * bp[0]; s0_1 += ap[0] * bp[1]; s1_0 += ap[16] * bp[0]; s1_1 += ap[16] * bp[1]; s0_0 += ap[1] * bp[16]; s0_1 += ap[1] * bp[17]; s1_0 += ap[17] * bp[16]; s1_1 += ap[17] * bp[17]; s0_0 += ap[2] * bp[32]; s0_1 += ap[2] * bp[33]; s1_0 += ap[18] * bp[32]; s1_1 += ap[18] * bp[33]; s0_0 += ap[3] * bp[48]; s0_1 += ap[3] * bp[49]; s1_0 += ap[19] * bp[48]; s1_1 += ap[19] * bp[49]; s0_0 += ap[4] * bp[64]; s0_1 += ap[4] * bp[65]; s1_0 += ap[20] * bp[64]; s1_1 += ap[20] * bp[65]; s0_0 += ap[5] * bp[80]; s0_1 += ap[5] * bp[81]; s1_0 += ap[21] * bp[80]; s1_1 += ap[21] * bp[81]; s0_0 += ap[6] * bp[96]; s0_1 += ap[6] * bp[97]; s1_0 += ap[22] * bp[96]; s1_1 += ap[22] * bp[97]; s0_0 += ap[7] * bp[112]; s0_1 += ap[7] * bp[113]; s1_0 += ap[23] * bp[112]; s1_1 += ap[23] * bp[113]; s0_0 += ap[8] * bp[128]; s0_1 += ap[8] * bp[129]; s1_0 += ap[24] * bp[128]; s1_1 += ap[24] * bp[129]; s0_0 += ap[9] * bp[144]; s0_1 += ap[9] * bp[145]; s1_0 += ap[25] * bp[144]; s1_1 += ap[25] * bp[145]; s0_0 += ap[10] * bp[160]; s0_1 += ap[10] * bp[161]; s1_0 += ap[26] * bp[160]; s1_1 += ap[26] * bp[161]; s0_0 += ap[11] * bp[176]; s0_1 += ap[11] * bp[177]; s1_0 += ap[27] * bp[176]; s1_1 += ap[27] * bp[177]; s0_0 += ap[12] * bp[192]; s0_1 += ap[12] * bp[193]; s1_0 += ap[28] * bp[192]; s1_1 += ap[28] * bp[193]; s0_0 += ap[13] * bp[208]; s0_1 += ap[13] * bp[209]; s1_0 += ap[29] * bp[208]; s1_1 += ap[29] * bp[209]; s0_0 += ap[14] * bp[224]; s0_1 += ap[14] * bp[225]; s1_0 += ap[30] * bp[224]; s1_1 += ap[30] * bp[225]; s0_0 += ap[15] * bp[240]; s0_1 += ap[15] * bp[241]; s1_0 += ap[31] * bp[240]; s1_1 += ap[31] * bp[241]; rp[0] = s0_0; rp[1] = s0_1; rp[16] = s1_0; rp[17] = s1_1; } } } cilk void matrixmul(long nb, block *A, block *B, block *R) { if (nb == 1) { flops = serialmul(A, B, R); } else if (nb >= 4) { spawn matrixmul(nb/4, A, B, R); spawn matrixmul(nb/4, A, B+(nb/4), R+(nb/4)); spawn matrixmul(nb/4, A+2*(nb/4), B+(nb/4), R+2*(nb/4)); spawn matrixmul(nb/4, A+2*(nb/4), B, R+3*(nb/4)); sync; spawn matrixmul(nb/4, A+(nb/4), B+2*(nb/4), R); spawn matrixmul(nb/4, A+(nb/4), B+3*(nb/4), R+(nb/4)); spawn matrixmul(nb/4, A+3*(nb/4), B+3*(nb/4), R+2*(nb/4)); spawn matrixmul(nb/4, A+3*(nb/4), B+3*(nb/4), R+3*(nb/4)); sync; } }
Goal • The programmer writes simple code with small base cases • The compiler automatically generates efficient code with large base cases
Running Example – Array Increment void f(char *p, int n) if (n == 1) { /* base case: increment one element */ (*p) += 1; } else { f(p, n/2); /* increment first half */ f(p+n/2, n/2); /* increment second half */ } }
Dynamic Call Tree for n=4 Execution of f(p,4)
Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f
Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f Activation Frame on the Stack
Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f Executed Instructions
Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f
Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2
Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1
Control Flow Overhead Execution of f(p,4) • Call overhead Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1
Control Flow Overhead Execution of f(p,4) • Call overhead + Test overhead Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1
Computation Execution of f(p,4) • Call overhead + Test overhead • Computation Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1
Large Base Cases = Reduced Overhead Execution of f(p,4) Test n=2 Call f Call f n=4 Test n=2 Inc *p Inc *(p+1) Test n=2 Inc *p Inc *(p+1) n=2
Transformation 1: Recursion Inlining Start with the original recursive procedure void f (char *p, int n) if (n == 1) { (*p) += 1; } else { f(p, n/2); f(p+n/2, n/2); }
Transformation 1: Recursion Inlining Make two copies of the original procedure void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { f1(p, n/2); f1(p+n/2, n/2); } void f2(char *p, int n) if (n == 1) { (*p) += 1; } else { f2(p, n/2); f2(p+n/2, n/2); }
Transformation 1: Recursion Inlining Transform direct recursion to mutual recursion void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { f2(p, n/2); f2(p+n/2, n/2); } void f2(char *p, int n) if (n == 1) { (*p) += 1; } else { f1(p, n/2); f1(p+n/2, n/2); }
Transformation 1: Recursion Inlining Inline procedure f2 at call sites in f1 void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { f2(p, n/2); f2(p+n/2, n/2); } void f2(char *p, int n) if (n == 1) { (*p) += 1; } else { f1(p, n/2); f1(p+n/2, n/2); }
Transformation 1: Recursion Inlining void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { if (n/2 == 1) { *p += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); } if (n/2 == 1) { *(p+n/2) += 1; } else { f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } }
Transformation 1: Recursion Inlining void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { if (n/2 == 1) { *p += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); } if (n/2 == 1) { *(p+n/2) += 1; } else { f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } } • Reduced procedure call overhead • More code exposed at the intra-procedural level • Opportunities to simplify control flow in the inlined code
Transformation 1: Recursion Inlining void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { if (n/2 == 1) { *p += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); } if (n/2 == 1) { *(p+n/2) += 1; } else { f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } } • Reduced procedure call overhead • More code exposed at the intra-procedural level • Opportunities to simplify control flow in the inlined code: • identical condition expressions
Transformation 2: Conditional Fusion Merge if statements with identical conditions void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); }
Transformation 2: Conditional Fusion Merge if statements with identical conditions void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } • Reduced branching overhead and bigger basic blocks • Larger base case for n/2 = 1
Unrolling Iterations Repeatedly apply inlining and conditional fusion void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); }
Second Unrolling Iteration void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } void f2(char *p, int n) if (n == 1) { *p += 1; } else { f2(p, n/2); f2(p+n/2, n/2); }
Second Unrolling Iteration void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f2(p, n/2/2); f2(p+n/2/2, n/2/2); f2(p+n/2, n/2/2); f2(p+n/2+n/4, n/2/2); } void f2(char *p, int n) if (n == 1) { *p += 1; } else { f1(p, n/2); f1(p+n/2, n/2); }
Result of Second Unrolling Iteration void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else if (n/2/2 == 1) { *p += 1; *(p+n/2/2) += 1; *(p+n/2) += 1; *(p+n/2+n/2/2) += 1; } else { f1(p, n/2/2/2); f1(p+n/2/2/2, n/2/2/2); f1(p+n/2/2, n/2/2/2); f1(p+n/2/2+n/2/2/2, n/2/2/2); f1(p+n/2, n/2/2/2); f1(p+n/2+n/2/2/2, n/2/2/2); f1(p+n/2+n/2/2, n/2/2/2); f1(p+n/2+n/2/2+n/2/2/2, n/2/2/2); }
Unrolling Iterations • The unrolling process stops when the number of iterations reaches the desired unrolling factor • The unrolled recursive procedure: • Has base cases for larger problem sizes • Divides the given problem into more sub-problems of smaller sizes • In our example: • Base cases for n=1, n=2, and n=4 • Problems are divided into 8 problems of 1/8 size
Speedup for Matrix Multiply Matrix of 512 x 512 elements
Speedup for Matrix Multiply Matrix of 512 x 512 elements
Speedup for Matrix Multiply Matrix of 1024 x 1024 elements
Efficiency of Unrolled Recursive Part • Because the recursive part is also unrolled, • recursion may not exercise the large base cases • Which base case is executed depends on the size of the input problem • In our example: • For a problem of size n=8, the base case for n=1 is executed • For a problem of size n=16, the base case for n=2 is executed • The efficient base case for n=4 is not executed in these cases
Solution: Recursion Re-Rolling • Roll back the recursive part of the unrolled procedure after the large base cases are generated • Re-Rolling ensures that larger base cases are always executed, independent of the input problem size • The compiler unrolls the recursive part only temporarily, to generate the base cases
Transformation 3: Recursion Re-Rolling void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else if (n/2/2 == 1) { *p += 1; *(p+n/2/2) += 1; *(p+n/2) += 1; *(p+n/2+n/2/2) += 1; } else { f1(p, n/2/2/2); f1(p+n/2/2/2, n/2/2/2); f1(p+n/2/2, n/2/2/2); f1(p+n/2/2+n/2/2/2, n/2/2/2); f1(p+n/2, n/2/2/2); f1(p+n/2+n/2/2/2, n/2/2/2); f1(p+n/2+n/2/2, n/2/2/2); f1(p+n/2+n/2/2+n/2/2/2, n/2/2/2); }
Transformation 3: Recursion Re-Rolling Identify the recursive part void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else if (n/2/2 == 1) { *p += 1; *(p+n/2/2) += 1; *(p+n/2) += 1; *(p+n/2+n/2/2) += 1; } else { f1(p, n/2/2/2); f1(p+n/2/2/2, n/2/2/2); f1(p+n/2/2, n/2/2/2); f1(p+n/2/2+n/2/2/2, n/2/2/2); f1(p+n/2, n/2/2/2); f1(p+n/2+n/2/2/2, n/2/2/2); f1(p+n/2+n/2/2, n/2/2/2); f1(p+n/2+n/2/2+n/2/2/2, n/2/2/2); }