Skip to content

Adding policy tools #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions pkg/hashicorp/tfregistry/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,127 @@ func ModuleDetails(registryClient *http.Client, logger *log.Logger) (tool mcp.To
}
}
}

func SearchPolicies(registryClient *http.Client, logger *log.Logger) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("searchPolicies",
mcp.WithDescription(`Searches for Terraform policies based on a query string. This tool returns a list of matching policies, which can be used to retrieve detailed policy information using the 'policyDetails' tool.
You MUST call this function before 'providerDetails' to obtain a valid terraformPolicyID.
When selecting the best match, consider: - Name similarity to the query - Title relevance - Verification status (verified) - Download counts (popularity) Return the selected policyID and explain your choice.
If there are multiple good matches, mention this but proceed with the most relevant one. If no policies were found, reattempt the search with a new policyQuery.`),
mcp.WithTitleAnnotation("Search and match Terraform policies based on name and relevance"),
mcp.WithOpenWorldHintAnnotation(true),
mcp.WithString("policyQuery",
mcp.Required(),
mcp.Description("The query to search for Terraform modules."),
),
), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
var terraformPolicies TerraformPolicyList
policyQuery := request.Params.Arguments["policyQuery"]
if pq, ok := policyQuery.(string); !ok {
return nil, logAndReturnError(logger, "error finding the policy based on that name;", nil)
} else {

// static list of 100 is fine for now
policyResp, err := sendRegistryCall(registryClient, "GET", "policies?page%5Bsize%5D=100&include=latest-version", logger, "v2")
if err != nil {
return nil, logAndReturnError(logger, "Failed to fetch policies: registry API did not return a successful response", err)
}

err = json.Unmarshal(policyResp, &terraformPolicies)
if err != nil {
return nil, logAndReturnError(nil, "unmarshalling policy list", err)
}

var builder strings.Builder
builder.WriteString(fmt.Sprintf("Matching Terraform Policies for query: %s\n\n", pq))
builder.WriteString("Each result includes:\n- terraformPolicyID: Unique identifier to be used with policyDetails tool\n- Name: Policy name\n- Title: Policy description\n- Downloads: Policy downloads\n---\n\n")

contentAvailable := false
for _, policy := range terraformPolicies.Data {
if strings.Contains(strings.ToLower(policy.Attributes.Name), strings.ToLower(pq)) ||
strings.Contains(strings.ToLower(policy.Attributes.Title), strings.ToLower(pq)) {
contentAvailable = true
ID := strings.ReplaceAll(policy.Relationships.LatestVersion.Links.Related, "/v2/", "")
builder.WriteString(fmt.Sprintf(
"- terraformPolicyID: %s\n- Name: %s\n- Title: %s\n- Downloads: %d\n---\n",
ID,
policy.Attributes.Name,
policy.Attributes.Title,
policy.Attributes.Downloads,
))
}
}

policyData := builder.String()
if !contentAvailable {
policyData = fmt.Sprintf("No policies found matching the query: %s. Try a different policyQuery.", pq)
}

return mcp.NewToolResultText(policyData), nil
}
}
}

func PolicyDetails(registryClient *http.Client, logger *log.Logger) (tool mcp.Tool, handler server.ToolHandlerFunc) {
return mcp.NewTool("policyDetails",
mcp.WithDescription(`Fetches up-to-date documentation for a specific policy from the Terraform registry. You must call 'searchPolicies' first to obtain the exact terraformPolicyID required to use this tool.`),
mcp.WithTitleAnnotation("Fetch detailed Terraform policy documentation using a terraformPolicyID"),
mcp.WithOpenWorldHintAnnotation(true),
mcp.WithString("terraformPolicyID",
mcp.Required(),
mcp.Description("Matching terraformPolicyID retrieved from the 'searchPolicies' tool (e.g., 'policies/hashicorp/CIS-Policy-Set-for-AWS-Terraform/1.0.1')"),
),
), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
terraformPolicyID, ok := request.Params.Arguments["terraformPolicyID"].(string)
if !ok || terraformPolicyID == "" {
return nil, logAndReturnError(logger, "terraformPolicyID is required and must be a string, it is fetched by running the searchPolicies tool", nil)
}

policyResp, err := sendRegistryCall(registryClient, "GET", fmt.Sprintf("%s?include=policies,policy-modules,policy-library", terraformPolicyID), logger, "v2")
if err != nil {
return nil, logAndReturnError(logger, "Failed to fetch policy details: registry API did not return a successful response", err)
}

var policyDetails TerraformPolicyDetails
if err := json.Unmarshal(policyResp, &policyDetails); err != nil {
return nil, logAndReturnError(logger, fmt.Sprintf("error unmarshalling policy details for %s", terraformPolicyID), err)
}

readme := extractReadme(policyDetails.Data.Attributes.Readme)
var builder strings.Builder
builder.WriteString(fmt.Sprintf("## Policy details about %s \n\n%s", terraformPolicyID, readme))
policyList := ""
moduleList := ""
for _, policy := range policyDetails.Included {
if policy.Type == "policy-modules" {
moduleList += fmt.Sprintf(`
module "%s" {
source = "https://registry.terraform.io/v2%s/policy-module/%s.sentinel?checksum=sha256:%s"
}
`, policy.Attributes.Name, terraformPolicyID, policy.Attributes.Name, policy.Attributes.Shasum)
}

if policy.Type == "policies" {
policyList += fmt.Sprintf("- POLICY_NAME: %s\n- POLICY_CHECKSUM: sha256:%s\n", policy.Attributes.Name, policy.Attributes.Shasum)
policyList += "\n---\n"
}
}
builder.WriteString("---\n")
builder.WriteString("## Usage\n\n")
builder.WriteString("Generate the content for a HashiCorp Configuration Language (HCL) file named policies.hcl. This file should define a set of policies. For each policy provided, create a distinct policy block using the following template.\n")
builder.WriteString("\n```hcl\n")
builder.WriteString(moduleList)
builder.WriteString(fmt.Sprintf(`
policy "<<POLICY_NAME>>" {
source = "https://registry.terraform.io/v2%s/policy/<<POLICY_NAME>>.sentinel?checksum=<<POLICY_CHECKSUM>>"
enforcement_level = "advisory"
}
`, terraformPolicyID))
builder.WriteString("\n```\n")
builder.WriteString(fmt.Sprintf("Available policies with SHA for %s are: \n\n", terraformPolicyID))
builder.WriteString(policyList)

policyData := builder.String()
return mcp.NewToolResultText(policyData), nil
}
}
2 changes: 2 additions & 0 deletions pkg/hashicorp/tfregistry/tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ func InitTools(hcServer *server.MCPServer, registryClient *http.Client, logger *
hcServer.AddTool(GetProviderDocs(registryClient, logger))
hcServer.AddTool(SearchModules(registryClient, logger))
hcServer.AddTool(ModuleDetails(registryClient, logger))
hcServer.AddTool(SearchPolicies(registryClient, logger))
hcServer.AddTool(PolicyDetails(registryClient, logger))
}
119 changes: 119 additions & 0 deletions pkg/hashicorp/tfregistry/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,122 @@ type ProviderDocData struct {
Self string `json:"self"`
} `json:"links"`
}

type TerraformPolicyList struct {
Data []struct {
Type string `json:"type"`
ID string `json:"id"`
Attributes struct {
Downloads int `json:"downloads"`
FullName string `json:"full-name"`
Ingress string `json:"ingress"`
Name string `json:"name"`
Namespace string `json:"namespace"`
OwnerName string `json:"owner-name"`
Source string `json:"source"`
Title string `json:"title"`
Verified bool `json:"verified"`
} `json:"attributes"`
Relationships struct {
LatestVersion struct {
Data struct {
ID string `json:"id"`
Type string `json:"type"`
} `json:"data"`
Links struct {
Related string `json:"related"`
} `json:"links"`
} `json:"latest-version"`
} `json:"relationships"`
Links struct {
Self string `json:"self"`
} `json:"links"`
} `json:"data"`
Included []struct {
Type string `json:"type"`
ID string `json:"id"`
Attributes struct {
Description string `json:"description"`
Downloads int `json:"downloads"`
PublishedAt time.Time `json:"published-at"`
Readme string `json:"readme"`
Source string `json:"source"`
Tag string `json:"tag"`
Version string `json:"version"`
} `json:"attributes"`
Links struct {
Self string `json:"self"`
} `json:"links"`
} `json:"included"`
Links struct {
First string `json:"first"`
Last string `json:"last"`
Next any `json:"next"`
Prev any `json:"prev"`
} `json:"links"`
Meta struct {
Pagination struct {
PageSize int `json:"page-size"`
CurrentPage int `json:"current-page"`
NextPage any `json:"next-page"`
PrevPage any `json:"prev-page"`
TotalPages int `json:"total-pages"`
TotalCount int `json:"total-count"`
} `json:"pagination"`
} `json:"meta"`
}

type TerraformPolicyDetails struct {
Data struct {
Type string `json:"type"`
ID string `json:"id"`
Attributes struct {
Description string `json:"description"`
Downloads int `json:"downloads"`
PublishedAt time.Time `json:"published-at"`
Readme string `json:"readme"`
Source string `json:"source"`
Tag string `json:"tag"`
Version string `json:"version"`
} `json:"attributes"`
Relationships struct {
Policies struct {
Data []struct {
Type string `json:"type"`
ID string `json:"id"`
} `json:"data"`
} `json:"policies"`
PolicyLibrary struct {
Data struct {
Type string `json:"type"`
ID string `json:"id"`
} `json:"data"`
} `json:"policy-library"`
PolicyModules struct {
Data []struct {
Type string `json:"type"`
ID string `json:"id"`
} `json:"data"`
} `json:"policy-modules"`
} `json:"relationships"`
Links struct {
Self string `json:"self"`
} `json:"links"`
} `json:"data"`
Included []struct {
Type string `json:"type"`
ID string `json:"id"`
Attributes struct {
Description string `json:"description"`
Downloads int `json:"downloads"`
FullName string `json:"full-name"`
Name string `json:"name"`
Shasum string `json:"shasum"`
ShasumType string `json:"shasum-type"`
Title string `json:"title"`
} `json:"attributes"`
Links struct {
Self string `json:"self"`
} `json:"links"`
} `json:"included"`
}
22 changes: 22 additions & 0 deletions pkg/hashicorp/tfregistry/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,25 @@ func isV2ProviderDataType(dataType string) bool {
v2Categories := []string{"guides", "functions", "overview"}
return slices.Contains(v2Categories, dataType)
}

func extractReadme(readme string) string {
if readme == "" {
return ""
}

extractedReadme := ""
headerFound := false
strArr := strings.Split(readme, "\n")
for _, str := range strArr {
if strings.Contains(str, "#") {
if headerFound {
return extractedReadme
}
headerFound = true
}
extractedReadme += str + "\n"
}

extractedReadme = strings.TrimSuffix(extractedReadme, "\n")
return extractedReadme
}
53 changes: 53 additions & 0 deletions pkg/hashicorp/tfregistry/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,56 @@ func TestIsV2ProviderDataType(t *testing.T) {
}
}
}

func TestExtractReadme(t *testing.T) {
tests := []struct {
name string
readme string
expected string
}{
{
name: "SingleSection",
readme: "# Title\nSome content here.",
expected: "# Title\nSome content here.",
},
{
name: "TwoSections",
readme: "# Title\nSome content here.\n# Section2\nMore content.",
expected: "# Title\nSome content here.\n",
},
{
name: "NoHash",
readme: "No hash at all",
expected: "No hash at all",
},
{
name: "MultipleHashes",
readme: "# First\nContent1\n# Second\nContent2\n# Third\nContent3",
expected: "# First\nContent1\n",
},
{
name: "HashAtEnd",
readme: "Some intro\n# OnlySection",
expected: "Some intro\n# OnlySection",
},
{
name: "HashWithoutNextLine",
readme: "Some intro\n# OnlySection ## More Content",
expected: "Some intro\n# OnlySection ## More Content",
},
{
name: "EmptyString",
readme: "",
expected: "",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := extractReadme(tc.readme)
if result != tc.expected {
t.Errorf("extractReadme(%q) = %q; want %q", tc.readme, result, tc.expected)
}
})
}
}