Skip to content

Commit

Permalink
Merge branch 'develop' into so-version
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Sep 5, 2019
2 parents a79beae + b4f1161 commit 51bd00b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
15 changes: 7 additions & 8 deletions src/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,13 @@ static void ins_quantize_int8(program& prog,
{
// Current MIOpen convolution does not support alpha and beta,
// so we need a separate multiply to adjust the output
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor =
std::round(1.0f / (ins_quant_params[0].first * ins_quant_params[1].first));
auto conv_op = any_cast<op::convolution>(ins->get_operator());
auto padding = conv_op.padding;
auto stride = conv_op.stride;
auto dilation = conv_op.dilation;
auto padding_mode = conv_op.padding_mode;
auto group = conv_op.group;
auto adjust_factor = 1.0f / (ins_quant_params[0].first * ins_quant_params[1].first);

auto quant_conv = prog.insert_instruction(
ins,
Expand Down
45 changes: 44 additions & 1 deletion test/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ TEST_CASE(target_copy)
}
}

TEST_CASE(int8_quantization)
TEST_CASE(int8_quantization_dot)
{
auto run_prog = [](migraphx::program p,
const migraphx::target& t,
Expand Down Expand Up @@ -958,4 +958,47 @@ TEST_CASE(int8_quantization)
}
}

TEST_CASE(int8_quantization_conv)
{
auto run_prog = [](migraphx::program p,
const migraphx::target& t,
std::vector<float>& res,
bool b_quantize = false) {
if(b_quantize)
{
std::vector<migraphx::program::parameter_map> cali_data;
migraphx::quantize_int8(p, t, cali_data);
}
p.compile(t);
migraphx::program::parameter_map m;

auto result = t.copy_from(p.eval(m));
result.visit([&](auto v) { res.assign(v.begin(), v.end()); });
};

auto create_program = [] {
migraphx::program p;
migraphx::shape sx{migraphx::shape::float_type, {4, 2, 2, 2}};
migraphx::shape sw{migraphx::shape::float_type, {4, 2, 2, 2}};
std::vector<float> v(sx.elements(), 0.5f);
auto input = p.add_literal(migraphx::literal(sx, v));
auto weights = p.add_literal(migraphx::literal(sw, v));
p.add_instruction(migraphx::op::convolution{}, input, weights);

return p;
};

{
auto p = create_program();
std::vector<float> quant_result;
migraphx::target cpu_t = migraphx::cpu::target{};
run_prog(p, cpu_t, quant_result, true);

std::vector<float> no_quant_result;
run_prog(p, cpu_t, no_quant_result);

EXPECT(migraphx::verify_range(quant_result, no_quant_result));
}
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 51bd00b

Please sign in to comment.