@@ -520,41 +520,41 @@ static void packNonSplatInt4Tensor(ElementsAttr values, int64_t count,
520
520
}
521
521
}
522
522
523
- static void serializeSplatElements (DenseIntOrFPElementsAttr values,
524
- std::vector<int8_t > &data) {
523
+ static LogicalResult serializeSplatElements (DenseIntOrFPElementsAttr values,
524
+ std::vector<int8_t > &data) {
525
525
assert (values.isSplat () && " expected SplatElementsAttr" );
526
526
527
527
auto rtt = cast<RankedTensorType>(values.getType ());
528
528
if (rtt.getElementType ().isInteger (32 )) {
529
529
std::fill_n (reinterpret_cast <int32_t *>(data.data ()),
530
530
values.getNumElements (), values.getSplatValue <int32_t >());
531
- return ;
531
+ return llvm::success () ;
532
532
}
533
533
if (rtt.getElementType ().isInteger (8 )) {
534
534
std::fill_n (reinterpret_cast <int8_t *>(data.data ()),
535
535
values.getNumElements (), values.getSplatValue <int8_t >());
536
- return ;
536
+ return llvm::success () ;
537
537
}
538
538
if (rtt.getElementType ().isF32 ()) {
539
539
std::fill_n (reinterpret_cast <float *>(data.data ()), values.getNumElements (),
540
540
values.getSplatValue <float >());
541
- return ;
541
+ return llvm::success () ;
542
542
}
543
543
if (rtt.getElementType ().isF16 () || rtt.getElementType ().isBF16 ()) {
544
544
APInt tmp = values.getSplatValue <APFloat>().bitcastToAPInt ();
545
545
assert (tmp.getBitWidth () == 16 && " unexpected bitwidth" );
546
546
uint16_t fillValue = *reinterpret_cast <const uint16_t *>(tmp.getRawData ());
547
547
std::fill_n (reinterpret_cast <uint16_t *>(data.data ()),
548
548
values.getNumElements (), fillValue);
549
- return ;
549
+ return llvm::success () ;
550
550
}
551
551
if (rtt.getElementType ().isFloat8E4M3FN ()) {
552
552
APInt tmp = values.getSplatValue <APFloat>().bitcastToAPInt ();
553
553
assert (tmp.getBitWidth () == 8 && " unexpected bitwidth" );
554
554
uint8_t fillValue = *reinterpret_cast <const uint8_t *>(tmp.getRawData ());
555
555
std::fill_n (reinterpret_cast <uint8_t *>(data.data ()),
556
556
values.getNumElements (), fillValue);
557
- return ;
557
+ return llvm::success () ;
558
558
}
559
559
if (rtt.getElementType ().isInteger (4 )) {
560
560
APInt tmp = values.getSplatValue <APInt>();
@@ -566,11 +566,12 @@ static void serializeSplatElements(DenseIntOrFPElementsAttr values,
566
566
packed |= ((value & 0x0F ) << 4 );
567
567
// Fill `data` vector with `packed`
568
568
std::fill_n (reinterpret_cast <uint8_t *>(data.data ()), data.size (), packed);
569
- return ;
569
+ return llvm::success () ;
570
570
}
571
-
572
- llvm_unreachable (" unsupported data type to convert MLIR splat attribute to "
573
- " TensorRT weights!" );
571
+ llvm::errs () << " Error: "
572
+ << " unsupported data type to convert MLIR splat attribute to "
573
+ " TensorRT weights!" ;
574
+ return llvm::failure ();
574
575
}
575
576
576
577
FailureOr<nvinfer1::Weights>
@@ -615,8 +616,10 @@ NvInferNetworkEncoder::getNvInferWeights(ElementsAttr values) {
615
616
weights.values = data.data ();
616
617
617
618
if (values.isSplat () && isa<DenseIntOrFPElementsAttr>(values)) {
618
- serializeSplatElements (cast<DenseIntOrFPElementsAttr>(values),
619
- weightsMap[values]);
619
+ LogicalResult status = serializeSplatElements (
620
+ cast<DenseIntOrFPElementsAttr>(values), weightsMap[values]);
621
+ if (failed (status))
622
+ return failure ();
620
623
return weights;
621
624
}
622
625
0 commit comments