Skip to content

Commit

Permalink
Merge pull request #17 from graphcore-research/half_distance_matrix
Browse files Browse the repository at this point in the history
Fix distance_matrix in float16
  • Loading branch information
AlCatt91 committed Apr 27, 2023
2 parents 5468162 + c7c63b2 commit f07a396
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 21 deletions.
36 changes: 24 additions & 12 deletions poptorch_experimental_addons/cpp/distance_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cmath>
#include <sstream>

#include <popops/Cast.hpp>
#include <popops/ElementWise.hpp>
#include <poputil/TileMapping.hpp>
#include <poputil/Util.hpp>
Expand Down Expand Up @@ -59,21 +60,33 @@ poplar::Tensor getCachedCopy(std::map<std::pair<size_t, size_t>, poplar::Tensor>
return copy;
}

// Like popops::cast, but alias rather than copying when the type already matches
poplar::Tensor castMaybe(poplar::Graph& graph,
const poplar::Tensor& tensor,
poplar::Type type,
poplar::program::Sequence& prog) {
if (type == tensor.elementType()) {
return tensor;
}
auto casted = graph.clone(type, tensor);
prog.add(popops::cast(graph, tensor, casted));
return casted;
}

poplar::Tensor l1distance(poplar::Graph& graph,
const poplar::Tensor& a,
const poplar::Tensor& b,
poplar::program::Sequence& prog,
const poplar::DebugContext& debugContext) {
if (a.rank() != 2 || b.rank() != 2 || a.dim(1) != b.dim(1)) {
std::ostringstream msg;
msg << "Bad arguments to l1distance, expected a.shape (M, K), b.shape (N, "
"K), actual"
msg << "Bad arguments to l1distance, expected a.shape (M, K), b.shape (N, K), actual"
<< " a.shape = " << a.shapeToString() << ", b.shape = " << b.shapeToString() << ".";
throw std::invalid_argument(msg.str());
}
const size_t n = b.dim(0);
poplar::Tensor out =
graph.addVariable(a.elementType(), {a.dim(0), b.dim(0)}, {debugContext, "l1dist_out"});
graph.addVariable(poplar::FLOAT, {a.dim(0), b.dim(0)}, {debugContext, "l1dist_out"});
mapTensor2Dblocks(graph, out);
const auto& mapping = graph.getTileMapping(out);
poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l1dist"});
Expand All @@ -93,7 +106,7 @@ poplar::Tensor l1distance(poplar::Graph& graph,
}
}
prog.add(poplar::program::Execute(cs));
return out;
return castMaybe(graph, out, a.elementType(), prog);
}

poplar::Tensor l1distancegrad(poplar::Graph& graph,
Expand All @@ -113,7 +126,7 @@ poplar::Tensor l1distancegrad(poplar::Graph& graph,
}
const size_t k = a.dim(1);
poplar::Tensor grad =
graph.addVariable(a.elementType(), a.shape(), {debugContext, "l1dist_grad"});
graph.addVariable(poplar::FLOAT, a.shape(), {debugContext, "l1dist_grad"});
mapTensor2Dblocks(graph, grad);
const auto& mapping = graph.getTileMapping(grad);
poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l1dist_grad"});
Expand All @@ -138,7 +151,7 @@ poplar::Tensor l1distancegrad(poplar::Graph& graph,
}
}
prog.add(poplar::program::Execute(cs));
return grad;
return castMaybe(graph, grad, a.elementType(), prog);
}

poplar::Tensor l2distance(poplar::Graph& graph,
Expand All @@ -148,14 +161,13 @@ poplar::Tensor l2distance(poplar::Graph& graph,
const poplar::DebugContext& debugContext) {
if (a.rank() != 2 || b.rank() != 2 || a.dim(1) != b.dim(1)) {
std::ostringstream msg;
msg << "Bad arguments to l2distance, expected a.shape (M, K), b.shape (N, "
"K), actual"
msg << "Bad arguments to l2distance, expected a.shape (M, K), b.shape (N, K), actual"
<< " a.shape = " << a.shapeToString() << ", b.shape = " << b.shapeToString() << ".";
throw std::invalid_argument(msg.str());
}
const size_t n = b.dim(0);
poplar::Tensor out =
graph.addVariable(a.elementType(), {a.dim(0), b.dim(0)}, {debugContext, "l2dist_out"});
graph.addVariable(poplar::FLOAT, {a.dim(0), b.dim(0)}, {debugContext, "l2dist_out"});
mapTensor2Dblocks(graph, out);
const auto& mapping = graph.getTileMapping(out);
poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l2dist"});
Expand All @@ -175,7 +187,7 @@ poplar::Tensor l2distance(poplar::Graph& graph,
}
}
prog.add(poplar::program::Execute(cs));
return out;
return castMaybe(graph, out, a.elementType(), prog);
}

poplar::Tensor l2distancegrad(poplar::Graph& graph,
Expand All @@ -197,7 +209,7 @@ poplar::Tensor l2distancegrad(poplar::Graph& graph,
}
const size_t k = a.dim(1);
poplar::Tensor grad =
graph.addVariable(a.elementType(), a.shape(), {debugContext, "l2dist_grad"});
graph.addVariable(poplar::FLOAT, a.shape(), {debugContext, "l2dist_grad"});
mapTensor2Dblocks(graph, grad);
const auto& mapping = graph.getTileMapping(grad);
poplar::ComputeSet cs = graph.addComputeSet({debugContext, "l2dist_grad"});
Expand All @@ -219,7 +231,7 @@ poplar::Tensor l2distancegrad(poplar::Graph& graph,
}
}
prog.add(poplar::program::Execute(cs));
return grad;
return castMaybe(graph, grad, a.elementType(), prog);
}

const popart::OperatorIdentifier L1DistanceId = {"ai.graphcore.pea", "L1Distance", 1};
Expand Down
39 changes: 30 additions & 9 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from typing import Callable, Dict
from typing import Callable, Dict, cast

import poptorch
import pytest
Expand Down Expand Up @@ -61,12 +61,13 @@ def test_autograd_proxy(device: str) -> None:
assert_close(outputs["grad_x"], torch.tensor(3.0))


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("p", [1, 2])
def test_distance_matrix(p: int) -> None:
def test_distance_matrix(p: int, dtype: torch.dtype) -> None:
torch.manual_seed(1234)
M, N, K = 10, 30, 50
tensor1 = 10 + 20 * torch.randn(size=(M, K), dtype=torch.float32)
tensor2 = -10 + 10 * torch.randn(size=(N, K), dtype=torch.float32)
M, N, K = 10, 30, 5
tensor1 = torch.randn(size=(M, K), dtype=dtype)
tensor2 = torch.randn(size=(N, K), dtype=dtype)

output_ipu = run_forward_and_backward(
lambda tensor1, tensor2: pea.distance_matrix(tensor1, tensor2, p),
Expand All @@ -75,12 +76,32 @@ def test_distance_matrix(p: int) -> None:
device="ipu",
)
output_torch = run_forward_and_backward(
lambda tensor1, tensor2: torch.cdist(tensor1, tensor2, p),
lambda tensor1, tensor2: cast(
torch.Tensor, (tensor1[:, None] - tensor2[None, :]).norm(p=p, dim=-1)
),
dict(tensor1=tensor1, tensor2=tensor2),
patterns={},
device="cpu",
)

assert_close(output_ipu["output"], output_torch["output"])
assert_close(output_ipu["grad_tensor1"], output_torch["grad_tensor1"])
assert_close(output_ipu["grad_tensor2"], output_torch["grad_tensor2"])
atol = {torch.float32: 1e-5, torch.float16: 2e-3}[dtype]
rtol = {torch.float32: 2e-6, torch.float16: 2e-3}[dtype]

assert_close(
output_ipu["output"],
output_torch["output"],
rtol=rtol,
atol=atol,
)
assert_close(
output_ipu["grad_tensor1"],
output_torch["grad_tensor1"],
rtol=rtol,
atol=atol,
)
assert_close(
output_ipu["grad_tensor2"],
output_torch["grad_tensor2"],
rtol=rtol,
atol=atol,
)

0 comments on commit f07a396

Please sign in to comment.