diff --git a/pkg/hashicorp/tfregistry/handlers.go b/pkg/hashicorp/tfregistry/handlers.go index 72a8aa1..975922b 100644 --- a/pkg/hashicorp/tfregistry/handlers.go +++ b/pkg/hashicorp/tfregistry/handlers.go @@ -213,3 +213,127 @@ func ModuleDetails(registryClient *http.Client, logger *log.Logger) (tool mcp.To } } } + +func SearchPolicies(registryClient *http.Client, logger *log.Logger) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("searchPolicies", + mcp.WithDescription(`Searches for Terraform policies based on a query string. This tool returns a list of matching policies, which can be used to retrieve detailed policy information using the 'policyDetails' tool. + You MUST call this function before 'providerDetails' to obtain a valid terraformPolicyID. + When selecting the best match, consider: - Name similarity to the query - Title relevance - Verification status (verified) - Download counts (popularity) Return the selected policyID and explain your choice. + If there are multiple good matches, mention this but proceed with the most relevant one. If no policies were found, reattempt the search with a new policyQuery.`), + mcp.WithTitleAnnotation("Search and match Terraform policies based on name and relevance"), + mcp.WithOpenWorldHintAnnotation(true), + mcp.WithString("policyQuery", + mcp.Required(), + mcp.Description("The query to search for Terraform modules."), + ), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + var terraformPolicies TerraformPolicyList + policyQuery := request.Params.Arguments["policyQuery"] + if pq, ok := policyQuery.(string); !ok { + return nil, logAndReturnError(logger, "error finding the policy based on that name;", nil) + } else { + + // static list of 100 is fine for now + policyResp, err := sendRegistryCall(registryClient, "GET", "policies?page%5Bsize%5D=100&include=latest-version", logger, "v2") + if err != nil { + return nil, logAndReturnError(logger, "Failed to fetch policies: registry API did not return a successful response", err) + } + + err = json.Unmarshal(policyResp, &terraformPolicies) + if err != nil { + return nil, logAndReturnError(nil, "unmarshalling policy list", err) + } + + var builder strings.Builder + builder.WriteString(fmt.Sprintf("Matching Terraform Policies for query: %s\n\n", pq)) + builder.WriteString("Each result includes:\n- terraformPolicyID: Unique identifier to be used with policyDetails tool\n- Name: Policy name\n- Title: Policy description\n- Downloads: Policy downloads\n---\n\n") + + contentAvailable := false + for _, policy := range terraformPolicies.Data { + if strings.Contains(strings.ToLower(policy.Attributes.Name), strings.ToLower(pq)) || + strings.Contains(strings.ToLower(policy.Attributes.Title), strings.ToLower(pq)) { + contentAvailable = true + ID := strings.ReplaceAll(policy.Relationships.LatestVersion.Links.Related, "/v2/", "") + builder.WriteString(fmt.Sprintf( + "- terraformPolicyID: %s\n- Name: %s\n- Title: %s\n- Downloads: %d\n---\n", + ID, + policy.Attributes.Name, + policy.Attributes.Title, + policy.Attributes.Downloads, + )) + } + } + + policyData := builder.String() + if !contentAvailable { + policyData = fmt.Sprintf("No policies found matching the query: %s. Try a different policyQuery.", pq) + } + + return mcp.NewToolResultText(policyData), nil + } + } +} + +func PolicyDetails(registryClient *http.Client, logger *log.Logger) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("policyDetails", + mcp.WithDescription(`Fetches up-to-date documentation for a specific policy from the Terraform registry. You must call 'searchPolicies' first to obtain the exact terraformPolicyID required to use this tool.`), + mcp.WithTitleAnnotation("Fetch detailed Terraform policy documentation using a terraformPolicyID"), + mcp.WithOpenWorldHintAnnotation(true), + mcp.WithString("terraformPolicyID", + mcp.Required(), + mcp.Description("Matching terraformPolicyID retrieved from the 'searchPolicies' tool (e.g., 'policies/hashicorp/CIS-Policy-Set-for-AWS-Terraform/1.0.1')"), + ), + ), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + terraformPolicyID, ok := request.Params.Arguments["terraformPolicyID"].(string) + if !ok || terraformPolicyID == "" { + return nil, logAndReturnError(logger, "terraformPolicyID is required and must be a string, it is fetched by running the searchPolicies tool", nil) + } + + policyResp, err := sendRegistryCall(registryClient, "GET", fmt.Sprintf("%s?include=policies,policy-modules,policy-library", terraformPolicyID), logger, "v2") + if err != nil { + return nil, logAndReturnError(logger, "Failed to fetch policy details: registry API did not return a successful response", err) + } + + var policyDetails TerraformPolicyDetails + if err := json.Unmarshal(policyResp, &policyDetails); err != nil { + return nil, logAndReturnError(logger, fmt.Sprintf("error unmarshalling policy details for %s", terraformPolicyID), err) + } + + readme := extractReadme(policyDetails.Data.Attributes.Readme) + var builder strings.Builder + builder.WriteString(fmt.Sprintf("## Policy details about %s \n\n%s", terraformPolicyID, readme)) + policyList := "" + moduleList := "" + for _, policy := range policyDetails.Included { + if policy.Type == "policy-modules" { + moduleList += fmt.Sprintf(` +module "%s" { + source = "https://registry.terraform.io/v2%s/policy-module/%s.sentinel?checksum=sha256:%s" +} + `, policy.Attributes.Name, terraformPolicyID, policy.Attributes.Name, policy.Attributes.Shasum) + } + + if policy.Type == "policies" { + policyList += fmt.Sprintf("- POLICY_NAME: %s\n- POLICY_CHECKSUM: sha256:%s\n", policy.Attributes.Name, policy.Attributes.Shasum) + policyList += "\n---\n" + } + } + builder.WriteString("---\n") + builder.WriteString("## Usage\n\n") + builder.WriteString("Generate the content for a HashiCorp Configuration Language (HCL) file named policies.hcl. This file should define a set of policies. For each policy provided, create a distinct policy block using the following template.\n") + builder.WriteString("\n```hcl\n") + builder.WriteString(moduleList) + builder.WriteString(fmt.Sprintf(` +policy "<>" { + source = "https://registry.terraform.io/v2%s/policy/<>.sentinel?checksum=<>" + enforcement_level = "advisory" +} + `, terraformPolicyID)) + builder.WriteString("\n```\n") + builder.WriteString(fmt.Sprintf("Available policies with SHA for %s are: \n\n", terraformPolicyID)) + builder.WriteString(policyList) + + policyData := builder.String() + return mcp.NewToolResultText(policyData), nil + } +} diff --git a/pkg/hashicorp/tfregistry/tools.go b/pkg/hashicorp/tfregistry/tools.go index bb01f60..5d18b90 100644 --- a/pkg/hashicorp/tfregistry/tools.go +++ b/pkg/hashicorp/tfregistry/tools.go @@ -15,4 +15,6 @@ func InitTools(hcServer *server.MCPServer, registryClient *http.Client, logger * hcServer.AddTool(GetProviderDocs(registryClient, logger)) hcServer.AddTool(SearchModules(registryClient, logger)) hcServer.AddTool(ModuleDetails(registryClient, logger)) + hcServer.AddTool(SearchPolicies(registryClient, logger)) + hcServer.AddTool(PolicyDetails(registryClient, logger)) } diff --git a/pkg/hashicorp/tfregistry/types.go b/pkg/hashicorp/tfregistry/types.go index c533203..e5c0835 100644 --- a/pkg/hashicorp/tfregistry/types.go +++ b/pkg/hashicorp/tfregistry/types.go @@ -309,3 +309,122 @@ type ProviderDocData struct { Self string `json:"self"` } `json:"links"` } + +type TerraformPolicyList struct { + Data []struct { + Type string `json:"type"` + ID string `json:"id"` + Attributes struct { + Downloads int `json:"downloads"` + FullName string `json:"full-name"` + Ingress string `json:"ingress"` + Name string `json:"name"` + Namespace string `json:"namespace"` + OwnerName string `json:"owner-name"` + Source string `json:"source"` + Title string `json:"title"` + Verified bool `json:"verified"` + } `json:"attributes"` + Relationships struct { + LatestVersion struct { + Data struct { + ID string `json:"id"` + Type string `json:"type"` + } `json:"data"` + Links struct { + Related string `json:"related"` + } `json:"links"` + } `json:"latest-version"` + } `json:"relationships"` + Links struct { + Self string `json:"self"` + } `json:"links"` + } `json:"data"` + Included []struct { + Type string `json:"type"` + ID string `json:"id"` + Attributes struct { + Description string `json:"description"` + Downloads int `json:"downloads"` + PublishedAt time.Time `json:"published-at"` + Readme string `json:"readme"` + Source string `json:"source"` + Tag string `json:"tag"` + Version string `json:"version"` + } `json:"attributes"` + Links struct { + Self string `json:"self"` + } `json:"links"` + } `json:"included"` + Links struct { + First string `json:"first"` + Last string `json:"last"` + Next any `json:"next"` + Prev any `json:"prev"` + } `json:"links"` + Meta struct { + Pagination struct { + PageSize int `json:"page-size"` + CurrentPage int `json:"current-page"` + NextPage any `json:"next-page"` + PrevPage any `json:"prev-page"` + TotalPages int `json:"total-pages"` + TotalCount int `json:"total-count"` + } `json:"pagination"` + } `json:"meta"` +} + +type TerraformPolicyDetails struct { + Data struct { + Type string `json:"type"` + ID string `json:"id"` + Attributes struct { + Description string `json:"description"` + Downloads int `json:"downloads"` + PublishedAt time.Time `json:"published-at"` + Readme string `json:"readme"` + Source string `json:"source"` + Tag string `json:"tag"` + Version string `json:"version"` + } `json:"attributes"` + Relationships struct { + Policies struct { + Data []struct { + Type string `json:"type"` + ID string `json:"id"` + } `json:"data"` + } `json:"policies"` + PolicyLibrary struct { + Data struct { + Type string `json:"type"` + ID string `json:"id"` + } `json:"data"` + } `json:"policy-library"` + PolicyModules struct { + Data []struct { + Type string `json:"type"` + ID string `json:"id"` + } `json:"data"` + } `json:"policy-modules"` + } `json:"relationships"` + Links struct { + Self string `json:"self"` + } `json:"links"` + } `json:"data"` + Included []struct { + Type string `json:"type"` + ID string `json:"id"` + Attributes struct { + Description string `json:"description"` + Downloads int `json:"downloads"` + FullName string `json:"full-name"` + Name string `json:"name"` + Shasum string `json:"shasum"` + ShasumType string `json:"shasum-type"` + Title string `json:"title"` + } `json:"attributes"` + Links struct { + Self string `json:"self"` + } `json:"links"` + } `json:"included"` +} diff --git a/pkg/hashicorp/tfregistry/utils.go b/pkg/hashicorp/tfregistry/utils.go index 37af8fa..ca77de1 100644 --- a/pkg/hashicorp/tfregistry/utils.go +++ b/pkg/hashicorp/tfregistry/utils.go @@ -522,3 +522,25 @@ func isV2ProviderDataType(dataType string) bool { v2Categories := []string{"guides", "functions", "overview"} return slices.Contains(v2Categories, dataType) } + +func extractReadme(readme string) string { + if readme == "" { + return "" + } + + extractedReadme := "" + headerFound := false + strArr := strings.Split(readme, "\n") + for _, str := range strArr { + if strings.Contains(str, "#") { + if headerFound { + return extractedReadme + } + headerFound = true + } + extractedReadme += str + "\n" + } + + extractedReadme = strings.TrimSuffix(extractedReadme, "\n") + return extractedReadme +} diff --git a/pkg/hashicorp/tfregistry/utils_test.go b/pkg/hashicorp/tfregistry/utils_test.go index 62bbc44..54b97ea 100644 --- a/pkg/hashicorp/tfregistry/utils_test.go +++ b/pkg/hashicorp/tfregistry/utils_test.go @@ -351,3 +351,56 @@ func TestIsV2ProviderDataType(t *testing.T) { } } } + +func TestExtractReadme(t *testing.T) { + tests := []struct { + name string + readme string + expected string + }{ + { + name: "SingleSection", + readme: "# Title\nSome content here.", + expected: "# Title\nSome content here.", + }, + { + name: "TwoSections", + readme: "# Title\nSome content here.\n# Section2\nMore content.", + expected: "# Title\nSome content here.\n", + }, + { + name: "NoHash", + readme: "No hash at all", + expected: "No hash at all", + }, + { + name: "MultipleHashes", + readme: "# First\nContent1\n# Second\nContent2\n# Third\nContent3", + expected: "# First\nContent1\n", + }, + { + name: "HashAtEnd", + readme: "Some intro\n# OnlySection", + expected: "Some intro\n# OnlySection", + }, + { + name: "HashWithoutNextLine", + readme: "Some intro\n# OnlySection ## More Content", + expected: "Some intro\n# OnlySection ## More Content", + }, + { + name: "EmptyString", + readme: "", + expected: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := extractReadme(tc.readme) + if result != tc.expected { + t.Errorf("extractReadme(%q) = %q; want %q", tc.readme, result, tc.expected) + } + }) + } +}