Skip to content

Commit a69abcc

Browse files
erinzmoorecopybara-github
authored andcommitted
Check if context belongs to inference table before using.
PiperOrigin-RevId: 750581701
1 parent 9de370a commit a69abcc

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

xls/dslx/type_system_v2/inference_table.cc

+11-12
Original file line numberDiff line numberDiff line change
@@ -400,17 +400,16 @@ class InferenceTableImpl : public InferenceTable {
400400
(it == type_annotations_per_type_variable_.end())
401401
? std::vector<const TypeAnnotation*>()
402402
: it->second;
403-
if (parametric_context.has_value()) {
404-
if (mutable_parametric_context_data_.contains(*parametric_context)) {
405-
const auto& invocation_specific_annotations =
406-
mutable_parametric_context_data_.at(*parametric_context)
407-
.type_annotations_per_type_variable;
408-
const auto invocation_specific_it =
409-
invocation_specific_annotations.find(variable);
410-
if (invocation_specific_it != invocation_specific_annotations.end()) {
411-
absl::c_copy(invocation_specific_it->second,
412-
std::back_inserter(result));
413-
}
403+
if (parametric_context.has_value() &&
404+
(*parametric_context)->node()->owner() == &module_) {
405+
const auto& invocation_specific_annotations =
406+
mutable_parametric_context_data_.at(*parametric_context)
407+
.type_annotations_per_type_variable;
408+
const auto invocation_specific_it =
409+
invocation_specific_annotations.find(variable);
410+
if (invocation_specific_it != invocation_specific_annotations.end()) {
411+
absl::c_copy(invocation_specific_it->second,
412+
std::back_inserter(result));
414413
}
415414
}
416415
return result;
@@ -604,7 +603,7 @@ class InferenceTableImpl : public InferenceTable {
604603
std::optional<const ParametricContext*> context,
605604
const InferenceVariable* variable, const TypeAnnotation* annotation) {
606605
CHECK(variable->kind() == InferenceVariableKind::kType);
607-
if (context.has_value()) {
606+
if (context.has_value() && (*context)->node()->owner() == &module_) {
608607
mutable_parametric_context_data_.at(*context)
609608
.type_annotations_per_type_variable[variable]
610609
.push_back(annotation);

xls/dslx/type_system_v2/typecheck_module_v2_test.cc

+47
Original file line numberDiff line numberDiff line change
@@ -6200,6 +6200,53 @@ fn main() -> u5 {
62006200
IsOkAndHolds(HasTypeInfo(HasNodeWithType("var", "uN[5]"))));
62016201
}
62026202

6203+
TEST(TypecheckV2Test, ImportParametricFunctionWithMultipleInvocations) {
6204+
constexpr std::string_view kImported = R"(
6205+
pub fn add_one(x: u32) -> u32 { x + 1 }
6206+
6207+
pub fn some_function<N: u32, M: u32 = { add_one(N) }>() -> uN[M] { uN[M]:0 }
6208+
6209+
pub fn another_fn() -> u3 { some_function<2>() }
6210+
6211+
pub fn parametric_call<M: u32>() -> uN[M] { some_function<3, M>() }
6212+
6213+
)";
6214+
constexpr std::string_view kInt = R"(
6215+
import imported;
6216+
6217+
pub fn mid() -> u32 { imported::some_function<31>() }
6218+
6219+
pub fn default_import<N: u32, M: u32 = { imported::add_one(N) }>() -> uN[M] { uN[M]:0 }
6220+
6221+
)";
6222+
constexpr std::string_view kProgram = R"(
6223+
import imported;
6224+
import int;
6225+
6226+
const VAR = imported::some_function<4>();
6227+
const VAR2 = imported::some_function<3, 6>();
6228+
const VAR3 = imported::some_function<7>();
6229+
const VAR4 = imported::another_fn();
6230+
const VAR5 = imported::parametric_call<15>();
6231+
const VAR6 = int::mid();
6232+
6233+
fn main() -> u26 {
6234+
int::default_import<25>()
6235+
}
6236+
)";
6237+
ImportData import_data = CreateImportDataForTest();
6238+
XLS_EXPECT_OK(TypecheckV2(kImported, "imported", &import_data).status());
6239+
XLS_EXPECT_OK(TypecheckV2(kInt, "int", &import_data).status());
6240+
EXPECT_THAT(
6241+
TypecheckV2(kProgram, "main", &import_data),
6242+
IsOkAndHolds(HasTypeInfo(AllOf(
6243+
HasNodeWithType("VAR", "uN[5]"), HasNodeWithType("VAR2", "uN[6]"),
6244+
HasNodeWithType("VAR4", "uN[3]"), HasNodeWithType("VAR5", "uN[15]"),
6245+
HasNodeWithType("VAR6", "uN[32]"),
6246+
HasNodeWithType("int::default_import<25>()", "uN[26]"),
6247+
HasNodeWithType("VAR3", "uN[8]")))));
6248+
}
6249+
62036250
TEST(TypecheckV2Test, ImportParametricFunction) {
62046251
constexpr std::string_view kImported = R"(
62056252
pub fn some_function<N: u32>() -> uN[N] { uN[N]:0 }

0 commit comments

Comments
 (0)