diff --git a/api_file.go b/api_file.go index 8450252..176370b 100644 --- a/api_file.go +++ b/api_file.go @@ -94,11 +94,12 @@ func (s *Service) HandleDownloadFile( preview := req.QueryBool("preview") newPath := utils.NormalizePath(path) - entry, downloader, err := s.FileSystemManager.DownloadFile(req.Context(), utils.FullPath(newPath)) + downloader, entry, err := s.FileSystemManager.DownloadFile(req.Context(), utils.FullPath(newPath)) if err != nil { logger.Error("download %s: %v", path, err) return resp.InternalServerError("download file failed, " + err.Error()) } + defer downloader.Close() if !preview { resp.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", entry.FullPath.Name())) diff --git a/engine/filesystem_manager.go b/engine/filesystem_manager.go index 3fea3f4..4c2ed2c 100644 --- a/engine/filesystem_manager.go +++ b/engine/filesystem_manager.go @@ -158,7 +158,7 @@ func (f *FileSystemManager) CreateFile(ctx context.Context, path utils.FullPath, return nil } -func (f *FileSystemManager) DownloadFile(ctx context.Context, path utils.FullPath) (*utils.Entry, io.ReadSeeker, error) { +func (f *FileSystemManager) DownloadFile(ctx context.Context, path utils.FullPath) (*utils.S3ReadSeeker, *utils.Entry, error) { f.RLock() defer f.RUnlock() @@ -176,7 +176,7 @@ func (f *FileSystemManager) DownloadFile(ctx context.Context, path utils.FullPat return nil, nil, fmt.Errorf("read s3 object failed: %v", err) } - return entry, downloader, nil + return downloader, entry, nil } func (f *FileSystemManager) DeleteFile(ctx context.Context, path utils.FullPath, isDir bool) error { diff --git a/engine/storage_manager.go b/engine/storage_manager.go index 1353f81..296d7b7 100644 --- a/engine/storage_manager.go +++ b/engine/storage_manager.go @@ -121,7 +121,7 @@ func (sm *StorageManager) UploadFile(s3Key string, data io.Reader, contentType s return output, nil } -func (sm *StorageManager) DownloadFile(S3Key string) (io.ReadSeeker, error) { +func (sm *StorageManager) DownloadFile(S3Key string) (*utils.S3ReadSeeker, error) { head, err := sm.s3Client.HeadObject(&s3.HeadObjectInput{ Bucket: aws.String(sm.bucketName), Key: aws.String(S3Key), diff --git a/utils/s3ReadSeeker.go b/utils/s3ReadSeeker.go index 7fbff61..f10c4dc 100644 --- a/utils/s3ReadSeeker.go +++ b/utils/s3ReadSeeker.go @@ -8,6 +8,10 @@ import ( "github.com/aws/aws-sdk-go/service/s3" ) +const ( + bufferSize = 4 * 1024 * 1024 // 4MB +) + type S3ReadSeeker struct { client *s3.S3 bucket string @@ -15,6 +19,9 @@ type S3ReadSeeker struct { offset int64 body io.ReadCloser contentLen int64 + buffer []byte + bufLen int + bufPos int } func NewS3ReadSeeker(client *s3.S3, bucket, key string, body io.ReadCloser, contentLen int64) *S3ReadSeeker { @@ -24,22 +31,44 @@ func NewS3ReadSeeker(client *s3.S3, bucket, key string, body io.ReadCloser, cont key: key, body: body, contentLen: contentLen, + buffer: make([]byte, bufferSize), } } func (r *S3ReadSeeker) Read(p []byte) (n int, err error) { + if r.bufPos < r.bufLen { + n = copy(p, r.buffer[r.bufPos:r.bufLen]) + r.bufPos += n + return n, nil + } + if r.body == nil { return 0, io.EOF } - return r.body.Read(p) + + if len(p) > bufferSize { + n, err = r.body.Read(p) + r.offset += int64(n) + return n, err + } + + r.bufLen, err = r.body.Read(r.buffer) + if err != nil && err != io.EOF { + return 0, err + } + + if r.bufLen > 0 { + r.bufPos = 0 + n = copy(p, r.buffer[:r.bufLen]) + r.bufPos = n + r.offset += int64(n) + return n, nil + } + + return 0, io.EOF } func (r *S3ReadSeeker) Seek(offset int64, whence int) (int64, error) { - if r.body != nil { - r.body.Close() - r.body = nil - } - var newOffset int64 switch whence { case io.SeekStart: @@ -60,6 +89,19 @@ func (r *S3ReadSeeker) Seek(offset int64, whence int) (int64, error) { return r.contentLen, nil } + if newOffset >= r.offset-int64(r.bufLen) && newOffset < r.offset { + r.bufPos = int(newOffset - (r.offset - int64(r.bufLen))) + r.offset = newOffset + return newOffset, nil + } + + if r.body != nil { + r.body.Close() + r.body = nil + } + r.bufLen = 0 + r.bufPos = 0 + result, err := r.client.GetObject(&s3.GetObjectInput{ Bucket: aws.String(r.bucket), Key: aws.String(r.key),