Skip to content

Commit

Permalink
Add the return value to omMMapBinaryFile function and fix a z/OS bug (#…
Browse files Browse the repository at this point in the history
…3002)

* Check the return value when mmapping a file and exit gently if failed

Signed-off-by: Tung D. Le <[email protected]>

* Fix a bug in z/OS code

Signed-off-by: Tung D. Le <[email protected]>

* Convert filename to EBCDIC when generating code for zOS

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Nov 15, 2024
1 parent 890ebea commit 868432d
Show file tree
Hide file tree
Showing 10 changed files with 266 additions and 156 deletions.
21 changes: 14 additions & 7 deletions src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ bool extractConstantsToFile(ModuleOp &module, std::string filepath,
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(module.getBody());
std::string fname = llvm::sys::path::filename(filepath).str() + '\0';
fname = (isZOS(module)) ? krnl::e2a_s(fname) : fname;
mlir::StringAttr valueAttr = mlir::StringAttr::get(context, fname);
create.llvm.globalOp(LLVM::LLVMArrayType::get(llvmI8Ty, fname.size()),
/*isConstant=*/true, LLVM::Linkage::Internal,
Expand Down Expand Up @@ -612,15 +613,15 @@ void loadConstantsFromFile(ModuleOp &module,
OpBuilder b(ctx);
MultiDialectBuilder<LLVMBuilder> create(b, loc);

Type llvmI1Ty = IntegerType::get(ctx, 1);
Type llvmI8Ty = IntegerType::get(ctx, 8);
Type llvmI64Ty = IntegerType::get(ctx, 64);
Type llvmI8PtrTy = getPointerType(ctx, llvmI8Ty);
Type llvmVoidTy = LLVM::LLVMVoidType::get(ctx);

// The following function will be emitted inside the IR to load constants from
// file.
std::string loadAllConstantsFuncName = "omLoadConstantsFromFile";
Type llvmFnType = LLVM::LLVMFunctionType::get(llvmVoidTy, {}, false);
Type llvmFnType = LLVM::LLVMFunctionType::get(llvmI1Ty, {}, false);

// If calledByEntryPoint, this function will be called by entry points.
// Otherwise, user program (C/C++/Java/Python) would call this function.
Expand All @@ -629,6 +630,7 @@ void loadConstantsFromFile(ModuleOp &module,
Operation *firstEntryPointOp =
getFirstEntryOpInBlock(module, entryGlobalOps);
assert(firstEntryPointOp && "No entry function exists");
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(firstEntryPointOp);
funcOp = create.llvm.func(
loadAllConstantsFuncName, llvmFnType, /*createUniqueFunc=*/true);
Expand All @@ -646,13 +648,16 @@ void loadConstantsFromFile(ModuleOp &module,
std::find(entryName.begin(), entryName.end(), '\0'), entryName.end());
auto entryFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(entryName);
assert(entryFunc && "Entry function not found");
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(
&entryFunc.getBody().front(), entryFunc.getBody().front().begin());
FlatSymbolRefAttr loadAllConstantsRef = create.llvm.getOrInsertSymbolRef(
module, LLVMBuilder::SymbolPostfix(module, loadAllConstantsFuncName),
llvmVoidTy, {},
llvmI1Ty, {},
/*isVarArg=*/false);
create.llvm.call({}, loadAllConstantsRef, {});
Value retVal = create.llvm.call({llvmI1Ty}, loadAllConstantsRef, {});
equalOrFailed(module, b, loc,
create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)), retVal);
}
} else {
OpBuilder::InsertionGuard guard(b);
Expand Down Expand Up @@ -697,8 +702,11 @@ void loadConstantsFromFile(ModuleOp &module,
// Call a function to mmap the binary file to memory.
Value isleVal = create.llvm.constant(llvmI64Ty, isle);
Value sizeVal = create.llvm.constant(llvmI64Ty, dataSize);
RuntimeAPI::callApi(b, loc, apiRegistry, RuntimeAPI::API::MMAP_BINARY_FILE,
Value retVal = RuntimeAPI::callApi(b, loc, apiRegistry,
RuntimeAPI::API::MMAP_BINARY_FILE,
{packedGlobalPtr, fnameI8Ptr, sizeVal, isleVal});
equalOrReturn(module, b, loc,
create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)), retVal, retVal);

// Now set pointers for constants in the IR
module->walk([&](LLVM::GlobalOp dataGlobalOp) -> WalkResult {
Expand All @@ -725,11 +733,10 @@ void loadConstantsFromFile(ModuleOp &module,
RuntimeAPI::callApi(b, loc, apiRegistry,
RuntimeAPI::API::GET_EXTERNAL_CONSTANT_ADDR,
{dataPtr, packedGlobalPtr, offsetVal});

return WalkResult::advance();
});

create.llvm._return();
create.llvm._return(create.llvm.constant(llvmI1Ty, static_cast<int64_t>(1)));
}

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 0 additions & 25 deletions src/Conversion/KrnlToLLVM/KrnlEntryPoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,31 +412,6 @@ class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
rewriter.getI64Type(), {rewriter.getI64Type()});
}

// Emit code for `IF lhs != rhs THEN return null ELSE do nothing`
void equalOrFailed(ModuleOp &module, PatternRewriter &rewriter, Location loc,
Value lhs, Value rhs, std::string errorMsg = "",
bool appendRHS = true) const {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (appendRHS)
create.krnl.printf(
StringRef(errorMsg), rhs, rewriter.getI64Type(), true);
else
create.krnl.printf(StringRef(errorMsg + "\n"));
// Set errno.
krnl::emitErrNo(module, rewriter, loc, EINVAL);
// Return NULL.
create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

void emitVerificationCodeForInputTensors(ModuleOp &module,
PatternRewriter &rewriter, Location loc,
const RuntimeAPIRegistry &apiRegistry, Value omTensorInputs,
Expand Down
43 changes: 43 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include "onnx-mlir/Compiler/OMCompilerRuntimeTypes.h"
#include "src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp"
#include "src/Dialect/Krnl/DialectBuilder.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp"
Expand Down Expand Up @@ -342,5 +343,47 @@ bool isZOS(ModuleOp module) {
return zOS;
}

void equalOrFailed(ModuleOp &module, OpBuilder &rewriter, Location loc,
Value lhs, Value rhs, std::string errorMsg, bool appendRHS) {
MLIRContext *context = rewriter.getContext();
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty()) {
if (appendRHS)
create.krnl.printf(
StringRef(errorMsg), rhs, rewriter.getI64Type(), true);
else
create.krnl.printf(StringRef(errorMsg + "\n"));
}
// Set errno.
emitErrNo(module, rewriter, loc, EINVAL);
// Return NULL.
create.llvm._return(create.llvm.null(getI8PointerType(context)));
});
}

void equalOrReturn(ModuleOp &module, OpBuilder &rewriter, Location loc,
Value lhs, Value rhs, Value retVal, std::string errorMsg) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(rewriter, loc);
create.llvm.ifThenElse(/*cond=*/
[&](const LLVMBuilder &createLLVM) {
return createLLVM.icmp(LLVM::ICmpPredicate::ne, lhs, rhs);
}, /*then=*/
[&](const LLVMBuilder &createLLVM) {
MultiDialectBuilder<LLVMBuilder, KrnlBuilder> create(createLLVM);
// Print an error message.
if (!errorMsg.empty())
create.krnl.printf(StringRef(errorMsg + "\n"));
// Return retVal.
create.llvm._return(retVal);
});
}

} // namespace krnl
} // namespace onnx_mlir
10 changes: 10 additions & 0 deletions src/Conversion/KrnlToLLVM/KrnlToLLVMHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ std::string e2a_s(std::string e_s);
void emitErrNo(mlir::ModuleOp module, mlir::OpBuilder &builder,
mlir::Location loc, int err);

/// Emit code for `IF lhs != rhs THEN return null ELSE do nothing`.
void equalOrFailed(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs,
std::string errorMsg = "", bool appendRHS = true);

/// Emit code for `IF lhs != rhs THEN return retVal ELSE do nothing`.
void equalOrReturn(mlir::ModuleOp &module, mlir::OpBuilder &rewriter,
mlir::Location loc, mlir::Value lhs, mlir::Value rhs, mlir::Value retVal,
std::string errorMsg = "");

/// Creates an LLVM pointer type with the given element type and address space.
/// This function is meant to be used in code supporting both typed and opaque
/// pointers, as it will create an opaque pointer with the given address space
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/KrnlToLLVM/RuntimeAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry(
: registry() {
MLIRContext *context = module.getContext();
auto voidTy = LLVM::LLVMVoidType::get(context);
Type int1Ty = IntegerType::get(context, 1);
auto int8Ty = IntegerType::get(context, 8);
auto opaquePtrTy = onnx_mlir::krnl::getPointerType(context, int8Ty);
auto opaquePtrPtrTy = onnx_mlir::krnl::getPointerType(context, opaquePtrTy);
Expand All @@ -88,7 +89,7 @@ RuntimeAPIRegistry::RuntimeAPIRegistry(
RuntimeAPI(API::GET_OMT_ARRAY, "omTensorListGetOmtArray", opaquePtrPtrTy, {opaquePtrTy}),
RuntimeAPI(API::PRINT_OMTENSOR, "omTensorPrint", voidTy, {opaquePtrTy, opaquePtrTy}),
RuntimeAPI(API::GET_OMTENSOR_LIST_SIZE, "omTensorListGetSize", int64Ty, {opaquePtrTy}),
RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", voidTy, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}),
RuntimeAPI(API::MMAP_BINARY_FILE, "omMMapBinaryFile", int1Ty, {opaquePtrPtrTy, opaquePtrTy, int64Ty, int64Ty}),
RuntimeAPI(API::GET_EXTERNAL_CONSTANT_ADDR, "omGetExternalConstantAddr", voidTy, {opaquePtrPtrTy, opaquePtrPtrTy, int64Ty}),
};
// clang-format on
Expand Down
120 changes: 54 additions & 66 deletions src/Runtime/OMExternalConstant.inc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ typedef int make_iso_compilers_happy;

#include <errno.h>
#include <inttypes.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
Expand Down Expand Up @@ -58,88 +59,75 @@ void checkEndianness(const char constPackIsLE) {
///
/// This function is thread-safe.
///
void omMMapBinaryFile(
void **constAddr, char *filename, int64_t size, int64_t isLE) {
checkEndianness(isLE);
char *fname = filename;
#ifdef __MVS__
// Convert the file name to EBCDIC for the open call.
char *tPath = strdup(fname);
if (!tPath) {
fprintf(stderr, "Error while strdup");
return;
}
__a2e_s(tPath);
fname = tPath;
#endif

bool omMMapBinaryFile(
void **constAddr, char *fname, int64_t size, int64_t isLE) {
if (constAddr == NULL) {
perror("Error: null pointer");
return;
fprintf(stderr, "Error: null pointer.");
return false;
}

if (constAddr[0] == NULL) {
char *filePath;
char *basePath = getenv("OM_CONSTANT_PATH");
if (basePath) {
size_t baseLen = strlen(basePath);
size_t fnameLen = strlen(fname);
size_t sepLen = strlen(DIR_SEPARATOR);
size_t filePathLen = baseLen + sepLen + fnameLen;
filePath = (char *)malloc(filePathLen);
if (!filePath) {
fprintf(stderr, "Error while malloc");
return;
}
memcpy(filePath, basePath, baseLen);
memcpy(filePath + baseLen, DIR_SEPARATOR, sepLen);
memcpy(filePath + baseLen + sepLen, fname, fnameLen);
filePath[filePathLen] = '\0';
} else {
filePath = (char *)fname;
}
int fd = open(filePath, O_RDONLY);
if (fd < 0) {
fprintf(stderr, "Error while opening %s\n", filePath);
return;
// Already mmaped. Nothing to do.
if (constAddr[0] != NULL)
return true;

char *filePath;
char *basePath = getenv("OM_CONSTANT_PATH");
if (basePath) {
size_t baseLen = strlen(basePath);
size_t fnameLen = strlen(fname);
size_t sepLen = strlen(DIR_SEPARATOR);
size_t filePathLen = baseLen + sepLen + fnameLen + 1;
filePath = (char *)malloc(filePathLen);
if (!filePath) {
fprintf(stderr, "Error while malloc: %s", strerror(errno));
return false;
}
snprintf(filePath, filePathLen, "%s%s%s", basePath, DIR_SEPARATOR, fname);
} else {
filePath = (char *)fname;
}
int fd = open(filePath, O_RDONLY);
if (fd < 0) {
fprintf(stderr, "Error while opening %s: %s\n", filePath, strerror(errno));
if (basePath)
free(filePath);
return false;
}

#ifdef __MVS__
void *tempAddr = mmap(0, size, PROT_READ, __MAP_MEGA, fd, 0);
#else
void *tempAddr = mmap(0, size, PROT_READ, MAP_SHARED, fd, 0);
#endif

if (tempAddr == MAP_FAILED) {
fprintf(stderr, "Error while mmapping %s\n", fname);
close(fd);
return;
}

/* Prepare to compare-and-swap to setup the shared constAddr.
* If we fail, another thread beat us so free our mmap.
*/
#ifdef __MVS__
void *expected = NULL;
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)tempAddr))
munmap(tempAddr, size);
#else
if (!__sync_bool_compare_and_swap(&constAddr[0], NULL, tempAddr))
munmap(tempAddr, size);
#endif

/* Either we succeeded in setting constAddr or someone else did it.
* Either way, constAddr is now setup. We can close our fd without
* invalidating the mmap.
*/
if (tempAddr == MAP_FAILED) {
fprintf(stderr, "Error while mmapping %s: %s\n", fname, strerror(errno));
close(fd);
if (basePath)
free(filePath);
return false;
}

/* Prepare to compare-and-swap to setup the shared constAddr.
* If we fail, another thread beat us so free our mmap.
*/
#ifdef __MVS__
free(tPath);
void *expected = NULL;
if (cds((cds_t *)&expected, (cds_t *)&constAddr[0], *(cds_t *)&tempAddr))
munmap(tempAddr, size);
#else
if (!__sync_bool_compare_and_swap(&constAddr[0], NULL, tempAddr))
munmap(tempAddr, size);
#endif

/* Either we succeeded in setting constAddr or someone else did it.
* Either way, constAddr is now setup. We can close our fd without
* invalidating the mmap.
*/
close(fd);
if (basePath)
free(filePath);
return true;
}

/// Return the address of a constant at a given offset.
Expand All @@ -153,11 +141,11 @@ void omMMapBinaryFile(
void omGetExternalConstantAddr(
void **outputAddr, void **baseAddr, int64_t offset) {
if (outputAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer.");
return;
}
if (baseAddr == NULL) {
perror("Error: null pointer");
fprintf(stderr, "Error: null pointer.");
return;
}
// Constant is already loaded. Nothing to do.
Expand Down
Loading

0 comments on commit 868432d

Please sign in to comment.