Skip to content

Commit 37c884a

Browse files
authored
support for f128 (#2427)
1 parent dec11bf commit 37c884a

File tree

5 files changed

+74
-0
lines changed

5 files changed

+74
-0
lines changed

enzyme/Enzyme/CApi.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ ConcreteType eunwrap(CConcreteType CDT, llvm::LLVMContext &ctx) {
9999
return ConcreteType(llvm::Type::getX86_FP80Ty(ctx));
100100
case DT_BFloat16:
101101
return ConcreteType(llvm::Type::getBFloatTy(ctx));
102+
case DT_FP128:
103+
return ConcreteType(llvm::Type::getFP128Ty(ctx));
102104
case DT_Unknown:
103105
return BaseType::Unknown;
104106
}
@@ -133,6 +135,8 @@ CConcreteType ewrap(const ConcreteType &CT) {
133135
return DT_X86_FP80;
134136
if (flt->isBFloatTy())
135137
return DT_BFloat16;
138+
if (flt->isFP128Ty())
139+
return DT_FP128;
136140
} else {
137141
switch (CT.SubTypeEnum) {
138142
case BaseType::Integer:

enzyme/Enzyme/CApi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ typedef enum {
6161
DT_Unknown = 6,
6262
DT_X86_FP80 = 7,
6363
DT_BFloat16 = 8,
64+
DT_FP128 = 9,
6465
} CConcreteType;
6566

6667
struct CDataPair {

enzyme/Enzyme/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,8 @@ static inline llvm::Type *FloatToIntTy(llvm::Type *T) {
620620
return llvm::IntegerType::get(T->getContext(), 64);
621621
if (T->isX86_FP80Ty())
622622
return llvm::IntegerType::get(T->getContext(), 80);
623+
if (T->isFP128Ty())
624+
return llvm::IntegerType::get(T->getContext(), 128);
623625
assert(0 && "unknown floating point type");
624626
return nullptr;
625627
}
@@ -641,6 +643,10 @@ static inline llvm::Type *IntToFloatTy(llvm::Type *T) {
641643
return llvm::Type::getFloatTy(T->getContext());
642644
case 64:
643645
return llvm::Type::getDoubleTy(T->getContext());
646+
case 80:
647+
return llvm::Type::getX86_FP80Ty(T->getContext());
648+
case 128:
649+
return llvm::Type::getFP128Ty(T->getContext());
644650
}
645651
}
646652
assert(0 && "unknown int to floating point type");
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi
2+
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s
3+
4+
; Function Attrs: mustprogress noinline nounwind optnone uwtable
5+
define double @tester(double %x0) #0 {
6+
entry:
7+
%x2 = alloca double, align 8
8+
%x3 = alloca fp128, align 16
9+
store double %x0, double* %x2, align 8
10+
%x4 = load double, double* %x2, align 8
11+
%x5 = fpext double %x4 to fp128
12+
store fp128 %x5, fp128* %x3, align 16
13+
%x6 = load fp128, fp128* %x3, align 16
14+
%x7 = fptrunc fp128 %x6 to double
15+
ret double %x7
16+
}
17+
18+
define double @test_derivative(double %x) {
19+
entry:
20+
%0 = tail call double (double (double)*, ...) @__enzyme_fwddiff(double (double)* nonnull @tester, double %x, double 0.0)
21+
ret double %0
22+
}
23+
24+
; Function Attrs: nounwind
25+
declare double @__enzyme_fwddiff(double (double)*, ...)
26+
27+
; CHECK: define internal double @fwddiffetester(double %x0, double %"x0'")
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %"x5'ipc" = fpext double %"x0'" to fp128
30+
; CHECK-NEXT: %"x7'ipc" = fptrunc fp128 %"x5'ipc" to double
31+
; CHECK-NEXT: ret double %"x7'ipc"
32+
; CHECK-NEXT: }
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
2+
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s
3+
4+
; Function Attrs: mustprogress noinline nounwind optnone uwtable
5+
define double @tester(double %x0) #0 {
6+
entry:
7+
%x2 = alloca double, align 8
8+
%x3 = alloca fp128, align 16
9+
store double %x0, double* %x2, align 8
10+
%x4 = load double, double* %x2, align 8
11+
%x5 = fpext double %x4 to fp128
12+
store fp128 %x5, fp128* %x3, align 16
13+
%x6 = load fp128, fp128* %x3, align 16
14+
%x7 = fptrunc fp128 %x6 to double
15+
ret double %x7
16+
}
17+
18+
define double @test_derivative(double %x) {
19+
entry:
20+
%0 = tail call double (double (double)*, ...) @__enzyme_autodiff(double (double)* nonnull @tester, double %x)
21+
ret double %0
22+
}
23+
24+
; Function Attrs: nounwind
25+
declare double @__enzyme_autodiff(double (double)*, ...)
26+
27+
; CHECK: define internal { double } @diffetester(double %x0, double %differeturn)
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %0 = insertvalue { double } undef, double %differeturn, 0
30+
; CHECK-NEXT: ret { double } %0
31+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)