In this example, we show how to build a database DSL such that we can interact with SQL-like queries directly from within Python, as well as optimize their execution.
Say we are given an example like this:
SELECT * FROM T WHERE T.a > 5 + 5
Clearly, this can be optimized using constant folding to:
SELECT * FROM T WHERE T.a > 10
Through xDSL, we can build the necessary abstractions for such a query and implement optimizations, in particular the constant folding one.
While there are several ways to structure an IR in xDSL, we decide to go with a structure that connects different Operations (abstractions for some form of computation) through SSAValues. These SSAValues need a type, which is a form of compile-time information. This kind of information can be expressed as an Attribute. Therefore, we start by defining an Attribute for bags. These bags need to have some information about what is actually contained in them. This information is again compile-time information. Therefore, we encode it as an Attribute and add it to the bag attribute as a Parameter.
from xdsl.ir import *
from xdsl.irdl import *
from xdsl.dialects.builtin import *
from xdsl.dialects.arith import *
from xdsl.dialects.scf import *
from xdsl.pattern_rewriter import *
@irdl_attr_definition
class Bag(ParametrizedAttribute):
name = "sql.bag"
schema: ParameterDef[Attribute]
In the textual IR, which xDSL generates out-of-the-box for Attributes and Operations, this looks the following:
printer = Printer()
printer.print_attribute(Bag([i32]))
#sql.bag<i32>
Now we want to start abstracting forms of computation. At first, we just want to have an access to the table T, so we define a table operation. This Operation has an attribute encoding the table name and a result. This result is an SSAValue, which we can pass to other operations in the future.
from xdsl.irdl import attr_def, result_def
@irdl_op_definition
class Table(IRDLOperation):
name = "sql.table"
table_name = attr_def(StringAttr)
result_bag = result_def(Bag)
Using this operation, we can create the first full query:
SELECT * from T
t = Table.build(attributes={"table_name": StringAttr("T")}, result_types=[Bag([(i32)])])
printer.print_op(t)
%0 = "sql.table"() {"table_name" = "T"} : () -> #sql.bag<i32>
The top-level object in xDSL is a ModuleOp
, which is an operation representing a module
of code to compile.
module = ModuleOp([t])
print(module)
builtin.module { %0 = "sql.table"() {"table_name" = "T"} : () -> #sql.bag<i32> }
In order to abstract our goal query, we need an abstraction for selections. Again, this is a form of computation, so abstract it as an operation. The actual condition to filter with is nested inside that operation. The way to go about this in xDSL is using a Region. Additionally, we decide to reuse dialects already defined within xDSL, in particular the arith dialect.
@irdl_op_definition
class Selection(IRDLOperation):
name = "sql.selection"
input_bag = operand_def(Bag)
filter = region_def()
result_bag = result_def(Bag)
We instantiate this in two steps. First, we build the filter region, then the operation itself:
from xdsl.builder import Builder
@Builder.implicit_region((i32,))
def filter(args: tuple[BlockArgument, ...]):
# filter argument
(arg,) = args
const1 = Constant.from_int_and_width(5, 32)
const2 = Constant.from_int_and_width(5, 32)
add = Addi(const1, const2)
cmp = Cmpi(arg, add, "sgt")
# sgt stands for `signed greater than`. In xDSL, this is encoded as a predicate attribute with value 4.
Yield(cmp)
sel = Selection.build(result_types=[Bag([i32])], operands=[t], regions=[filter])
printer.print_op(sel)
%1 = "sql.selection"(%0) ({ ^0(%2 : i32): %3 = arith.constant 5 : i32 %4 = arith.constant 5 : i32 %5 = arith.addi %3, %4 : i32 %6 = arith.cmpi sgt, %2, %5 : i32 scf.yield %6 : i1 }) : (#sql.bag<i32>) -> #sql.bag<i32>
In a next step, we want to rewrite the IR created in the last step using constant folding. For that, we use the xDSL RewriteEngine, which applies RewritePatterns to the IR. As a first step, we define the necessary Pattern:
@dataclass
class ConstantFolding(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Addi, rewriter: PatternRewriter):
if isinstance(op.lhs.owner, Constant) and isinstance(op.rhs.owner, Constant):
lhs_data = cast(IntegerAttr[IntegerType], op.lhs.owner.value).value.data
rhs_data = cast(IntegerAttr[IntegerType], op.rhs.owner.value).value.data
lhs_type = cast(IntegerAttr[IntegerType], op.lhs.owner.value).type
rewriter.replace_matched_op(
Constant.from_int_and_width(lhs_data + rhs_data, lhs_type)
)
walker = PatternRewriteWalker(
GreedyRewritePatternApplier([ConstantFolding()]),
walk_regions_first=True,
apply_recursively=True,
walk_reverse=False,
)
walker.rewrite_op(sel)
True
printer.print_op(sel)
%1 = "sql.selection"(%0) ({ ^0(%2 : i32): %3 = arith.constant 5 : i32 %4 = arith.constant 5 : i32 %7 = arith.constant 10 : i32 %6 = arith.cmpi sgt, %2, %7 : i32 scf.yield %6 : i1 }) : (#sql.bag<i32>) -> #sql.bag<i32>
Now let's remove the left over constants and we are:
@dataclass
class DeadConstantElim(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter):
if len(op.result.uses) == 0:
rewriter.erase_matched_op()
walker = PatternRewriteWalker(
GreedyRewritePatternApplier([DeadConstantElim()]),
walk_regions_first=True,
apply_recursively=True,
walk_reverse=False,
)
walker.rewrite_op(sel)
True
printer.print_op(sel)
%1 = "sql.selection"(%0) ({ ^0(%2 : i32): %7 = arith.constant 10 : i32 %6 = arith.cmpi sgt, %2, %7 : i32 scf.yield %6 : i1 }) : (#sql.bag<i32>) -> #sql.bag<i32>
In this example, the SSAValue %1 is just flying around. We want to make sure it is actually bound to somewhere, such that we know what to do with it during compilation. Therefore, we introduce a SinkOp, which returns the data in the bag to the executor of the Query.
@irdl_op_definition
class SinkOp(IRDLOperation):
name = "sql.sink"
bag = operand_def(Bag)
In xDSL, all IRs need a ModuleOp as the outermost Operation, so we wrap it inside on:
module.body.block.add_op(SinkOp.build(operands=[sel]))
printer.print(module)
builtin.module { %0 = "sql.table"() {"table_name" = "T"} : () -> #sql.bag<i32> "sql.sink"(%1) : (#sql.bag<i32>) -> () }
And now we actually have an abstraction for the query, even with an optimization pass for constant folding.