Skip to content

Commit

Permalink
Added workaround for Cuda Test Runner.
Browse files Browse the repository at this point in the history
  • Loading branch information
MoFtZ committed Jan 31, 2024
1 parent 1bb9adc commit 55d36d4
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 55 deletions.
107 changes: 53 additions & 54 deletions Src/ILGPU.Algorithms.Tests/Generic/AlgorithmsTestBase.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// ---------------------------------------------------------------------------------------
// ILGPU Algorithms
// Copyright (c) 2020-2023 ILGPU Project
// Copyright (c) 2020-2024 ILGPU Project
// www.ilgpu.net
//
// File: AlgorithmsTestBase.cs
Expand All @@ -27,7 +27,7 @@ protected AlgorithmsTestBase(ITestOutputHelper output, TestContext testContext)
/// <summary>
/// Compares two numbers for equality, within a defined tolerance.
/// </summary>
private class HalfPrecisionComparer
internal class HalfPrecisionComparer
: EqualityComparer<Half>
{
public readonly float Margin;
Expand Down Expand Up @@ -59,7 +59,7 @@ public override int GetHashCode(Half obj) =>
/// <summary>
/// Compares two numbers for equality, within a defined tolerance.
/// </summary>
private class FloatPrecisionComparer
internal class FloatPrecisionComparer
: EqualityComparer<float>
{
public readonly float Margin;
Expand Down Expand Up @@ -91,7 +91,7 @@ public override int GetHashCode(float obj) =>
/// <summary>
/// Compares two numbers for equality, within a defined tolerance.
/// </summary>
private class DoublePrecisionComparer
internal class DoublePrecisionComparer
: EqualityComparer<double>
{
public readonly double Margin;
Expand Down Expand Up @@ -123,7 +123,7 @@ public override int GetHashCode(double obj) =>
/// <summary>
/// Compares two numbers for equality, within a defined tolerance.
/// </summary>
private class HalfRelativeErrorComparer
internal class HalfRelativeErrorComparer
: EqualityComparer<Half>
{
public readonly float RelativeError;
Expand Down Expand Up @@ -163,7 +163,7 @@ public override int GetHashCode(Half obj) =>
/// <summary>
/// Compares two numbers for equality, within a defined tolerance.
/// </summary>
private class FloatRelativeErrorComparer
internal class FloatRelativeErrorComparer
: EqualityComparer<float>
{
public readonly float RelativeError;
Expand Down Expand Up @@ -203,7 +203,7 @@ public override int GetHashCode(float obj) =>
/// <summary>
/// Compares two numbers for equality, within a defined tolerance.
/// </summary>
private class DoubleRelativeErrorComparer
internal class DoubleRelativeErrorComparer
: EqualityComparer<double>
{
public readonly double RelativeError;
Expand Down Expand Up @@ -245,19 +245,33 @@ public override int GetHashCode(double obj) =>
/// </summary>
/// <param name="buffer">The target buffer.</param>
/// <param name="expected">The expected values.</param>
/// <param name="decimalPlaces">The acceptable error margin.</param>
public void VerifyWithinPrecision(
ArrayView<Half> buffer,
Half[] expected,
uint decimalPlaces)
/// <param name="comparer">The comparer to use.</param>
public void VerifyUsingComparer<T>(
ArrayView<T> buffer,
T[] expected,
IEqualityComparer<T> comparer)
where T : unmanaged
{
var data = buffer.GetAsArray(Accelerator.DefaultStream);
Assert.Equal(data.Length, expected.Length);

var comparer = new HalfPrecisionComparer(decimalPlaces);
Assert.Equal(expected, data, comparer);
}

/// <summary>
/// Verifies the contents of the given memory buffer.
/// </summary>
/// <param name="buffer">The target buffer.</param>
/// <param name="expected">The expected values.</param>
/// <param name="decimalPlaces">The acceptable error margin.</param>
public void VerifyWithinPrecision(
ArrayView<Half> buffer,
Half[] expected,
uint decimalPlaces) =>
VerifyUsingComparer(
buffer,
expected,
new HalfPrecisionComparer(decimalPlaces));

/// <summary>
/// Verifies the contents of the given memory buffer.
/// </summary>
Expand All @@ -267,14 +281,11 @@ public void VerifyWithinPrecision(
public void VerifyWithinPrecision(
ArrayView<float> buffer,
float[] expected,
uint decimalPlaces)
{
var data = buffer.GetAsArray(Accelerator.DefaultStream);
Assert.Equal(data.Length, expected.Length);

var comparer = new FloatPrecisionComparer(decimalPlaces);
Assert.Equal(expected, data, comparer);
}
uint decimalPlaces) =>
VerifyUsingComparer(
buffer,
expected,
new FloatPrecisionComparer(decimalPlaces));

/// <summary>
/// Verifies the contents of the given memory buffer.
Expand All @@ -285,14 +296,11 @@ public void VerifyWithinPrecision(
public void VerifyWithinPrecision(
ArrayView<double> buffer,
double[] expected,
uint decimalPlaces)
{
var data = buffer.GetAsArray(Accelerator.DefaultStream);
Assert.Equal(data.Length, expected.Length);

var comparer = new DoublePrecisionComparer(decimalPlaces);
Assert.Equal(expected, data, comparer);
}
uint decimalPlaces) =>
VerifyUsingComparer(
buffer,
expected,
new DoublePrecisionComparer(decimalPlaces));

/// <summary>
/// Verifies the contents of the given memory buffer.
Expand All @@ -303,14 +311,11 @@ public void VerifyWithinPrecision(
public void VerifyWithinRelativeError(
ArrayView<Half> buffer,
Half[] expected,
double relativeError)
{
var data = buffer.GetAsArray(Accelerator.DefaultStream);
Assert.Equal(data.Length, expected.Length);

var comparer = new HalfRelativeErrorComparer((float)relativeError);
Assert.Equal(expected, data, comparer);
}
double relativeError) =>
VerifyUsingComparer(
buffer,
expected,
new HalfRelativeErrorComparer((float)relativeError));

/// <summary>
/// Verifies the contents of the given memory buffer.
Expand All @@ -321,14 +326,11 @@ public void VerifyWithinRelativeError(
public void VerifyWithinRelativeError(
ArrayView<float> buffer,
float[] expected,
double relativeError)
{
var data = buffer.GetAsArray(Accelerator.DefaultStream);
Assert.Equal(data.Length, expected.Length);

var comparer = new FloatRelativeErrorComparer((float)relativeError);
Assert.Equal(expected, data, comparer);
}
double relativeError) =>
VerifyUsingComparer(
buffer,
expected,
new FloatRelativeErrorComparer((float)relativeError));

/// <summary>
/// Verifies the contents of the given memory buffer.
Expand All @@ -339,13 +341,10 @@ public void VerifyWithinRelativeError(
public void VerifyWithinRelativeError(
ArrayView<double> buffer,
double[] expected,
double relativeError)
{
var data = buffer.GetAsArray(Accelerator.DefaultStream);
Assert.Equal(data.Length, expected.Length);

var comparer = new DoubleRelativeErrorComparer(relativeError);
Assert.Equal(expected, data, comparer);
}
double relativeError) =>
VerifyUsingComparer(
buffer,
expected,
new DoubleRelativeErrorComparer(relativeError));
}
}
42 changes: 41 additions & 1 deletion Src/ILGPU.Algorithms.Tests/XMathTests.Pow.tt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// ---------------------------------------------------------------------------------------
// ILGPU Algorithms
// Copyright (c) 2020-2023 ILGPU Project
// Copyright (c) 2020-2024 ILGPU Project
// www.ilgpu.net
//
// File: XMathTests.Pow.tt/XMathTests.Pow.cs
Expand Down Expand Up @@ -48,6 +48,32 @@ namespace ILGPU.Algorithms.Tests
// and ensures a minimum error on each accelerator type.
partial class XMathTests
{
#region Nested Types

/// <summary>
/// WORKAROUND: The output of LibDevice __nv_pow(double, double) and
/// .NET Math.Pow(double, double) on Cuda Test Runner are different.
/// </summary>
private class CudaPowDoubleRelativeErrorComparer : DoubleRelativeErrorComparer
{
public CudaPowDoubleRelativeErrorComparer(double relativeError)
: base(relativeError)
{ }

public override bool Equals(double x, double y)
{
if ((double.IsPositiveInfinity(x) && double.IsNegativeInfinity(y)) ||
(double.IsNegativeInfinity(x) && double.IsPositiveInfinity(y)))
{
return true;
}

return base.Equals(x, y);
}
}

#endregion

<# foreach (var function in powFunctions) { #>
internal static void <#= function.KernelName #>(
Index1D index,
Expand Down Expand Up @@ -120,10 +146,24 @@ namespace ILGPU.Algorithms.Tests
v => Math<#= function.MathSuffix #>.<#= function.Name #>(v.X, v.Y))
.ToArray();
if (Accelerator.AcceleratorType == AcceleratorType.Cuda)
<#
if (function.DataType == "double") {
#>
VerifyUsingComparer(
output.View,
expected,
new CudaPowDoubleRelativeErrorComparer(
(<#= function.DataType #>)<#= function.RelativeError.Cuda #>));
<#
} else {
#>
VerifyWithinRelativeError(
output.View,
expected,
<#= function.RelativeError.Cuda #>);
<#
}
#>
else if (Accelerator.AcceleratorType == AcceleratorType.OpenCL)
VerifyWithinRelativeError(
output.View,
Expand Down

0 comments on commit 55d36d4

Please sign in to comment.