diff --git a/api_file.go b/api_file.go index 6baef43..8450252 100644 --- a/api_file.go +++ b/api_file.go @@ -1,6 +1,9 @@ package main import ( + "fmt" + "time" + "gosvc/httpserver" "gosvc/logger" "gosvc/validator" @@ -88,15 +91,20 @@ func (s *Service) HandleDownloadFile( resp *httpserver.Response, ) *httpserver.Response { path := req.QueryString("path") + preview := req.QueryBool("preview") newPath := utils.NormalizePath(path) - _, _, err := s.FileSystemManager.DownloadFile(req.Context(), utils.FullPath(newPath)) + entry, downloader, err := s.FileSystemManager.DownloadFile(req.Context(), utils.FullPath(newPath)) if err != nil { logger.Error("download %s: %v", path, err) return resp.InternalServerError("download file failed, " + err.Error()) } - return resp.NoContent() + if !preview { + resp.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", entry.FullPath.Name())) + } + + return resp.ServeContent(entry.FullPath.Name(), time.Unix(entry.LastModificationTime, 0), downloader) } func (s *Service) HandleDeleteFile( diff --git a/engine/filesystem_manager.go b/engine/filesystem_manager.go index ce07cce..3fea3f4 100644 --- a/engine/filesystem_manager.go +++ b/engine/filesystem_manager.go @@ -158,7 +158,7 @@ func (f *FileSystemManager) CreateFile(ctx context.Context, path utils.FullPath, return nil } -func (f *FileSystemManager) DownloadFile(ctx context.Context, path utils.FullPath) ([]byte, *utils.Entry, error) { +func (f *FileSystemManager) DownloadFile(ctx context.Context, path utils.FullPath) (*utils.Entry, io.ReadSeeker, error) { f.RLock() defer f.RUnlock() @@ -171,12 +171,12 @@ func (f *FileSystemManager) DownloadFile(ctx context.Context, path utils.FullPat return nil, nil, fmt.Errorf("cannot download directory") } - content, err := f.storage.ReadObject(entry.S3Key) + downloader, err := f.storage.DownloadFile(entry.S3Key) if err != nil { return nil, nil, fmt.Errorf("read s3 object failed: %v", err) } - return content, entry, nil + return entry, downloader, nil } func (f *FileSystemManager) DeleteFile(ctx context.Context, path utils.FullPath, isDir bool) error { diff --git a/engine/storage_manager.go b/engine/storage_manager.go index 570bf2f..1353f81 100644 --- a/engine/storage_manager.go +++ b/engine/storage_manager.go @@ -4,7 +4,6 @@ import ( "bytes" "fmt" "io" - "os" "robotfs/utils" @@ -122,19 +121,24 @@ func (sm *StorageManager) UploadFile(s3Key string, data io.Reader, contentType s return output, nil } -func (sm *StorageManager) DownloadFile(S3Key string, localFilePath string) error { - file, err := os.Create(localFilePath) - if err != nil { - return err - } - defer file.Close() - - _, err = sm.s3Downloader.Download(file, &s3.GetObjectInput{ +func (sm *StorageManager) DownloadFile(S3Key string) (io.ReadSeeker, error) { + head, err := sm.s3Client.HeadObject(&s3.HeadObjectInput{ Bucket: aws.String(sm.bucketName), Key: aws.String(S3Key), }) + if err != nil { + return nil, err + } - return err + result, err := sm.s3Client.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(sm.bucketName), + Key: aws.String(S3Key), + }) + if err != nil { + return nil, err + } + + return utils.NewS3ReadSeeker(sm.s3Client, sm.bucketName, S3Key, result.Body, *head.ContentLength), nil } func (sm *StorageManager) ReadObject(S3Key string) ([]byte, error) { diff --git a/router.go b/router.go index e35a7c3..121fa66 100644 --- a/router.go +++ b/router.go @@ -99,6 +99,11 @@ func (s *Service) RegisterRouteRules() { Method: http.MethodGet, Handler: s.HandleDownloadFile, QueryRules: []*httpserver.QueryRule{ + { + Key: "preview", + Type: httpserver.QueryTypeBool, + Required: true, + }, { Key: "path", Type: httpserver.QueryTypeString, diff --git a/utils/s3ReadSeeker.go b/utils/s3ReadSeeker.go new file mode 100644 index 0000000..7fbff61 --- /dev/null +++ b/utils/s3ReadSeeker.go @@ -0,0 +1,82 @@ +package utils + +import ( + "fmt" + "io" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" +) + +type S3ReadSeeker struct { + client *s3.S3 + bucket string + key string + offset int64 + body io.ReadCloser + contentLen int64 +} + +func NewS3ReadSeeker(client *s3.S3, bucket, key string, body io.ReadCloser, contentLen int64) *S3ReadSeeker { + return &S3ReadSeeker{ + client: client, + bucket: bucket, + key: key, + body: body, + contentLen: contentLen, + } +} + +func (r *S3ReadSeeker) Read(p []byte) (n int, err error) { + if r.body == nil { + return 0, io.EOF + } + return r.body.Read(p) +} + +func (r *S3ReadSeeker) Seek(offset int64, whence int) (int64, error) { + if r.body != nil { + r.body.Close() + r.body = nil + } + + var newOffset int64 + switch whence { + case io.SeekStart: + newOffset = offset + case io.SeekCurrent: + newOffset = r.offset + offset + case io.SeekEnd: + newOffset = r.contentLen + offset + default: + return 0, fmt.Errorf("invalid whence") + } + + if newOffset < 0 { + return 0, fmt.Errorf("negative offset") + } + + if newOffset >= r.contentLen { + return r.contentLen, nil + } + + result, err := r.client.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(r.bucket), + Key: aws.String(r.key), + Range: aws.String(fmt.Sprintf("bytes=%d-", newOffset)), + }) + if err != nil { + return 0, err + } + + r.body = result.Body + r.offset = newOffset + return newOffset, nil +} + +func (r *S3ReadSeeker) Close() error { + if r.body != nil { + return r.body.Close() + } + return nil +}