Skip to content

Commit 6a2b53b

Browse files
committed
Return error status instead of silently erroring out during TensorRT weight conversion
1 parent 6e80e0f commit 6a2b53b

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

mlir-tensorrt/tensorrt/lib/Target/TensorRTEncodingOpInterface/NetworkEncoder.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -520,41 +520,41 @@ static void packNonSplatInt4Tensor(ElementsAttr values, int64_t count,
520520
}
521521
}
522522

523-
static void serializeSplatElements(DenseIntOrFPElementsAttr values,
524-
std::vector<int8_t> &data) {
523+
static LogicalResult serializeSplatElements(DenseIntOrFPElementsAttr values,
524+
std::vector<int8_t> &data) {
525525
assert(values.isSplat() && "expected SplatElementsAttr");
526526

527527
auto rtt = cast<RankedTensorType>(values.getType());
528528
if (rtt.getElementType().isInteger(32)) {
529529
std::fill_n(reinterpret_cast<int32_t *>(data.data()),
530530
values.getNumElements(), values.getSplatValue<int32_t>());
531-
return;
531+
return llvm::success();
532532
}
533533
if (rtt.getElementType().isInteger(8)) {
534534
std::fill_n(reinterpret_cast<int8_t *>(data.data()),
535535
values.getNumElements(), values.getSplatValue<int8_t>());
536-
return;
536+
return llvm::success();
537537
}
538538
if (rtt.getElementType().isF32()) {
539539
std::fill_n(reinterpret_cast<float *>(data.data()), values.getNumElements(),
540540
values.getSplatValue<float>());
541-
return;
541+
return llvm::success();
542542
}
543543
if (rtt.getElementType().isF16() || rtt.getElementType().isBF16()) {
544544
APInt tmp = values.getSplatValue<APFloat>().bitcastToAPInt();
545545
assert(tmp.getBitWidth() == 16 && "unexpected bitwidth");
546546
uint16_t fillValue = *reinterpret_cast<const uint16_t *>(tmp.getRawData());
547547
std::fill_n(reinterpret_cast<uint16_t *>(data.data()),
548548
values.getNumElements(), fillValue);
549-
return;
549+
return llvm::success();
550550
}
551551
if (rtt.getElementType().isFloat8E4M3FN()) {
552552
APInt tmp = values.getSplatValue<APFloat>().bitcastToAPInt();
553553
assert(tmp.getBitWidth() == 8 && "unexpected bitwidth");
554554
uint8_t fillValue = *reinterpret_cast<const uint8_t *>(tmp.getRawData());
555555
std::fill_n(reinterpret_cast<uint8_t *>(data.data()),
556556
values.getNumElements(), fillValue);
557-
return;
557+
return llvm::success();
558558
}
559559
if (rtt.getElementType().isInteger(4)) {
560560
APInt tmp = values.getSplatValue<APInt>();
@@ -566,11 +566,12 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values,
566566
packed |= ((value & 0x0F) << 4);
567567
// Fill `data` vector with `packed`
568568
std::fill_n(reinterpret_cast<uint8_t *>(data.data()), data.size(), packed);
569-
return;
569+
return llvm::success();
570570
}
571-
572-
llvm_unreachable("unsupported data type to convert MLIR splat attribute to "
573-
"TensorRT weights!");
571+
llvm::errs() << "Error: "
572+
<< "unsupported data type to convert MLIR splat attribute to "
573+
"TensorRT weights!";
574+
return llvm::failure();
574575
}
575576

576577
FailureOr<nvinfer1::Weights>
@@ -615,8 +616,10 @@ NvInferNetworkEncoder::getNvInferWeights(ElementsAttr values) {
615616
weights.values = data.data();
616617

617618
if (values.isSplat() && isa<DenseIntOrFPElementsAttr>(values)) {
618-
serializeSplatElements(cast<DenseIntOrFPElementsAttr>(values),
619-
weightsMap[values]);
619+
LogicalResult status = serializeSplatElements(
620+
cast<DenseIntOrFPElementsAttr>(values), weightsMap[values]);
621+
if (failed(status))
622+
return failure();
620623
return weights;
621624
}
622625

0 commit comments

Comments
 (0)