Skip to content

Commit

Permalink
Some cleanup in error handling, better logging, and more logical prog…
Browse files Browse the repository at this point in the history
…ram exit behaviour
  • Loading branch information
nixigaj committed Apr 11, 2024
1 parent cec127d commit 4b9d8e5
Showing 1 changed file with 73 additions and 46 deletions.
119 changes: 73 additions & 46 deletions cf-tlsa-acmesh.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,26 @@ func main() {
requiredEnvVars := []string{"KEY_FILE", "KEY_FILE_NEXT", "ZONE_ID", "API_TOKEN", "DOMAIN"}
for _, envVar := range requiredEnvVars {
if os.Getenv(envVar) == "" {
log.Println("Error:", envVar, "environment variable is not defined")
os.Exit(1)
log.Fatalln("Fatal:", envVar, "environment variable is not defined")
}
}

cert, err := generateCert(os.Getenv("KEY_FILE"))
if err != nil {
log.Println("Error generating cert:", err)
os.Exit(1)
log.Fatalln("Fatal: failed to generate current cert:", err)
}

certNext, err := generateCert(os.Getenv("KEY_FILE_NEXT"))
if err != nil {
log.Println("Error generating next cert:", err)
os.Exit(1)
log.Fatalln("Fatal: failed to generate next cert:", err)
}

log.Println("Current cert:", cert)
log.Println("Next cert:", certNext)

tlsaRecords, err := getTLSARecords()
if err != nil {
log.Println("Error:", err)
return
log.Fatalln("Fatal: failed to get TLSA records:", err)
}

for i, record := range tlsaRecords {
Expand All @@ -82,27 +78,49 @@ func main() {

if len(tlsaRecords) != 2 {
log.Println("Incorrect number of DNS entries. Deleting them and generating new ones.")
deleteAll(tlsaRecords)
addRequest(certNext)
addRequest(cert)
return

err = deleteAll(tlsaRecords)
if err != nil {
log.Fatalln("Fatal: failed to delete all TLSA recors:", err)
}

err = addRequest(certNext)
if err != nil {
log.Fatalln("Fatal: failed to add TLSA record for current cert:", err)
}

err = addRequest(cert)
if err != nil {
log.Fatalln("Fatal: failed to add TLSA record for next cert:", err)
}

os.Exit(0)
}

if (checkData(tlsaRecords[0], cert) && checkData(tlsaRecords[1], certNext)) ||
(checkData(tlsaRecords[0], certNext) && checkData(tlsaRecords[1], cert)) {
switch {
case (checkData(tlsaRecords[0], cert) && checkData(tlsaRecords[1], certNext)) ||
(checkData(tlsaRecords[0], certNext) && checkData(tlsaRecords[1], cert)):
log.Println("Nothing to do!")
} else if checkData(tlsaRecords[0], cert) {
modifyRequest(certNext, tlsaRecords[1].ID)
} else if checkData(tlsaRecords[0], certNext) {
modifyRequest(cert, tlsaRecords[1].ID)
} else if checkData(tlsaRecords[1], cert) {
modifyRequest(certNext, tlsaRecords[0].ID)
} else if checkData(tlsaRecords[1], certNext) {
modifyRequest(cert, tlsaRecords[0].ID)
} else {
modifyRequest(certNext, tlsaRecords[1].ID)
modifyRequest(cert, tlsaRecords[0].ID)
case checkData(tlsaRecords[0], cert):
err = modifyRequest(certNext, tlsaRecords[1].ID)
case checkData(tlsaRecords[0], certNext):
err = modifyRequest(cert, tlsaRecords[1].ID)
case checkData(tlsaRecords[1], cert):
err = modifyRequest(certNext, tlsaRecords[0].ID)
case checkData(tlsaRecords[1], certNext):
err = modifyRequest(cert, tlsaRecords[0].ID)
default:
err = modifyRequest(certNext, tlsaRecords[1].ID)
if err != nil {
break
}
err = modifyRequest(cert, tlsaRecords[0].ID)
}
if err != nil {
log.Fatalln("Fatal: failed to modify TLSA records:", err)
}

os.Exit(0)
}

func getTLSARecords() ([]tlsaRecord, error) {
Expand Down Expand Up @@ -131,12 +149,18 @@ func getTLSARecords() ([]tlsaRecord, error) {
}
}(resp.Body)

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed reading response body: %v", err)
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP request failed with status code: %s", resp.Status)
return nil, fmt.Errorf("recieved %d HTTP response status code for GET request, response body: %s", resp.StatusCode, string(body))
}

var response tlsaRecordsResponse
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
err = json.Unmarshal(body, &response)
if err != nil {
return nil, fmt.Errorf("failed to decode JSON response: %v", err)
}

Expand Down Expand Up @@ -171,18 +195,23 @@ func generateCert(keyPath string) (string, error) {
return hex.EncodeToString(hashSum), nil
}

func deleteAll(tlsaRecords []tlsaRecord) {
func deleteAll(tlsaRecords []tlsaRecord) error {
zoneID, authToken := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN")

for _, record := range tlsaRecords {
log.Println("Deleting DNS record:", record.ID)
url := cloudflareAPI + zoneID + "/dns_records/" + record.ID
resp, err := makeHTTPRequest("DELETE", url, authToken, nil)
handleResponse(resp, err, "DELETE")
err = handleResponse(resp, err, "DELETE")
if err != nil {
return err
}
}

return nil
}

func addRequest(hash string) {
func addRequest(hash string) error {
log.Println("Adding DNS record with hash:", hash)

zoneID, authToken, domain := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN"), os.Getenv("DOMAIN")
Expand All @@ -193,10 +222,10 @@ func addRequest(hash string) {
port, protocol, domain, usage, selector, matchingType, hash)

resp, err := makeHTTPRequest("POST", url, authToken, []byte(payload))
handleResponse(resp, err, "POST")
return handleResponse(resp, err, "POST")
}

func modifyRequest(hash, id string) {
func modifyRequest(hash, id string) error {
log.Println("Modifying DNS record:", id, "with hash:", hash)

zoneID, authToken, domain := os.Getenv("ZONE_ID"), os.Getenv("API_TOKEN"), os.Getenv("DOMAIN")
Expand All @@ -207,7 +236,7 @@ func modifyRequest(hash, id string) {
port, protocol, domain, usage, selector, matchingType, hash)

resp, err := makeHTTPRequest("PUT", url, authToken, []byte(payload))
handleResponse(resp, err, "PUT")
return handleResponse(resp, err, "PUT")
}

func makeHTTPRequest(method, url, authToken string, payload []byte) (*http.Response, error) {
Expand All @@ -223,29 +252,27 @@ func makeHTTPRequest(method, url, authToken string, payload []byte) (*http.Respo
return client.Do(req)
}

func handleResponse(resp *http.Response, err error, action string) {
func handleResponse(resp *http.Response, err error, action string) error {
if err != nil {
log.Println("Error:", err)
os.Exit(1)
return err
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
log.Println("Error closing HTTP body", err)
log.Println("Error closing HTTP body:", err)
}
}(resp.Body)

log.Println(action, "HTTP Status Code:", resp.Status)
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed reading response body: %v", err)
}

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Println("Error reading response body:", err)
} else {
log.Println("Response Body:", string(body))
}
os.Exit(1)
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("recieved %d HTTP response status code for %s request, response body: %s", resp.StatusCode, action, string(body))
}

return nil
}

func checkData(record tlsaRecord, hash string) (correct bool) {
Expand Down

0 comments on commit 4b9d8e5

Please sign in to comment.