Skip to content

Commit

Permalink
Implement new argument synthesis pass. (#2031)
Browse files Browse the repository at this point in the history
This pass performs function specialization by replacing 1 or more
arguments to a function with custom code blocks, which would be constant
expressions in the typical application. This functionality is an
alternative to some of the functionality in the current quake synthesis
pass. Some key differences:

  - It will support substitutions of any argument type.
  - It can be used to replace any permutation of arguments.
  - Substitutions must be type correct.
  - Can be tested from the command-line in full generality.

Add a test that demonstrates the command-line option in action.
  • Loading branch information
schweitzpgi authored Aug 6, 2024
1 parent 0b69865 commit 324e85e
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 1 deletion.
48 changes: 48 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1604,4 +1604,52 @@ def cc_CreateStringLiteralOp : CCOp<"string_literal"> {
$stringLiteral `:` qualified(type(results)) attr-dict
}];
}

def cc_ArgumentSubstitutionOp : CCOp<"arg_subst",
[IsolatedFromAbove, NoRegionArguments, NoTerminator, SingleBlock]> {
let summary = "An argument substition.";
let description = [{
This operation is used to define computations to produce a particular value.
The last Op in the block is the result and specifies the result type.
The code in the block will be substituted into a FuncOp to replace an
argument. The argument is erased from the function's signature, specializing
the function into a new function of reduced arity. (Typically, all arguments
are erased turning the function into a nullary.)

For example, given a function
```mlir
func.func @foo(%arg0: i32, %arg1: f32) ...
```
and a set of argument substitutions for the scalar arguments
```mlir
cc.arg_subst[0] {
%c42_i32 = arith.constant 42 : i32
}
cc.arg_subst[1] {
%cst = arith.constant 3.100000e+00 : f32
}
```
the argument synthesis pass can substitute the arguments and create a new
nullary function
```mlir
func.func @foo() {
%arg0 = arith.constant 42 : i32
%arg1 = arith.constant 3.1 : f32
...
```

Each arg_subst can hold an arbitrary block of code, allowing for the
construction of non-trivial values.

See also the `argument-synthesis` pass.
}];

let arguments = (ins
ConfinedAttr<I32Attr, [IntNonNegative]>:$position
);
let regions = (region SizedRegion<1>:$body);

let assemblyFormat = "`[` $position `]` $body attr-dict";
}

#endif // CUDAQ_OPTIMIZER_DIALECT_CC_OPS
7 changes: 7 additions & 0 deletions include/cudaq/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ inline std::unique_ptr<mlir::Pass> createPySynthCallableBlockArgs() {
return createPySynthCallableBlockArgs({});
}

/// Helper function to build an argument synthesis pass. The names of the
/// functions and the substitutions text can be built as an unzipped pair of
/// lists.
std::unique_ptr<mlir::Pass> createArgumentSynthesisPass(
const mlir::ArrayRef<mlir::StringRef> &funcNames,
const mlir::ArrayRef<mlir::StringRef> &substitutions);

// declarative passes
#define GEN_PASS_DECL
#define GEN_PASS_REGISTRATION
Expand Down
37 changes: 36 additions & 1 deletion include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,42 @@ def ApplySpecialization : Pass<"apply-op-specialization", "mlir::ModuleOp"> {
];
}

def BasisConversionPass: Pass<"basis-conversion", "mlir::ModuleOp"> {
def ArgumentSynthesis : Pass<"argument-synthesis", "mlir::func::FuncOp"> {
let summary = "Specialize a function by replacing arguments with constants";
let description = [{
This pass takes a list of functions and argument substitutions. For each
function in the list, the arguments to the function in the substitutions
list will be erased and replaced with a computed value (e.g., a constant)
provided in the substitution list. All arguments or some subset of arguments
may be substituted in this way.

To facilitate command-line testing, this pass can be run with the functions
suboption using filenames containing the argument substitutions. For
example, one might run
```console
cudaq-opt input.qke \
--argument-synthesis=functions="kernel1:subst1.qke,kernel2:subst2.qke"
```
where `kernel1`, `kernel2` are the names of functions and `subst1.qke` and
`subst2.qke` are quake source files contains arg_subst operations.

For running this pass from code, one can build the substitution code in a
std::string and use a prefix character '*' to indicate the text is inline
and not in a file.
```
kernel1:*"cc.arg_subst [0] { ... }"
```
}];

let options = [
ListOption<"funcList", "functions", "std::string",
"Function name and substitutions pairs (<func>:<filename>)">,
];
let dependentDialects = ["cudaq::cc::CCDialect",
"mlir::cf::ControlFlowDialect"];
}

def BasisConversionPass : Pass<"basis-conversion", "mlir::ModuleOp"> {
let summary = "Converts kernels to a set of basis operations.";
let description = [{
This pass takes as input a list of target (allowed) quantum operations.
Expand Down
146 changes: 146 additions & 0 deletions lib/Optimizer/Transforms/ArgumentSynthesis.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*******************************************************************************
* Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. *
* All rights reserved. *
* *
* This source code and the accompanying materials are made available under *
* the terms of the Apache License 2.0 which accompanies this distribution. *
******************************************************************************/

#include "PassDetails.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"

namespace cudaq::opt {
#define GEN_PASS_DEF_ARGUMENTSYNTHESIS
#include "cudaq/Optimizer/Transforms/Passes.h.inc"
} // namespace cudaq::opt

#define DEBUG_TYPE "argument-synthesis"

using namespace mlir;

namespace {
class ArgumentSynthesisPass
: public cudaq::opt::impl::ArgumentSynthesisBase<ArgumentSynthesisPass> {
public:
using ArgumentSynthesisBase::ArgumentSynthesisBase;

void runOnOperation() override {
func::FuncOp func = getOperation();
StringRef funcName = func.getName();
std::string text;
if (std::find_if(funcList.begin(), funcList.end(),
[&](const std::string &item) {
auto pos = item.find(':');
if (pos == std::string::npos)
return false;
std::string itemName = item.substr(0, pos);
bool result = itemName == funcName;
if (result)
text = item.substr(pos + 1);
return result;
}) == funcList.end()) {
// If the function isn't on the list, do nothing.
LLVM_DEBUG(llvm::dbgs() << funcName << " not in list.\n");
return;
}

// If there are no substitutions, we're done.
if (text.empty()) {
LLVM_DEBUG(llvm::dbgs() << funcName << " has no substitutions.");
return;
}

// If we're here, we have a FuncOp and we have substitutions that can be
// applied.
//
// 1. Create a Module with the substitutions that we'll be making.
auto *ctx = func.getContext();
LLVM_DEBUG(llvm::dbgs() << "substitution pattern: '" << text << "'\n");
auto substMod = [&]() -> OwningOpRef<ModuleOp> {
if (text.front() == '*') {
// Substitutions are a raw string after the '*' character.
return parseSourceString<ModuleOp>(text.substr(1), ctx);
}
// Substitutions are in a text file (command-line usage).
return parseSourceFile<ModuleOp>(text, ctx);
}();
assert(*substMod && "module must have been created");

// 2. Go through the Module and process each substitution.
std::vector<bool> processedArgs(func.getFunctionType().getNumInputs());
std::vector<std::tuple<unsigned, Value, Value>> replacements;
for (auto &op : *substMod) {
auto subst = dyn_cast<cudaq::cc::ArgumentSubstitutionOp>(op);
if (!subst)
continue;
auto pos = subst.getPosition();
if (pos >= processedArgs.size()) {
func.emitError("Argument " + std::to_string(pos) + " is invalid.");
signalPassFailure();
return;
}
if (processedArgs[pos]) {
func.emitError("Argument " + std::to_string(pos) +
" was already substituted.");
signalPassFailure();
return;
}

// OK, substitute the code for the argument.
Block &entry = func.getRegion().front();
processedArgs[pos] = true;
OpBuilder builder{ctx};
Block *splitBlock = entry.splitBlock(entry.begin());
builder.setInsertionPointToEnd(&entry);
builder.create<cf::BranchOp>(func.getLoc(), &subst.getBody().front());
Operation *lastOp = &subst.getBody().front().back();
builder.setInsertionPointToEnd(&subst.getBody().front());
builder.create<cf::BranchOp>(func.getLoc(), splitBlock);
func.getBlocks().splice(Region::iterator{splitBlock},
subst.getBody().getBlocks());
if (lastOp &&
lastOp->getResult(0).getType() == entry.getArgument(pos).getType()) {
LLVM_DEBUG(llvm::dbgs()
<< funcName << " argument " << std::to_string(pos)
<< " was substituted.\n");
replacements.emplace_back(pos, entry.getArgument(pos),
lastOp->getResult(0));
}
}

// Note: if we exited before here, any code that was cloned into the
// function is still dead and can be removed by a DCE.

// 3. Replace the block argument values with the freshly inserted new code.
BitVector replacedArgs(processedArgs.size());
for (auto [pos, fromVal, toVal] : replacements) {
replacedArgs.set(pos);
fromVal.replaceAllUsesWith(toVal);
}

// 4. Finish specializing func and erase any of func's arguments that were
// substituted.
func.eraseArguments(replacedArgs);
}
};
} // namespace

// Helper function that takes an unzipped pair of lists of function names and
// substitution code strings. This is meant to make adding this pass to a
// pipeline easier from within a tool (such as the JIT compiler).
std::unique_ptr<mlir::Pass> cudaq::opt::createArgumentSynthesisPass(
const ArrayRef<StringRef> &funcNames,
const ArrayRef<StringRef> &substitutions) {
SmallVector<std::string> pairs;
if (funcNames.size() == substitutions.size())
for (auto [name, text] : llvm::zip(funcNames, substitutions))
pairs.emplace_back(name.str() + ":*" + text.str());
return std::make_unique<ArgumentSynthesisPass>(
ArgumentSynthesisOptions{pairs});
}
1 change: 1 addition & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_cudaq_library(OptTransforms
AggressiveEarlyInlining.cpp
ApplyControlNegations.cpp
ApplyOpSpecialization.cpp
ArgumentSynthesis.cpp
BasisConversion.cpp
CombineQuantumAlloc.cpp
ConstPropComplex.cpp
Expand Down
15 changes: 15 additions & 0 deletions test/Quake/arg_subst.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// ========================================================================== //
// Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

cc.arg_subst [0] {
%1 = arith.constant 42 : i32
}

cc.arg_subst [1] {
%1 = arith.constant 3.1 : f32
}
40 changes: 40 additions & 0 deletions test/Quake/arg_subst_func.qke
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// ========================================================================== //
// Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

// RUN: cudaq-opt --argument-synthesis=functions=foo:%S/arg_subst.txt,blink:%S/arg_subst.txt --canonicalize %s | FileCheck %s

func.func private @bar(i32)
func.func private @baz(f32)

func.func @foo(%arg0: i32, %arg1: f32) {
call @bar(%arg0) : (i32) -> ()
call @baz(%arg1) : (f32) -> ()
return
}

// CHECK-LABEL: func.func @foo() {
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 42 : i32
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 3.100000e+00 : f32
// CHECK: call @bar(%[[VAL_0]]) : (i32) -> ()
// CHECK: call @baz(%[[VAL_1]]) : (f32) -> ()
// CHECK: return
// CHECK: }

func.func @blink(%arg0: i32, %arg1: i32) {
call @bar(%arg0) : (i32) -> ()
call @bar(%arg1) : (i32) -> ()
return
}

// CHECK-LABEL: func.func @blink(
// CHECK-SAME: %[[VAL_0:.*]]: i32) {
// CHECK: %[[VAL_1:.*]] = arith.constant 42 : i32
// CHECK: call @bar(%[[VAL_1]]) : (i32) -> ()
// CHECK: call @bar(%[[VAL_0]]) : (i32) -> ()
// CHECK: return
// CHECK: }

0 comments on commit 324e85e

Please sign in to comment.