Skip to content

Commit

Permalink
Function template specialization
Browse files Browse the repository at this point in the history
  • Loading branch information
aneshlya committed Aug 8, 2023
1 parent 212d34f commit 81e507b
Show file tree
Hide file tree
Showing 14 changed files with 459 additions and 67 deletions.
30 changes: 29 additions & 1 deletion src/func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,11 @@ void FunctionTemplate::Print() const {
void FunctionTemplate::GenerateIR() const {
for (const auto &inst : instantiations) {
Function *func = const_cast<Function *>(inst.second->parentFunction);
func->GenerateIR();
if (func != nullptr) {
func->GenerateIR();
} else {
Error(inst.second->pos, "Template function specialization was declared but never defined.");
}
}
}

Expand Down Expand Up @@ -984,6 +988,30 @@ Symbol *FunctionTemplate::AddInstantiation(const std::vector<std::pair<const Typ
return instSym;
}

Symbol *FunctionTemplate::AddSpecialization(const FunctionType *ftype,
const std::vector<std::pair<const Type *, SourcePos>> &types,
SourcePos pos) {
const TemplateParms *typenames = GetTemplateParms();
Assert(typenames);
TemplateInstantiation templInst(*typenames, types);

// Create a function symbol
Symbol *instSym = templInst.InstantiateTemplateSymbol(sym);
instSym->type = ftype;
instSym->pos = pos;

TemplateArgs *templArgs = new TemplateArgs(types);

// Check if we have previously declared specialization and we are about to define it.
Symbol *funcSym = LookupInstantiation(types);
if (funcSym != nullptr) {
return funcSym;
} else {
instantiations.push_back(std::make_pair(templArgs, instSym));
}
return instSym;
}

///////////////////////////////////////////////////////////////////////////
// TemplateInstantiation

Expand Down
4 changes: 3 additions & 1 deletion src/func.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ class FunctionTemplate {

Symbol *LookupInstantiation(const std::vector<std::pair<const Type *, SourcePos>> &types);
Symbol *AddInstantiation(const std::vector<std::pair<const Type *, SourcePos>> &types);
Symbol *AddSpecialization(const FunctionType *ftype, const std::vector<std::pair<const Type *, SourcePos>> &types,
SourcePos pos);

// Generate code for instantiations
// Generate code for instantiations and specializations.
void GenerateIR() const;

void Print() const;
Expand Down
136 changes: 92 additions & 44 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,62 +1136,110 @@ void Module::AddFunctionTemplateDefinition(const TemplateParms *templateParmList
ast->AddFunctionTemplate(sym, code);
}

void Module::AddFunctionTemplateInstantiation(const std::string &name,
const std::vector<std::pair<const Type *, SourcePos>> &types,
const FunctionType *ftype, SourcePos pos) {
FunctionTemplate *Module::MatchFunctionTemplate(const std::string &name, const FunctionType *ftype,
std::vector<std::pair<const Type *, SourcePos>> &normTypes,
SourcePos pos) {
if (ftype == nullptr) {
Assert(m->errorCount > 0);
return nullptr;
}
std::vector<TemplateSymbol *> matches;
bool found = symbolTable->LookupFunctionTemplate(name, &matches);
if (!found) {
Error(pos, "No matching function template was found.");
return nullptr;
}
// Do template argument "normalization", i.e apply "varying type default":
//
// template <typename T> void foo(T t);
// foo<int>(1); // T is assumed to be "varying int" here.
for (auto &arg : normTypes) {
if (arg.first->GetVariability() == Variability::Unbound) {
arg.first = arg.first->GetAsVaryingType();
}
}

if (found) {
// TODO: need to outline this copy-paste code.
// Do template argument "normalization", i.e apply "varying type default":
//
// template <typename T> void foo(T t);
// foo<int>(1); // T is assumed to be "varying int" here.
std::vector<std::pair<const Type *, SourcePos>> normTypes(types);
for (auto &arg : normTypes) {
if (arg.first->GetVariability() == Variability::Unbound) {
arg.first = arg.first->GetAsVaryingType();
}
FunctionTemplate *templ = nullptr;
for (auto &templateSymbol : matches) {
// Number of template parameters must match.
if (normTypes.size() != templateSymbol->templateParms->GetCount()) {
// We don't have default parameters yet, so just matching the size exactly.
continue;
}

FunctionTemplate *templ = nullptr;
for (auto &templateSymbol : matches) {
// Number of template parameters must match.
if (normTypes.size() != templateSymbol->templateParms->GetCount()) {
// We don't have default parameters yet, so just matching the size exactly.
continue;
// Number of function parameters must match.
if (!ftype || !templateSymbol->type || ftype->GetNumParameters() != templateSymbol->type->GetNumParameters()) {
continue;
}
bool matched = true;
TemplateInstantiation inst(*(templateSymbol->templateParms), normTypes);
for (int i = 0; i < ftype->GetNumParameters(); i++) {
const Type *instParam = ftype->GetParameterType(i);
const Type *templateParam = templateSymbol->type->GetParameterType(i)->ResolveDependence(inst);
if (!Type::Equal(instParam, templateParam)) {
matched = false;
break;
}
}
if (matched) {
templ = templateSymbol->functionTemplate;
}
}
return templ;
}

// Number of function parameters must match.
if (!ftype || !templateSymbol->type ||
ftype->GetNumParameters() != templateSymbol->type->GetNumParameters()) {
continue;
}
void Module::AddFunctionTemplateInstantiation(const std::string &name,
const std::vector<std::pair<const Type *, SourcePos>> &types,
const FunctionType *ftype, SourcePos pos) {
std::vector<std::pair<const Type *, SourcePos>> normTypes(types);
FunctionTemplate *templ = MatchFunctionTemplate(name, ftype, normTypes, pos);
if (templ) {
templ->AddInstantiation(normTypes);
} else {
Error(pos, "No matching function template found for instantiation.");
}
}

TemplateInstantiation inst(*(templateSymbol->templateParms), normTypes);
bool matched = true;
for (int i = 0; i < ftype->GetNumParameters(); i++) {
const Type *instParam = ftype->GetParameterType(i);
const Type *templateParam = templateSymbol->type->GetParameterType(i)->ResolveDependence(inst);
if (!Type::Equal(instParam, templateParam)) {
matched = false;
break;
}
}
void Module::AddFunctionTemplateSpecializationDefinition(const std::string &name, const FunctionType *ftype,
const std::vector<std::pair<const Type *, SourcePos>> &types,
SourcePos pos, Stmt *code) {
std::vector<std::pair<const Type *, SourcePos>> normTypes(types);
FunctionTemplate *templ = MatchFunctionTemplate(name, ftype, normTypes, pos);
if (templ == nullptr) {
Error(pos, "No matching function template found for specialization.");
return;
}
Symbol *sym = templ->LookupInstantiation(normTypes);
if (sym == nullptr || code == nullptr) {
Assert(m->errorCount > 0);
return;
}
sym->pos = code->pos;

if (matched) {
templ = templateSymbol->functionTemplate;
break;
}
}
// Update already created symbol with real function type and function implementation
sym->type = ftype;
Function *inst = new Function(sym, code);
sym->parentFunction = inst;
}

if (templ) {
templ->AddInstantiation(normTypes);
} else {
Error(pos, "No matching function template found for instantiation.");
void Module::AddFunctionTemplateSpecializationDeclaration(const std::string &name, const FunctionType *ftype,
const std::vector<std::pair<const Type *, SourcePos>> &types,
SourcePos pos) {
std::vector<std::pair<const Type *, SourcePos>> normTypes(types);
FunctionTemplate *templ = MatchFunctionTemplate(name, ftype, normTypes, pos);
if (templ == nullptr) {
Error(pos, "No matching function template found for specialization.");
return;
}
Symbol *sym = templ->LookupInstantiation(normTypes);
if (sym != nullptr) {
if (Type::Equal(sym->type, ftype) && sym->parentFunction != nullptr) {
Error(pos, "Template function specialization was already defined.");
return;
}
}

templ->AddSpecialization(ftype, normTypes, pos);
}

void Module::AddExportedTypes(const std::vector<std::pair<const Type *, SourcePos>> &types) {
Expand Down
13 changes: 13 additions & 0 deletions src/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ class Module {
const std::vector<std::pair<const Type *, SourcePos>> &types,
const FunctionType *ftype, SourcePos pos);

void AddFunctionTemplateSpecializationDeclaration(const std::string &name, const FunctionType *ftype,
const std::vector<std::pair<const Type *, SourcePos>> &types,
SourcePos pos);

void AddFunctionTemplateSpecializationDefinition(const std::string &name, const FunctionType *ftype,
const std::vector<std::pair<const Type *, SourcePos>> &types,
SourcePos pos, Stmt *code);

/** Adds the given type to the set of types that have their definitions
included in automatically generated header files. */
void AddExportedTypes(const std::vector<std::pair<const Type *, SourcePos>> &types);
Expand All @@ -98,6 +106,11 @@ class Module {
function symbol for it. */
Symbol *AddLLVMIntrinsicDecl(const std::string &name, ExprList *args, SourcePos po);

/** Returns pointer to FunctionTemplate based on template name and template argument types provided. Also makes
* template argument types normalization.*/
FunctionTemplate *MatchFunctionTemplate(const std::string &name, const FunctionType *ftype,
std::vector<std::pair<const Type *, SourcePos>> &normTypes, SourcePos pos);

/** After a source file has been compiled, output can be generated in a
number of different formats. */
enum OutputType {
Expand Down
Loading

0 comments on commit 81e507b

Please sign in to comment.