1 / 75

Recursion Unrolling for Divide and Conquer Programs

Learn about automatic generation of large base cases, divide and conquer computations, transformations, and related work. Implement efficient matrix multiplication using recursive strategies. Understand advantages and disadvantages of recursion.

vnava
Download Presentation

Recursion Unrolling for Divide and Conquer Programs

An Image/Link below is provided (as is) to download presentation Download Policy: Content on the Website is provided to you AS IS for your information and personal use and may not be sold / licensed / shared on other websites without getting consent from its author. Content is provided to you AS IS for your information and personal use only. Download presentation by click this link. While downloading, if for some reason you are not able to download a presentation, the publisher may have deleted the file from their server. During download, if you can't get a presentation, the file might be deleted by the publisher.

E N D

Presentation Transcript


  1. Recursion Unrolling for Divide and Conquer Programs Radu Rugina and Martin Rinard Laboratory for Computer Science Massachusetts Institute of Technology

  2. What This Talk Is About • Automatic generation of efficient large base cases for divide and conquer programs

  3. Outline • Motivating Example • Computation Structure • Transformations • Related Work • Conclusion

  4. 1. Motivating Example

  5. Divide and Conquer Matrix Multiply A  B= R  = • Divide matrices into sub-matrices: A0 , A1, A2 etc • Use blocked matrix multiply equations

  6. Divide and Conquer Matrix Multiply A  B= R  = • Recursively multiply sub-matrices

  7. Divide and Conquer Matrix Multiply A  B= R  = • Terminate recursion with a simple base case

  8. Divide and Conquer Matrix Multiply void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); } Implements R += A  B

  9. Divide and Conquer Matrix Multiply Divide matrices in sub-matrices and recursively multiply sub-matrices void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }

  10. Divide and Conquer Matrix Multiply Identify sub-matrices with pointers void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }

  11. Divide and Conquer Matrix Multiply Use a simple algorithm for the base case void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }

  12. Divide and Conquer Matrix Multiply void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); } • Advantage of small base case: simplicity • Code is easy to: • Write • Maintain • Debug • Understand

  13. Divide and Conquer Matrix Multiply • Disadvantage: inefficiency • Large control flow overhead: • Most of the time is spent in dividing the matrix in sub-matrices void matmul(int *A, int *B, int *R, int n) { if (n == 1) { (*R) += (*A) * (*B); } else { matmul(A, B, R, n/4); matmul(A, B+(n/4), R+(n/4), n/4); matmul(A+2*(n/4), B, R+2*(n/4), n/4); matmul(A+2*(n/4), B+(n/4), R+3*(n/4), n/4); matmul(A+(n/4), B+2*(n/4), R, n/4); matmul(A+(n/4), B+3*(n/4), R+(n/4), n/4); matmul(A+3*(n/4), B+2*(n/4), R+2*(n/4), n/4); matmul(A+3*(n/4), B+3*(n/4), R+3*(n/4), n/4); }

  14. Hand Coded Implementation void serialmul(block *As, block *Bs, block *Rs) { int i, j; DOUBLE *A = (DOUBLE *) As; DOUBLE *B = (DOUBLE *) Bs; DOUBLE *R = (DOUBLE *) Rs; for (j = 0; j < 16; j += 2) { DOUBLE *bp = &B[j]; for (i = 0; i < 16; i += 2) { DOUBLE *ap = &A[i * 16]; DOUBLE *rp = &R[j + i * 16]; register DOUBLE s0_0 = rp[0], s0_1 = rp[1]; register DOUBLE s1_0 = rp[16], s1_1 = rp[17]; s0_0 += ap[0] * bp[0]; s0_1 += ap[0] * bp[1]; s1_0 += ap[16] * bp[0]; s1_1 += ap[16] * bp[1]; s0_0 += ap[1] * bp[16]; s0_1 += ap[1] * bp[17]; s1_0 += ap[17] * bp[16]; s1_1 += ap[17] * bp[17]; s0_0 += ap[2] * bp[32]; s0_1 += ap[2] * bp[33]; s1_0 += ap[18] * bp[32]; s1_1 += ap[18] * bp[33]; s0_0 += ap[3] * bp[48]; s0_1 += ap[3] * bp[49]; s1_0 += ap[19] * bp[48]; s1_1 += ap[19] * bp[49]; s0_0 += ap[4] * bp[64]; s0_1 += ap[4] * bp[65]; s1_0 += ap[20] * bp[64]; s1_1 += ap[20] * bp[65]; s0_0 += ap[5] * bp[80]; s0_1 += ap[5] * bp[81]; s1_0 += ap[21] * bp[80]; s1_1 += ap[21] * bp[81]; s0_0 += ap[6] * bp[96]; s0_1 += ap[6] * bp[97]; s1_0 += ap[22] * bp[96]; s1_1 += ap[22] * bp[97]; s0_0 += ap[7] * bp[112]; s0_1 += ap[7] * bp[113]; s1_0 += ap[23] * bp[112]; s1_1 += ap[23] * bp[113]; s0_0 += ap[8] * bp[128]; s0_1 += ap[8] * bp[129]; s1_0 += ap[24] * bp[128]; s1_1 += ap[24] * bp[129]; s0_0 += ap[9] * bp[144]; s0_1 += ap[9] * bp[145]; s1_0 += ap[25] * bp[144]; s1_1 += ap[25] * bp[145]; s0_0 += ap[10] * bp[160]; s0_1 += ap[10] * bp[161]; s1_0 += ap[26] * bp[160]; s1_1 += ap[26] * bp[161]; s0_0 += ap[11] * bp[176]; s0_1 += ap[11] * bp[177]; s1_0 += ap[27] * bp[176]; s1_1 += ap[27] * bp[177]; s0_0 += ap[12] * bp[192]; s0_1 += ap[12] * bp[193]; s1_0 += ap[28] * bp[192]; s1_1 += ap[28] * bp[193]; s0_0 += ap[13] * bp[208]; s0_1 += ap[13] * bp[209]; s1_0 += ap[29] * bp[208]; s1_1 += ap[29] * bp[209]; s0_0 += ap[14] * bp[224]; s0_1 += ap[14] * bp[225]; s1_0 += ap[30] * bp[224]; s1_1 += ap[30] * bp[225]; s0_0 += ap[15] * bp[240]; s0_1 += ap[15] * bp[241]; s1_0 += ap[31] * bp[240]; s1_1 += ap[31] * bp[241]; rp[0] = s0_0; rp[1] = s0_1; rp[16] = s1_0; rp[17] = s1_1; } } } cilk void matrixmul(long nb, block *A, block *B, block *R) { if (nb == 1) { flops = serialmul(A, B, R); } else if (nb >= 4) { spawn matrixmul(nb/4, A, B, R); spawn matrixmul(nb/4, A, B+(nb/4), R+(nb/4)); spawn matrixmul(nb/4, A+2*(nb/4), B+(nb/4), R+2*(nb/4)); spawn matrixmul(nb/4, A+2*(nb/4), B, R+3*(nb/4)); sync; spawn matrixmul(nb/4, A+(nb/4), B+2*(nb/4), R); spawn matrixmul(nb/4, A+(nb/4), B+3*(nb/4), R+(nb/4)); spawn matrixmul(nb/4, A+3*(nb/4), B+3*(nb/4), R+2*(nb/4)); spawn matrixmul(nb/4, A+3*(nb/4), B+3*(nb/4), R+3*(nb/4)); sync; } }

  15. Goal • The programmer writes simple code with small base cases • The compiler automatically generates efficient code with large base cases

  16. 2. Computation Structure

  17. Running Example – Array Increment void f(char *p, int n) if (n == 1) { /* base case: increment one element */ (*p) += 1; } else { f(p, n/2); /* increment first half */ f(p+n/2, n/2); /* increment second half */ } }

  18. Dynamic Call Tree for n=4 Execution of f(p,4)

  19. Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f

  20. Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f Activation Frame on the Stack

  21. Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f Executed Instructions

  22. Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f

  23. Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2

  24. Dynamic Call Tree for n=4 Execution of f(p,4) Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1

  25. Control Flow Overhead Execution of f(p,4) • Call overhead Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1

  26. Control Flow Overhead Execution of f(p,4) • Call overhead + Test overhead Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1

  27. Computation Execution of f(p,4) • Call overhead + Test overhead • Computation Test n=1 Call f Call f n=4 Test n=1 Call f Call f Test n=1 Call f Call f n=2 Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p Test n=1 Inc *p n=1

  28. Large Base Cases = Reduced Overhead Execution of f(p,4) Test n=2 Call f Call f n=4 Test n=2 Inc *p Inc *(p+1) Test n=2 Inc *p Inc *(p+1) n=2

  29. 3. Transformations

  30. Transformation 1: Recursion Inlining Start with the original recursive procedure void f (char *p, int n) if (n == 1) { (*p) += 1; } else { f(p, n/2); f(p+n/2, n/2); }

  31. Transformation 1: Recursion Inlining Make two copies of the original procedure void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { f1(p, n/2); f1(p+n/2, n/2); } void f2(char *p, int n) if (n == 1) { (*p) += 1; } else { f2(p, n/2); f2(p+n/2, n/2); }

  32. Transformation 1: Recursion Inlining Transform direct recursion to mutual recursion void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { f2(p, n/2); f2(p+n/2, n/2); } void f2(char *p, int n) if (n == 1) { (*p) += 1; } else { f1(p, n/2); f1(p+n/2, n/2); }

  33. Transformation 1: Recursion Inlining Inline procedure f2 at call sites in f1 void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { f2(p, n/2); f2(p+n/2, n/2); } void f2(char *p, int n) if (n == 1) { (*p) += 1; } else { f1(p, n/2); f1(p+n/2, n/2); }

  34. Transformation 1: Recursion Inlining void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { if (n/2 == 1) { *p += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); } if (n/2 == 1) { *(p+n/2) += 1; } else { f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } }

  35. Transformation 1: Recursion Inlining void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { if (n/2 == 1) { *p += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); } if (n/2 == 1) { *(p+n/2) += 1; } else { f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } } • Reduced procedure call overhead • More code exposed at the intra-procedural level • Opportunities to simplify control flow in the inlined code

  36. Transformation 1: Recursion Inlining void f1(char *p, int n) if (n == 1) { (*p) += 1; } else { if (n/2 == 1) { *p += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); } if (n/2 == 1) { *(p+n/2) += 1; } else { f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } } • Reduced procedure call overhead • More code exposed at the intra-procedural level • Opportunities to simplify control flow in the inlined code: • identical condition expressions

  37. Transformation 2: Conditional Fusion Merge if statements with identical conditions void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); }

  38. Transformation 2: Conditional Fusion Merge if statements with identical conditions void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } • Reduced branching overhead and bigger basic blocks • Larger base case for n/2 = 1

  39. Unrolling Iterations Repeatedly apply inlining and conditional fusion void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); }

  40. Second Unrolling Iteration void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f1(p, n/2/2); f1(p+n/2/2, n/2/2); f1(p+n/2, n/2/2); f1(p+n/2+n/4, n/2/2); } void f2(char *p, int n) if (n == 1) { *p += 1; } else { f2(p, n/2); f2(p+n/2, n/2); }

  41. Second Unrolling Iteration void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else { f2(p, n/2/2); f2(p+n/2/2, n/2/2); f2(p+n/2, n/2/2); f2(p+n/2+n/4, n/2/2); } void f2(char *p, int n) if (n == 1) { *p += 1; } else { f1(p, n/2); f1(p+n/2, n/2); }

  42. Result of Second Unrolling Iteration void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else if (n/2/2 == 1) { *p += 1; *(p+n/2/2) += 1; *(p+n/2) += 1; *(p+n/2+n/2/2) += 1; } else { f1(p, n/2/2/2); f1(p+n/2/2/2, n/2/2/2); f1(p+n/2/2, n/2/2/2); f1(p+n/2/2+n/2/2/2, n/2/2/2); f1(p+n/2, n/2/2/2); f1(p+n/2+n/2/2/2, n/2/2/2); f1(p+n/2+n/2/2, n/2/2/2); f1(p+n/2+n/2/2+n/2/2/2, n/2/2/2); }

  43. Unrolling Iterations • The unrolling process stops when the number of iterations reaches the desired unrolling factor • The unrolled recursive procedure: • Has base cases for larger problem sizes • Divides the given problem into more sub-problems of smaller sizes • In our example: • Base cases for n=1, n=2, and n=4 • Problems are divided into 8 problems of 1/8 size

  44. Speedup for Matrix Multiply Matrix of 512 x 512 elements

  45. Speedup for Matrix Multiply Matrix of 512 x 512 elements

  46. Speedup for Matrix Multiply Matrix of 1024 x 1024 elements

  47. Efficiency of Unrolled Recursive Part • Because the recursive part is also unrolled, • recursion may not exercise the large base cases • Which base case is executed depends on the size of the input problem • In our example: • For a problem of size n=8, the base case for n=1 is executed • For a problem of size n=16, the base case for n=2 is executed • The efficient base case for n=4 is not executed in these cases

  48. Solution: Recursion Re-Rolling • Roll back the recursive part of the unrolled procedure after the large base cases are generated • Re-Rolling ensures that larger base cases are always executed, independent of the input problem size • The compiler unrolls the recursive part only temporarily, to generate the base cases

  49. Transformation 3: Recursion Re-Rolling void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else if (n/2/2 == 1) { *p += 1; *(p+n/2/2) += 1; *(p+n/2) += 1; *(p+n/2+n/2/2) += 1; } else { f1(p, n/2/2/2); f1(p+n/2/2/2, n/2/2/2); f1(p+n/2/2, n/2/2/2); f1(p+n/2/2+n/2/2/2, n/2/2/2); f1(p+n/2, n/2/2/2); f1(p+n/2+n/2/2/2, n/2/2/2); f1(p+n/2+n/2/2, n/2/2/2); f1(p+n/2+n/2/2+n/2/2/2, n/2/2/2); }

  50. Transformation 3: Recursion Re-Rolling Identify the recursive part void f1(char *p, int n) if (n == 1) { *p += 1; } else if (n/2 == 1) { *p += 1; *(p+n/2) += 1; } else if (n/2/2 == 1) { *p += 1; *(p+n/2/2) += 1; *(p+n/2) += 1; *(p+n/2+n/2/2) += 1; } else { f1(p, n/2/2/2); f1(p+n/2/2/2, n/2/2/2); f1(p+n/2/2, n/2/2/2); f1(p+n/2/2+n/2/2/2, n/2/2/2); f1(p+n/2, n/2/2/2); f1(p+n/2+n/2/2/2, n/2/2/2); f1(p+n/2+n/2/2, n/2/2/2); f1(p+n/2+n/2/2+n/2/2/2, n/2/2/2); }

More Related