In this tutorial we will learn how to build, compose, and transform Iteration/Expression Trees (IETs).
Dimensions
are the building blocks of both Iterations
and Expressions
.
from devito import SpaceDimension, TimeDimension
dims = {'i': SpaceDimension(name='i'),
'j': SpaceDimension(name='j'),
'k': SpaceDimension(name='k'),
't0': TimeDimension(name='t0'),
't1': TimeDimension(name='t1')}
dims
{'i': i, 'j': j, 'k': k, 't0': t0, 't1': t1}
Elements such as Scalars
, Constants
and Functions
are used to build SymPy equations.
from devito import Grid, Constant, Function, TimeFunction
from devito.types import Array, Scalar
grid = Grid(shape=(10, 10))
symbs = {'a': Scalar(name='a'),
'b': Constant(name='b'),
'c': Array(name='c', shape=(3,), dimensions=(dims['i'],)).indexify(),
'd': Array(name='d',
shape=(3,3),
dimensions=(dims['j'],dims['k'])).indexify(),
'e': Function(name='e',
shape=(3,3,3),
dimensions=(dims['t0'],dims['t1'],dims['i'])).indexify(),
'f': TimeFunction(name='f', grid=grid).indexify()}
symbs
{'a': a, 'b': b, 'c': c[i], 'd': d[j, k], 'e': e[t0, t1, i], 'f': f[t, x, y]}
An IET Expression
wraps a SymPy equation. Below, DummyEq
is a subclass of sympy.Eq
with some metadata attached. What, when and how metadata are attached is here irrelevant.
from devito.ir.iet import Expression
from devito.ir.equations import DummyEq
from devito.tools import pprint
def get_exprs(a, b, c, d, e, f):
return [Expression(DummyEq(a, b + c + 5.)),
Expression(DummyEq(d, e - f)),
Expression(DummyEq(a, 4 * (b * a))),
Expression(DummyEq(a, (6. / b) + (8. * a)))]
exprs = get_exprs(symbs['a'],
symbs['b'],
symbs['c'],
symbs['d'],
symbs['e'],
symbs['f'])
pprint(exprs)
<Expression a = b + c[i] + 5.0> <Expression d[j, k] = e[t0, t1, i] - f[t, x, y]> <Expression a = 4*a*b> <Expression a = 8.0*a + 6.0/b>
An Iteration
typically wraps one or more Expression
s.
from devito.ir.iet import Iteration
def get_iters(dims):
return [lambda ex: Iteration(ex, dims['i'], (0, 3, 1)),
lambda ex: Iteration(ex, dims['j'], (0, 5, 1)),
lambda ex: Iteration(ex, dims['k'], (0, 7, 1)),
lambda ex: Iteration(ex, dims['t0'], (0, 4, 1)),
lambda ex: Iteration(ex, dims['t1'], (0, 4, 1))]
iters = get_iters(dims)
Here, we can see how blocks of Iterations
over Expressions
can be used to build loop nests.
def get_block1(exprs, iters):
# Perfect loop nest:
# for i
# for j
# for k
# expr0
return iters[0](iters[1](iters[2](exprs[0])))
def get_block2(exprs, iters):
# Non-perfect simple loop nest:
# for i
# expr0
# for j
# for k
# expr1
return iters[0]([exprs[0], iters[1](iters[2](exprs[1]))])
def get_block3(exprs, iters):
# Non-perfect non-trivial loop nest:
# for i
# for s
# expr0
# for j
# for k
# expr1
# expr2
# for p
# expr3
return iters[0]([iters[3](exprs[0]),
iters[1](iters[2]([exprs[1], exprs[2]])),
iters[4](exprs[3])])
block1 = get_block1(exprs, iters)
block2 = get_block2(exprs, iters)
block3 = get_block3(exprs, iters)
pprint(block1), print('\n')
pprint(block2), print('\n')
pprint(block3)
<Iteration i::i::(0, 3, 1)> <Iteration j::j::(0, 5, 1)> <Iteration k::k::(0, 7, 1)> <Expression a = b + c[i] + 5.0> <Iteration i::i::(0, 3, 1)> <Expression a = b + c[i] + 5.0> <Iteration j::j::(0, 5, 1)> <Iteration k::k::(0, 7, 1)> <Expression d[j, k] = e[t0, t1, i] - f[t, x, y]> <Iteration i::i::(0, 3, 1)> <Iteration t0::t0::(0, 4, 1)> <Expression a = b + c[i] + 5.0> <Iteration j::j::(0, 5, 1)> <Iteration k::k::(0, 7, 1)> <Expression d[j, k] = e[t0, t1, i] - f[t, x, y]> <Expression a = 4*a*b> <Iteration t1::t1::(0, 4, 1)> <Expression a = 8.0*a + 6.0/b>
And, finally, we can build Callable
kernels that will be used to generate C code. Note that Operator
is a subclass of Callable
.
from devito.ir.iet import Callable
kernels = [Callable('foo', block1, 'void', ()),
Callable('foo', block2, 'void', ()),
Callable('foo', block3, 'void', ())]
print('kernel no.1:\n' + str(kernels[0].ccode) + '\n')
print('kernel no.2:\n' + str(kernels[1].ccode) + '\n')
print('kernel no.3:\n' + str(kernels[2].ccode) + '\n')
kernel no.1: void foo() { for (int i = 0; i <= 3; i += 1) { for (int j = 0; j <= 5; j += 1) { for (int k = 0; k <= 7; k += 1) { a = b + c[i] + 5.0F; } } } } kernel no.2: void foo() { for (int i = 0; i <= 3; i += 1) { a = b + c[i] + 5.0F; for (int j = 0; j <= 5; j += 1) { for (int k = 0; k <= 7; k += 1) { d[j][k] = e[t0][t1][i] - f[t][x][y]; } } } } kernel no.3: void foo() { for (int i = 0; i <= 3; i += 1) { for (int t0 = 0; t0 <= 4; t0 += 1) { a = b + c[i] + 5.0F; } for (int j = 0; j <= 5; j += 1) { for (int k = 0; k <= 7; k += 1) { d[j][k] = e[t0][t1][i] - f[t][x][y]; a = 4*a*b; } } for (int t1 = 0; t1 <= 4; t1 += 1) { a = 8.0F*a + 6.0F/b; } } }
An IET is immutable. It can be "transformed" by replacing or dropping some of its inner nodes, but what this actually means is that a new IET is created. IETs are transformed by Transformer
visitors. A Transformer
takes in input a dictionary encoding replacement rules.
from devito.ir.iet import Transformer
# Replaces a Function's body with another
transformer = Transformer({block1: block2})
kernel_alt = transformer.visit(kernels[0])
print(kernel_alt)
void foo() { for (int i = 0; i <= 3; i += 1) { a = b + c[i] + 5.0F; for (int j = 0; j <= 5; j += 1) { for (int k = 0; k <= 7; k += 1) { d[j][k] = e[t0][t1][i] - f[t][x][y]; } } } }
Specific Expression
s within the loop nest can also be substituted.
# Replaces an expression with another
transformer = Transformer({exprs[0]: exprs[1]})
newblock = transformer.visit(block1)
newcode = str(newblock.ccode)
print(newcode)
for (int i = 0; i <= 3; i += 1) { for (int j = 0; j <= 5; j += 1) { for (int k = 0; k <= 7; k += 1) { d[j][k] = e[t0][t1][i] - f[t][x][y]; } } }
from devito.ir.iet import Block
import cgen as c
# Creates a replacer for replacing an expression
line1 = '// Replaced expression'
replacer = Block(c.Line(line1))
transformer = Transformer({exprs[1]: replacer})
newblock = transformer.visit(block2)
newcode = str(newblock.ccode)
print(newcode)
for (int i = 0; i <= 3; i += 1) { a = b + c[i] + 5.0F; for (int j = 0; j <= 5; j += 1) { for (int k = 0; k <= 7; k += 1) { // Replaced expression { } } } }
# Wraps an expression in comments
line1 = '// This is the opening comment'
line2 = '// This is the closing comment'
wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2))
transformer = Transformer({exprs[0]: wrapper(exprs[0])})
newblock = transformer.visit(block1)
newcode = str(newblock.ccode)
print(newcode)
for (int i = 0; i <= 3; i += 1) { for (int j = 0; j <= 5; j += 1) { for (int k = 0; k <= 7; k += 1) { // This is the opening comment { a = b + c[i] + 5.0F; } // This is the closing comment } } }