#!/usr/bin/env python # coding: utf-8 # In this tutorial we will learn how to build, compose, and transform Iteration/Expression Trees (IETs). # # # Part II - Bottom Up # # `Dimensions` are the building blocks of both `Iterations` and `Expressions`. # # In[1]: from devito import SpaceDimension, TimeDimension dims = {'i': SpaceDimension(name='i'), 'j': SpaceDimension(name='j'), 'k': SpaceDimension(name='k'), 't0': TimeDimension(name='t0'), 't1': TimeDimension(name='t1')} dims # Elements such as `Scalars`, `Constants` and `Functions` are used to build SymPy equations. # In[2]: from devito import Grid, Constant, Function, TimeFunction from devito.types import Array, Scalar grid = Grid(shape=(10, 10)) symbs = {'a': Scalar(name='a'), 'b': Constant(name='b'), 'c': Array(name='c', shape=(3,), dimensions=(dims['i'],)).indexify(), 'd': Array(name='d', shape=(3,3), dimensions=(dims['j'],dims['k'])).indexify(), 'e': Function(name='e', shape=(3,3,3), dimensions=(dims['t0'],dims['t1'],dims['i'])).indexify(), 'f': TimeFunction(name='f', grid=grid).indexify()} symbs # An IET `Expression` wraps a SymPy equation. Below, `DummyEq` is a subclass of `sympy.Eq` with some metadata attached. What, when and how metadata are attached is here irrelevant. # In[3]: from devito.ir.iet import Expression from devito.ir.equations import DummyEq from devito.tools import pprint def get_exprs(a, b, c, d, e, f): return [Expression(DummyEq(a, b + c + 5.)), Expression(DummyEq(d, e - f)), Expression(DummyEq(a, 4 * (b * a))), Expression(DummyEq(a, (6. / b) + (8. * a)))] exprs = get_exprs(symbs['a'], symbs['b'], symbs['c'], symbs['d'], symbs['e'], symbs['f']) pprint(exprs) # An `Iteration` typically wraps one or more `Expression`s. # In[4]: from devito.ir.iet import Iteration def get_iters(dims): return [lambda ex: Iteration(ex, dims['i'], (0, 3, 1)), lambda ex: Iteration(ex, dims['j'], (0, 5, 1)), lambda ex: Iteration(ex, dims['k'], (0, 7, 1)), lambda ex: Iteration(ex, dims['t0'], (0, 4, 1)), lambda ex: Iteration(ex, dims['t1'], (0, 4, 1))] iters = get_iters(dims) # Here, we can see how blocks of `Iterations` over `Expressions` can be used to build loop nests. # In[5]: def get_block1(exprs, iters): # Perfect loop nest: # for i # for j # for k # expr0 return iters[0](iters[1](iters[2](exprs[0]))) def get_block2(exprs, iters): # Non-perfect simple loop nest: # for i # expr0 # for j # for k # expr1 return iters[0]([exprs[0], iters[1](iters[2](exprs[1]))]) def get_block3(exprs, iters): # Non-perfect non-trivial loop nest: # for i # for s # expr0 # for j # for k # expr1 # expr2 # for p # expr3 return iters[0]([iters[3](exprs[0]), iters[1](iters[2]([exprs[1], exprs[2]])), iters[4](exprs[3])]) block1 = get_block1(exprs, iters) block2 = get_block2(exprs, iters) block3 = get_block3(exprs, iters) pprint(block1), print('\n') pprint(block2), print('\n') pprint(block3) # And, finally, we can build `Callable` _kernels_ that will be used to generate C code. Note that `Operator` is a subclass of `Callable`. # In[6]: from devito.ir.iet import Callable kernels = [Callable('foo', block1, 'void', ()), Callable('foo', block2, 'void', ()), Callable('foo', block3, 'void', ())] print('kernel no.1:\n' + str(kernels[0].ccode) + '\n') print('kernel no.2:\n' + str(kernels[1].ccode) + '\n') print('kernel no.3:\n' + str(kernels[2].ccode) + '\n') # An IET is immutable. It can be "transformed" by replacing or dropping some of its inner nodes, but what this actually means is that a new IET is created. IETs are transformed by `Transformer` visitors. A `Transformer` takes in input a dictionary encoding replacement rules. # In[7]: from devito.ir.iet import Transformer # Replaces a Function's body with another transformer = Transformer({block1: block2}) kernel_alt = transformer.visit(kernels[0]) print(kernel_alt) # Specific `Expression`s within the loop nest can also be substituted. # In[8]: # Replaces an expression with another transformer = Transformer({exprs[0]: exprs[1]}) newblock = transformer.visit(block1) newcode = str(newblock.ccode) print(newcode) # In[9]: from devito.ir.iet import Block import cgen as c # Creates a replacer for replacing an expression line1 = '// Replaced expression' replacer = Block(c.Line(line1)) transformer = Transformer({exprs[1]: replacer}) newblock = transformer.visit(block2) newcode = str(newblock.ccode) print(newcode) # In[10]: # Wraps an expression in comments line1 = '// This is the opening comment' line2 = '// This is the closing comment' wrapper = lambda n: Block(c.Line(line1), n, c.Line(line2)) transformer = Transformer({exprs[0]: wrapper(exprs[0])}) newblock = transformer.visit(block1) newcode = str(newblock.ccode) print(newcode)