Skip to content

Commit 2b81caa

Browse files
authored
fix: incomplete streamReader read (#30)
Signed-off-by: francois samin <[email protected]>
1 parent e548945 commit 2b81caa

File tree

3 files changed

+142
-7
lines changed

3 files changed

+142
-7
lines changed

convergent/convergent_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@ import (
55
"crypto/rand"
66
"encoding/hex"
77
"encoding/json"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
811
"strings"
12+
"sync"
913
"testing"
1014
"time"
1115

1216
"github.com/ovh/configstore"
17+
"github.com/stretchr/testify/assert"
1318
"github.com/stretchr/testify/require"
1419

1520
"github.com/ovh/symmecrypt/ciphers/aesgcm"
@@ -257,3 +262,59 @@ func TestLoadKeyFromStore(t *testing.T) {
257262
require.NoError(t, err)
258263
require.NotNil(t, k)
259264
}
265+
266+
func TestDecryptFromHTTP(t *testing.T) {
267+
// Key config
268+
cfgs := []convergent.ConvergentEncryptionConfig{
269+
{
270+
Cipher: aesgcm.CipherName,
271+
LocatorSalt: symutils.RandomSalt(),
272+
SecretValue: symutils.MustRandomString(10),
273+
},
274+
}
275+
276+
// Encrypt a random content
277+
clearContent := make([]byte, 10*1024)
278+
rand.Read(clearContent) // nolint
279+
280+
h, err := convergent.NewHash(bytes.NewReader(clearContent))
281+
require.NoError(t, err)
282+
283+
k, err := convergent.NewKey(h, cfgs...)
284+
require.NoError(t, err)
285+
require.NotNil(t, k)
286+
287+
dest := new(bytes.Buffer)
288+
err = k.EncryptPipe(bytes.NewReader(clearContent), dest)
289+
require.NoError(t, err)
290+
encryptedContent := dest.String()
291+
292+
// Serve the encrypted content
293+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
294+
time.Sleep(1 * time.Second)
295+
io.WriteString(w, encryptedContent) //nolint
296+
}))
297+
defer ts.Close()
298+
299+
wg := &sync.WaitGroup{}
300+
301+
for i := 0; i < 100; i++ {
302+
wg.Add(1)
303+
go func() {
304+
res, err := http.Get(ts.URL)
305+
require.NoError(t, err)
306+
307+
defer res.Body.Close()
308+
309+
dest := new(bytes.Buffer)
310+
err = k.DecryptPipe(res.Body, dest)
311+
require.NoError(t, err)
312+
313+
// Ensure the content is correctly decrypted
314+
assert.EqualValues(t, clearContent, dest.Bytes())
315+
wg.Done()
316+
}()
317+
}
318+
319+
wg.Wait()
320+
}

stream/stream.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,23 +177,23 @@ func NewReader(r io.Reader, k symmecrypt.Key, chunkSize int, extras ...[]byte) i
177177

178178
func (r *chunksReader) readNewChunk() error {
179179
// read the chunksize
180-
headerBtes := make([]byte, binary.MaxVarintLen32)
181-
if _, err := r.src.Read(headerBtes); err != nil { // READING THE CLEAR HEADER FROM THE ENCRYPTED SOURCE
180+
var headerBtes = new(bytes.Buffer)
181+
if _, err := io.CopyN(headerBtes, r.src, binary.MaxVarintLen32); err != nil {
182182
return err
183183
}
184184

185-
n, err := binary.ReadUvarint(bytes.NewReader(headerBtes)) // READ THE HEADER BUFFER
185+
n, err := binary.ReadUvarint(bytes.NewReader(headerBtes.Bytes())) // READ THE HEADER BUFFER
186186
if err != nil {
187187
return err
188188
}
189189

190190
// read the chunk content
191-
btes := make([]byte, n)
192-
_, err = r.src.Read(btes)
193-
if err != nil && err != io.EOF {
191+
var btsBuff = new(bytes.Buffer)
192+
if _, err := io.CopyN(btsBuff, r.src, int64(n)); err != nil && err != io.EOF {
194193
return err
195194
}
196195

196+
var btes = btsBuff.Bytes()
197197
var clearContent []byte
198198

199199
if r.uncappedK == nil {
@@ -225,7 +225,7 @@ func (r *chunksReader) Read(p []byte) (x int, e error) {
225225
}
226226
}
227227

228-
if len(p)+r.currentChunkReadBytes > r.chunkSize {
228+
if len(p)+r.currentChunkReadBytes >= r.chunkSize {
229229
var pp = p
230230
for {
231231
// The first part of 'p' will store the current chunk

stream/stream_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package stream_test
2+
3+
import (
4+
"bytes"
5+
"crypto/rand"
6+
"io"
7+
"os"
8+
"strings"
9+
"testing"
10+
11+
"github.com/ovh/configstore"
12+
"github.com/ovh/symmecrypt/keyloader"
13+
"github.com/ovh/symmecrypt/stream"
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
)
17+
18+
func ProviderTest() (configstore.ItemList, error) {
19+
ret := configstore.ItemList{
20+
Items: []configstore.Item{
21+
configstore.NewItem(
22+
keyloader.EncryptionKeyConfigName,
23+
`{"key":"5fdb8af280b007a46553dfddb3f42bc10619dcabca8d4fdf5239b09445ab1a41","identifier":"test","sealed":false,"timestamp":1522325806,"cipher":"aes-gcm"}`,
24+
1,
25+
),
26+
configstore.NewItem(
27+
keyloader.EncryptionKeyConfigName,
28+
`{"key":"7db2b4b695e11563edca94b0f9c7ad16919fc11eac414c1b1706cbaa3c3e61a4b884301ae4e8fbedcc4f000b9c52904f13ea9456379d373524dea7fef79b39f7","identifier":"test-composite","sealed":false,"timestamp":1522325758,"cipher":"aes-pmac-siv"}`,
29+
1,
30+
),
31+
configstore.NewItem(
32+
keyloader.EncryptionKeyConfigName,
33+
`{"key":"QXdDW4N/jmJzpMu7i1zu4YF1opTn7H+eOk9CLFGBSFg=","identifier":"test-composite","sealed":false,"timestamp":1522325802,"cipher":"xchacha20-poly1305"}`,
34+
1,
35+
),
36+
},
37+
}
38+
return ret, nil
39+
}
40+
41+
func TestMain(m *testing.M) {
42+
configstore.RegisterProvider("test", ProviderTest)
43+
os.Exit(m.Run())
44+
}
45+
46+
func TestIncompleteRead(t *testing.T) {
47+
clearContent := make([]byte, 32*1024+10)
48+
rand.Read(clearContent) // nolint
49+
50+
k, err := keyloader.LoadKey("test")
51+
require.NoError(t, err)
52+
53+
var bufWriter bytes.Buffer
54+
streamWriter := stream.NewWriter(&bufWriter, k, 32*1024)
55+
nbBytesWritten, err := io.Copy(streamWriter, bytes.NewReader(clearContent))
56+
require.NoError(t, err)
57+
t.Logf("%d bytes copied to streamWriter", nbBytesWritten)
58+
require.NoError(t, streamWriter.Close())
59+
60+
streamReader := stream.NewReader(strings.NewReader(bufWriter.String()), k, 32*1024)
61+
var firstPart = make([]byte, 32*1024)
62+
nbBytesReaden1, err := streamReader.Read(firstPart)
63+
t.Logf("%d bytes read the first time", nbBytesReaden1)
64+
require.NoError(t, err)
65+
66+
var secondPart = make([]byte, 32*1024)
67+
nbBytesReaden2, err := streamReader.Read(secondPart)
68+
t.Logf("%d bytes read the second time", nbBytesReaden2)
69+
assert.Error(t, err)
70+
assert.Contains(t, err.Error(), "EOF")
71+
72+
require.Equal(t, 32*1024+10, nbBytesReaden1+nbBytesReaden2)
73+
74+
}

0 commit comments

Comments
 (0)