diff --git a/armor/armor.go b/armor/armor.go index 7cc66157..def9fa11 100644 --- a/armor/armor.go +++ b/armor/armor.go @@ -28,7 +28,7 @@ const ( type armoredWriter struct { started, closed bool - encoder io.WriteCloser + encoder *format.WrappedBase64Encoder dst io.Writer } @@ -50,15 +50,20 @@ func (a *armoredWriter) Close() error { if err := a.encoder.Close(); err != nil { return err } - _, err := io.WriteString(a.dst, "\n"+Footer+"\n") + footer := Footer + "\n" + if !a.encoder.LastLineIsEmpty() { + footer = "\n" + footer + } + _, err := io.WriteString(a.dst, footer) return err } func NewWriter(dst io.Writer) io.WriteCloser { // TODO: write a test with aligned and misaligned sizes, and 8 and 10 steps. - return &armoredWriter{dst: dst, - encoder: base64.NewEncoder(base64.StdEncoding.Strict(), - format.NewlineWriter(dst))} + return &armoredWriter{ + dst: dst, + encoder: format.NewWrappedBase64Encoder(base64.StdEncoding, dst), + } } type armoredReader struct { diff --git a/armor/armor_test.go b/armor/armor_test.go index 5ffff190..523eae07 100644 --- a/armor/armor_test.go +++ b/armor/armor_test.go @@ -8,6 +8,7 @@ package armor_test import ( "bytes" + "crypto/rand" "encoding/pem" "fmt" "io" @@ -18,6 +19,7 @@ import ( "filippo.io/age" "filippo.io/age/armor" + "filippo.io/age/internal/format" ) func ExampleNewWriter() { @@ -87,9 +89,15 @@ kB/RRusYjn+KVJ+KTioxj0THtzZPXcjFKuQ1 } func TestArmor(t *testing.T) { + t.Run("PartialLine", func(t *testing.T) { testArmor(t, 611) }) + t.Run("FullLine", func(t *testing.T) { testArmor(t, 10*format.BytesPerLine) }) +} + +func testArmor(t *testing.T, size int) { buf := &bytes.Buffer{} w := armor.NewWriter(buf) - plain := make([]byte, 611) + plain := make([]byte, size) + rand.Read(plain) if _, err := w.Write(plain); err != nil { t.Fatal(err) } @@ -101,9 +109,18 @@ func TestArmor(t *testing.T) { if block == nil { t.Fatal("PEM decoding failed") } + if len(block.Headers) != 0 { + t.Error("unexpected headers") + } + if block.Type != "AGE ENCRYPTED FILE" { + t.Errorf("unexpected type %q", block.Type) + } if !bytes.Equal(block.Bytes, plain) { t.Error("PEM decoded value doesn't match") } + if !bytes.Equal(buf.Bytes(), pem.EncodeToMemory(block)) { + t.Error("PEM re-encoded value doesn't match") + } r := armor.NewReader(buf) out, err := ioutil.ReadAll(r) diff --git a/internal/format/format.go b/internal/format/format.go index 8f79b613..ecded7d1 100644 --- a/internal/format/format.go +++ b/internal/format/format.go @@ -43,25 +43,40 @@ func DecodeString(s string) ([]byte, error) { var EncodeToString = b64.EncodeToString const ColumnsPerLine = 64 + const BytesPerLine = ColumnsPerLine / 4 * 3 -// NewlineWriter returns a Writer that writes to dst, inserting an LF character -// every ColumnsPerLine bytes. It does not insert a newline neither at the -// beginning nor at the end of the stream, but it ensures the last line is -// shorter than ColumnsPerLine, which means it might be empty. -func NewlineWriter(dst io.Writer) io.Writer { - return &newlineWriter{dst: dst} +// NewWrappedBase64Encoder returns a WrappedBase64Encoder that writes to dst. +func NewWrappedBase64Encoder(enc *base64.Encoding, dst io.Writer) *WrappedBase64Encoder { + w := &WrappedBase64Encoder{dst: dst} + w.enc = base64.NewEncoder(enc, WriterFunc(w.writeWrapped)) + return w } -type newlineWriter struct { +type WriterFunc func(p []byte) (int, error) + +func (f WriterFunc) Write(p []byte) (int, error) { return f(p) } + +// WrappedBase64Encoder is a standard base64 encoder that inserts an LF +// character every ColumnsPerLine bytes. It does not insert a newline neither at +// the beginning nor at the end of the stream, but it ensures the last line is +// shorter than ColumnsPerLine, which means it might be empty. +type WrappedBase64Encoder struct { + enc io.WriteCloser dst io.Writer written int buf bytes.Buffer } -func (w *newlineWriter) Write(p []byte) (int, error) { +func (w *WrappedBase64Encoder) Write(p []byte) (int, error) { return w.enc.Write(p) } + +func (w *WrappedBase64Encoder) Close() error { + return w.enc.Close() +} + +func (w *WrappedBase64Encoder) writeWrapped(p []byte) (int, error) { if w.buf.Len() != 0 { - panic("age: internal error: non-empty newlineWriter.buf") + panic("age: internal error: non-empty WrappedBase64Encoder.buf") } for len(p) > 0 { toWrite := ColumnsPerLine - (w.written % ColumnsPerLine) @@ -84,9 +99,18 @@ func (w *newlineWriter) Write(p []byte) (int, error) { return len(p), nil } +// LastLineIsEmpty returns whether the last output line was empty, either +// because no input was written, or because a multiple of BytesPerLine was. +// +// Calling LastLineIsEmpty before Close is meaningless. +func (w *WrappedBase64Encoder) LastLineIsEmpty() bool { + return w.written%ColumnsPerLine == 0 +} + const intro = "age-encryption.org/v1\n" var recipientPrefix = []byte("->") + var footerPrefix = []byte("---") func (r *Stanza) Marshal(w io.Writer) error { @@ -101,7 +125,7 @@ func (r *Stanza) Marshal(w io.Writer) error { if _, err := io.WriteString(w, "\n"); err != nil { return err } - ww := base64.NewEncoder(b64, NewlineWriter(w)) + ww := NewWrappedBase64Encoder(b64, w) if _, err := ww.Write(r.Body); err != nil { return err }