diff --git a/cf-tlsa-acmesh.go b/cf-tlsa-acmesh.go index 4c8319f..209bffa 100644 --- a/cf-tlsa-acmesh.go +++ b/cf-tlsa-acmesh.go @@ -50,21 +50,18 @@ func main() { requiredEnvVars := []string{"KEY_FILE", "KEY_FILE_NEXT", "ZONE_ID", "API_TOKEN", "DOMAIN"} for _, envVar := range requiredEnvVars { if os.Getenv(envVar) == "" { - log.Println("Error:", envVar, "environment variable is not defined") - os.Exit(1) + log.Fatalln("Fatal:", envVar, "environment variable is not defined") } } cert, err := generateCert(os.Getenv("KEY_FILE")) if err != nil { - log.Println("Error generating cert:", err) - os.Exit(1) + log.Fatalln("Fatal: failed to generate current cert:", err) } certNext, err := generateCert(os.Getenv("KEY_FILE_NEXT")) if err != nil { - log.Println("Error generating next cert:", err) - os.Exit(1) + log.Fatalln("Fatal: failed to generate next cert:", err) } log.Println("Current cert:", cert) @@ -72,8 +69,7 @@ func main() { tlsaRecords, err := getTLSARecords() if err != nil { - log.Println("Error:", err) - return + log.Fatalln("Fatal: failed to get TLSA records:", err) } for i, record := range tlsaRecords { @@ -82,27 +78,49 @@ func main() { if len(tlsaRecords) != 2 { log.Println("Incorrect number of DNS entries. Deleting them and generating new ones.") - deleteAll(tlsaRecords) - addRequest(certNext) - addRequest(cert) - return + + err = deleteAll(tlsaRecords) + if err != nil { + log.Fatalln("Fatal: failed to delete all TLSA recors:", err) + } + + err = addRequest(certNext) + if err != nil { + log.Fatalln("Fatal: failed to add TLSA record for current cert:", err) + } + + err = addRequest(cert) + if err != nil { + log.Fatalln("Fatal: failed to add TLSA record for next cert:", err) + } + + os.Exit(0) } - if (checkData(tlsaRecords[0], cert) && checkData(tlsaRecords[1], certNext)) || - (checkData(tlsaRecords[0], certNext) && checkData(tlsaRecords[1], cert)) { + switch { + case (checkData(tlsaRecords[0], cert) && checkData(tlsaRecords[1], certNext)) || + (checkData(tlsaRecords[0], certNext) && checkData(tlsaRecords[1], cert)): log.Println("Nothing to do!") - } else if checkData(tlsaRecords[0], cert) { - modifyRequest(certNext, tlsaRecords[1].ID) - } else if checkData(tlsaRecords[0], certNext) { - modifyRequest(cert, tlsaRecords[1].ID) - } else if checkData(tlsaRecords[1], cert) { - modifyRequest(certNext, tlsaRecords[0].ID) - } else if checkData(tlsaRecords[1], certNext) { - modifyRequest(cert, tlsaRecords[0].ID) - } else { - modifyRequest(certNext, tlsaRecords[1].ID) - modifyRequest(cert, tlsaRecords[0].ID) + case checkData(tlsaRecords[0], cert): + err = modifyRequest(certNext, tlsaRecords[1].ID) + case checkData(tlsaRecords[0], certNext): + err = modifyRequest(cert, tlsaRecords[1].ID) + case checkData(tlsaRecords[1], cert): + err = modifyRequest(certNext, tlsaRecords[0].ID) + case checkData(tlsaRecords[1], certNext): + err = modifyRequest(cert, tlsaRecords[0].ID) + default: + err = modifyRequest(certNext, tlsaRecords[1].ID) + if err != nil { + break + } + err = modifyRequest(cert, tlsaRecords[0].ID) } + if err != nil { + log.Fatalln("Fatal: failed to modify TLSA records:", err) + } + + os.Exit(0) } func getTLSARecords() ([]tlsaRecord, error) { @@ -131,12 +149,18 @@ func getTLSARecords() ([]tlsaRecord, error) { } }(resp.Body) + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed reading response body: %v", err) + } + if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("HTTP request failed with status code: %s", resp.Status) + return nil, fmt.Errorf("recieved %d HTTP response status code for GET request, response body: %s", resp.StatusCode, string(body)) } var response tlsaRecordsResponse - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + err = json.Unmarshal(body, &response) + if err != nil { return nil, fmt.Errorf("failed to decode JSON response: %v", err) } @@ -171,18 +195,23 @@ func generateCert(keyPath string) (string, error) { return hex.EncodeToString(hashSum), nil } -func deleteAll(tlsaRecords []tlsaRecord) { +func deleteAll(tlsaRecords []tlsaRecord) error { zoneID, authToken := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN") for _, record := range tlsaRecords { log.Println("Deleting DNS record:", record.ID) url := cloudflareAPI + zoneID + "/dns_records/" + record.ID resp, err := makeHTTPRequest("DELETE", url, authToken, nil) - handleResponse(resp, err, "DELETE") + err = handleResponse(resp, err, "DELETE") + if err != nil { + return err + } } + + return nil } -func addRequest(hash string) { +func addRequest(hash string) error { log.Println("Adding DNS record with hash:", hash) zoneID, authToken, domain := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN"), os.Getenv("DOMAIN") @@ -193,10 +222,10 @@ func addRequest(hash string) { port, protocol, domain, usage, selector, matchingType, hash) resp, err := makeHTTPRequest("POST", url, authToken, []byte(payload)) - handleResponse(resp, err, "POST") + return handleResponse(resp, err, "POST") } -func modifyRequest(hash, id string) { +func modifyRequest(hash, id string) error { log.Println("Modifying DNS record:", id, "with hash:", hash) zoneID, authToken, domain := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN"), os.Getenv("DOMAIN") @@ -207,7 +236,7 @@ func modifyRequest(hash, id string) { port, protocol, domain, usage, selector, matchingType, hash) resp, err := makeHTTPRequest("PUT", url, authToken, []byte(payload)) - handleResponse(resp, err, "PUT") + return handleResponse(resp, err, "PUT") } func makeHTTPRequest(method, url, authToken string, payload []byte) (*http.Response, error) { @@ -223,29 +252,27 @@ func makeHTTPRequest(method, url, authToken string, payload []byte) (*http.Respo return client.Do(req) } -func handleResponse(resp *http.Response, err error, action string) { +func handleResponse(resp *http.Response, err error, action string) error { if err != nil { - log.Println("Error:", err) - os.Exit(1) + return err } defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { - log.Println("Error closing HTTP body", err) + log.Println("Error closing HTTP body:", err) } }(resp.Body) - log.Println(action, "HTTP Status Code:", resp.Status) + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed reading response body: %v", err) + } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Println("Error reading response body:", err) - } else { - log.Println("Response Body:", string(body)) - } - os.Exit(1) + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("recieved %d HTTP response status code for %s request, response body: %s", resp.StatusCode, action, string(body)) } + + return nil } func checkData(record tlsaRecord, hash string) (correct bool) {