This commit is contained in:
2024-02-25 08:30:34 +08:00
commit 4947f39e74
273 changed files with 45396 additions and 0 deletions

309
pkg/filesystem/archive.go Normal file
View File

@@ -0,0 +1,309 @@
package filesystem
import (
"archive/zip"
"context"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/mholt/archiver/v4"
)
/* ===============
压缩/解压缩
===============
*/
// Compress 创建给定目录和文件的压缩文件
func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs, fileIDs []uint, isArchive bool) error {
// 查找待压缩目录
folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID)
if err != nil && len(folderIDs) != 0 {
return ErrDBListObjects
}
// 查找待压缩文件
files, err := model.GetFilesByIDs(fileIDs, fs.User.ID)
if err != nil && len(fileIDs) != 0 {
return ErrDBListObjects
}
// 如果上下文限制了父目录,则进行检查
if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok {
// 检查目录
for _, folder := range folders {
if *folder.ParentID != parent.ID {
return ErrObjectNotExist
}
}
// 检查文件
for _, file := range files {
if file.FolderID != parent.ID {
return ErrObjectNotExist
}
}
}
// 尝试获取请求上下文,以便于后续检查用户取消任务
reqContext := ctx
ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context)
if ok {
reqContext = ginCtx.Request.Context()
}
// 将顶级待处理对象的路径设为根路径
for i := 0; i < len(folders); i++ {
folders[i].Position = ""
}
for i := 0; i < len(files); i++ {
files[i].Position = ""
}
// 创建压缩文件Writer
zipWriter := zip.NewWriter(writer)
defer zipWriter.Close()
ctx = reqContext
// 压缩各个目录及文件
for i := 0; i < len(folders); i++ {
select {
case <-reqContext.Done():
// 取消压缩请求
return ErrClientCanceled
default:
fs.doCompress(reqContext, nil, &folders[i], zipWriter, isArchive)
}
}
for i := 0; i < len(files); i++ {
select {
case <-reqContext.Done():
// 取消压缩请求
return ErrClientCanceled
default:
fs.doCompress(reqContext, &files[i], nil, zipWriter, isArchive)
}
}
return nil
}
func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) {
// 如果对象是文件
if file != nil {
// 切换上传策略
fs.Policy = file.GetPolicy()
err := fs.DispatchHandler()
if err != nil {
util.Log().Warning("Failed to compress file %q: %s", file.Name, err)
return
}
// 获取文件内容
fileToZip, err := fs.Handler.Get(
context.WithValue(ctx, fsctx.FileModelCtx, *file),
file.SourceName,
)
if err != nil {
util.Log().Debug("Failed to open %q: %s", file.Name, err)
return
}
if closer, ok := fileToZip.(io.Closer); ok {
defer closer.Close()
}
// 创建压缩文件头
header := &zip.FileHeader{
Name: filepath.FromSlash(path.Join(file.Position, file.Name)),
Modified: file.UpdatedAt,
UncompressedSize64: file.Size,
}
// 指定是压缩还是归档
if isArchive {
header.Method = zip.Store
} else {
header.Method = zip.Deflate
}
writer, err := zipWriter.CreateHeader(header)
if err != nil {
return
}
_, err = io.Copy(writer, fileToZip)
} else if folder != nil {
// 对象是目录
// 获取子文件
subFiles, err := folder.GetChildFiles()
if err == nil && len(subFiles) > 0 {
for i := 0; i < len(subFiles); i++ {
fs.doCompress(ctx, &subFiles[i], nil, zipWriter, isArchive)
}
}
// 获取子目录,继续递归遍历
subFolders, err := folder.GetChildFolder()
if err == nil && len(subFolders) > 0 {
for i := 0; i < len(subFolders); i++ {
fs.doCompress(ctx, nil, &subFolders[i], zipWriter, isArchive)
}
}
}
}
// Decompress 解压缩给定压缩文件到dst目录
func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string) error {
err := fs.ResetFileIfNotExist(ctx, src)
if err != nil {
return err
}
tempZipFilePath := ""
defer func() {
// 结束时删除临时压缩文件
if tempZipFilePath != "" {
if err := os.Remove(tempZipFilePath); err != nil {
util.Log().Warning("Failed to delete temp archive file %q: %s", tempZipFilePath, err)
}
}
}()
// 下载压缩文件到临时目录
fileStream, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName)
if err != nil {
return err
}
defer fileStream.Close()
tempZipFilePath = filepath.Join(
util.RelativePath(model.GetSettingByName("temp_path")),
"decompress",
fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()),
)
zipFile, err := util.CreatNestedFile(tempZipFilePath)
if err != nil {
util.Log().Warning("Failed to create temp archive file %q: %s", tempZipFilePath, err)
tempZipFilePath = ""
return err
}
defer zipFile.Close()
// 下载前先判断是否是可解压的格式
format, readStream, err := archiver.Identify(fs.FileTarget[0].SourceName, fileStream)
if err != nil {
util.Log().Warning("Failed to detect compressed format of file %q: %s", fs.FileTarget[0].SourceName, err)
return err
}
extractor, ok := format.(archiver.Extractor)
if !ok {
return fmt.Errorf("file not an extractor %s", fs.FileTarget[0].SourceName)
}
// 只有zip格式可以多个文件同时上传
var isZip bool
switch extractor.(type) {
case archiver.Zip:
extractor = archiver.Zip{TextEncoding: encoding}
isZip = true
}
// 除了zip必须下载到本地其余的可以边下载边解压
reader := readStream
if isZip {
_, err = io.Copy(zipFile, readStream)
if err != nil {
util.Log().Warning("Failed to write temp archive file %q: %s", tempZipFilePath, err)
return err
}
fileStream.Close()
// 设置文件偏移量
zipFile.Seek(0, io.SeekStart)
reader = zipFile
}
var wg sync.WaitGroup
parallel := model.GetIntSetting("max_parallel_transfer", 4)
worker := make(chan int, parallel)
for i := 0; i < parallel; i++ {
worker <- i
}
// 上传文件函数
uploadFunc := func(fileStream io.ReadCloser, size int64, savePath, rawPath string) {
defer func() {
if isZip {
worker <- 1
wg.Done()
}
if err := recover(); err != nil {
util.Log().Warning("Error while uploading files inside of archive file.")
fmt.Println(err)
}
}()
err := fs.UploadFromStream(ctx, &fsctx.FileStream{
File: fileStream,
Size: uint64(size),
Name: path.Base(savePath),
VirtualPath: path.Dir(savePath),
}, true)
fileStream.Close()
if err != nil {
util.Log().Debug("Failed to upload file %q in archive file: %s, skipping...", rawPath, err)
}
}
// 解压缩文件回调函数如果出错会停止解压的下一步进行全部return nil
err = extractor.Extract(ctx, reader, nil, func(ctx context.Context, f archiver.File) error {
rawPath := util.FormSlash(f.NameInArchive)
savePath := path.Join(dst, rawPath)
// 路径是否合法
if !strings.HasPrefix(savePath, util.FillSlash(path.Clean(dst))) {
util.Log().Warning("%s: illegal file path", f.NameInArchive)
return nil
}
// 如果是目录
if f.FileInfo.IsDir() {
fs.CreateDirectory(ctx, savePath)
return nil
}
// 上传文件
fileStream, err := f.Open()
if err != nil {
util.Log().Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err)
return nil
}
if !isZip {
uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
} else {
<-worker
wg.Add(1)
go uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
}
return nil
})
wg.Wait()
return err
}

View File

@@ -0,0 +1,74 @@
package backoff
import (
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/http"
"strconv"
"time"
)
// Backoff used for retry sleep backoff
type Backoff interface {
Next(err error) bool
Reset()
}
// ConstantBackoff implements Backoff interface with constant sleep time. If the error
// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration.
type ConstantBackoff struct {
Sleep time.Duration
Max int
tried int
}
func (c *ConstantBackoff) Next(err error) bool {
c.tried++
if c.tried > c.Max {
return false
}
var e *RetryableError
if errors.As(err, &e) && e.RetryAfter > 0 {
util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter)
time.Sleep(e.RetryAfter)
} else {
time.Sleep(c.Sleep)
}
return true
}
func (c *ConstantBackoff) Reset() {
c.tried = 0
}
type RetryableError struct {
Err error
RetryAfter time.Duration
}
// NewRetryableErrorFromHeader constructs a new RetryableError from http response header
// and existing error.
func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError {
retryAfter := header.Get("retry-after")
if retryAfter == "" {
retryAfter = "0"
}
res := &RetryableError{
Err: err,
}
if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil {
res.RetryAfter = time.Duration(retryAfterSecond) * time.Second
}
return res
}
func (e *RetryableError) Error() string {
return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err)
}

View File

@@ -0,0 +1,167 @@
package chunk
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"os"
)
const bufferTempPattern = "cdChunk.*.tmp"
// ChunkProcessFunc callback function for processing a chunk
type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error
// ChunkGroup manage groups of chunks
type ChunkGroup struct {
file fsctx.FileHeader
chunkSize uint64
backoff backoff.Backoff
enableRetryBuffer bool
fileInfo *fsctx.UploadTaskInfo
currentIndex int
chunkNum uint64
bufferTemp *os.File
}
func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff, useBuffer bool) *ChunkGroup {
c := &ChunkGroup{
file: file,
chunkSize: chunkSize,
backoff: backoff,
fileInfo: file.Info(),
currentIndex: -1,
enableRetryBuffer: useBuffer,
}
if c.chunkSize == 0 {
c.chunkSize = c.fileInfo.Size
}
if c.fileInfo.Size == 0 {
c.chunkNum = 1
} else {
c.chunkNum = c.fileInfo.Size / c.chunkSize
if c.fileInfo.Size%c.chunkSize != 0 {
c.chunkNum++
}
}
return c
}
// TempAvailable returns if current chunk temp file is available to be read
func (c *ChunkGroup) TempAvailable() bool {
if c.bufferTemp != nil {
state, _ := c.bufferTemp.Stat()
return state != nil && state.Size() == c.Length()
}
return false
}
// Process a chunk with retry logic
func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
reader := io.LimitReader(c.file, c.Length())
// If useBuffer is enabled, tee the reader to a temp file
if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() {
c.bufferTemp, _ = os.CreateTemp("", bufferTempPattern)
reader = io.TeeReader(reader, c.bufferTemp)
}
if c.bufferTemp != nil {
defer func() {
if c.bufferTemp != nil {
c.bufferTemp.Close()
os.Remove(c.bufferTemp.Name())
c.bufferTemp = nil
}
}()
// if temp buffer file is available, use it
if c.TempAvailable() {
if _, err := c.bufferTemp.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek temp file back to chunk start: %w", err)
}
util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
reader = io.NopCloser(c.bufferTemp)
}
}
err := processor(c, reader)
if err != nil {
if c.enableRetryBuffer {
request.BlackHole(reader)
}
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) {
if c.file.Seekable() {
if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil {
return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err)
}
}
util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err)
return c.Process(processor)
}
return err
}
util.Log().Debug("Chunk %d processed", c.currentIndex)
return nil
}
// Start returns the byte index of current chunk
func (c *ChunkGroup) Start() int64 {
return int64(uint64(c.Index()) * c.chunkSize)
}
// Total returns the total length
func (c *ChunkGroup) Total() int64 {
return int64(c.fileInfo.Size)
}
// Num returns the total chunk number
func (c *ChunkGroup) Num() int {
return int(c.chunkNum)
}
// RangeHeader returns header value of Content-Range
func (c *ChunkGroup) RangeHeader() string {
return fmt.Sprintf("bytes %d-%d/%d", c.Start(), c.Start()+c.Length()-1, c.Total())
}
// Index returns current chunk index, starts from 0
func (c *ChunkGroup) Index() int {
return c.currentIndex
}
// Next switch to next chunk, returns whether all chunks are processed
func (c *ChunkGroup) Next() bool {
c.currentIndex++
c.backoff.Reset()
return c.currentIndex < int(c.chunkNum)
}
// Length returns the length of current chunk
func (c *ChunkGroup) Length() int64 {
contentLength := c.chunkSize
if c.Index() == int(c.chunkNum-1) {
contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1)
}
return int64(contentLength)
}
// IsLast returns if current chunk is the last one
func (c *ChunkGroup) IsLast() bool {
return c.Index() == int(c.chunkNum-1)
}

View File

@@ -0,0 +1,427 @@
package cos
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/google/go-querystring/query"
cossdk "github.com/tencentyun/cos-go-sdk-v5"
)
// UploadPolicy 腾讯云COS上传策略
type UploadPolicy struct {
Expiration string `json:"expiration"`
Conditions []interface{} `json:"conditions"`
}
// MetaData 文件元信息
type MetaData struct {
Size uint64
CallbackKey string
CallbackURL string
}
type urlOption struct {
Speed int `url:"x-cos-traffic-limit,omitempty"`
ContentDescription string `url:"response-content-disposition,omitempty"`
}
// Driver 腾讯云COS适配器模板
type Driver struct {
Policy *model.Policy
Client *cossdk.Client
HTTPClient request.Client
}
// List 列出COS文件
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
// 初始化列目录参数
opt := &cossdk.BucketGetOptions{
Prefix: strings.TrimPrefix(base, "/"),
EncodingType: "",
MaxKeys: 1000,
}
// 是否为递归列出
if !recursive {
opt.Delimiter = "/"
}
// 手动补齐结尾的slash
if opt.Prefix != "" {
opt.Prefix += "/"
}
var (
marker string
objects []cossdk.Object
commons []string
)
for {
res, _, err := handler.Client.Bucket.Get(ctx, opt)
if err != nil {
return nil, err
}
objects = append(objects, res.Contents...)
commons = append(commons, res.CommonPrefixes...)
// 如果本次未列取完则继续使用marker获取结果
marker = res.NextMarker
// marker 为空时结果列取完毕,跳出
if marker == "" {
break
}
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(opt.Prefix, object)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(opt.Prefix, object.Key)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Size),
IsDir: false,
LastModify: time.Now(),
})
}
return res, nil
}
// CORS 创建跨域策略
func (handler Driver) CORS() error {
_, err := handler.Client.Bucket.PutCORS(context.Background(), &cossdk.BucketPutCORSOptions{
Rules: []cossdk.BucketCORSRule{{
AllowedMethods: []string{
"GET",
"POST",
"PUT",
"DELETE",
"HEAD",
},
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"*"},
MaxAgeSeconds: 3600,
ExposeHeaders: []string{},
}},
})
return err
}
// Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.HTTPClient.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
opt := &cossdk.ObjectPutOptions{}
_, err := handler.Client.Object.Put(ctx, file.Info().SavePath, file, opt)
return err
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
obs := []cossdk.Object{}
for _, v := range files {
obs = append(obs, cossdk.Object{Key: v})
}
opt := &cossdk.ObjectDeleteMultiOptions{
Objects: obs,
Quiet: true,
}
res, _, err := handler.Client.Object.DeleteMulti(context.Background(), opt)
if err != nil {
return files, err
}
// 整理删除结果
failed := make([]string, 0, len(files))
for _, v := range res.Errors {
failed = append(failed, v.Key)
}
if len(failed) == 0 {
return failed, nil
}
return failed, errors.New("delete failed")
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://cloud.tencent.com/document/product/436/44893
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heif", "heic"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) || file.Size > (32<<(10*2)) {
return nil, driver.ErrorThumbNotSupported
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumbParam := fmt.Sprintf("imageMogr2/thumbnail/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality)
source, err := handler.signSourceURL(
ctx,
file.SourceName,
int64(model.GetIntSetting("preview_timeout", 60)),
&urlOption{},
)
if err != nil {
return nil, err
}
thumbURL, _ := url.Parse(source)
thumbQuery := thumbURL.Query()
thumbQuery.Add(thumbParam, "")
thumbURL.RawQuery = thumbQuery.Encode()
return &response.ContentResponse{
Redirect: true,
URL: thumbURL.String(),
}, nil
}
// Source 获取外链URL
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 添加各项设置
options := urlOption{}
if speed > 0 {
if speed < 819200 {
speed = 819200
}
if speed > 838860800 {
speed = 838860800
}
options.Speed = speed
}
if isDownload {
options.ContentDescription = "attachment; filename=\"" + url.PathEscape(fileName) + "\""
}
return handler.signSourceURL(ctx, path, ttl, &options)
}
func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64, options *urlOption) (string, error) {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
// 公有空间不需要签名
if !handler.Policy.IsPrivate {
file, err := url.Parse(path)
if err != nil {
return "", err
}
// 非签名URL不支持设置响应header
options.ContentDescription = ""
optionQuery, err := query.Values(*options)
if err != nil {
return "", err
}
file.RawQuery = optionQuery.Encode()
sourceURL := cdnURL.ResolveReference(file)
return sourceURL.String(), nil
}
presignedURL, err := handler.Client.Object.GetPresignedURL(ctx, http.MethodGet, path,
handler.Policy.AccessKey, handler.Policy.SecretKey, time.Duration(ttl)*time.Second, options)
if err != nil {
return "", err
}
// 将最终生成的签名URL域名换成用户自定义的加速域名如果有
presignedURL.Host = cdnURL.Host
presignedURL.Scheme = cdnURL.Scheme
return presignedURL.String(), nil
}
// Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/cos/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI).String()
// 上传策略
savePath := file.Info().SavePath
startTime := time.Now()
endTime := startTime.Add(time.Duration(ttl) * time.Second)
keyTime := fmt.Sprintf("%d;%d", startTime.Unix(), endTime.Unix())
postPolicy := UploadPolicy{
Expiration: endTime.UTC().Format(time.RFC3339),
Conditions: []interface{}{
map[string]string{"bucket": handler.Policy.BucketName},
map[string]string{"$key": savePath},
map[string]string{"x-cos-meta-callback": apiURL},
map[string]string{"x-cos-meta-key": uploadSession.Key},
map[string]string{"q-sign-algorithm": "sha1"},
map[string]string{"q-ak": handler.Policy.AccessKey},
map[string]string{"q-sign-time": keyTime},
},
}
if handler.Policy.MaxSize > 0 {
postPolicy.Conditions = append(postPolicy.Conditions,
[]interface{}{"content-length-range", 0, handler.Policy.MaxSize})
}
res, err := handler.getUploadCredential(ctx, postPolicy, keyTime, savePath)
if err == nil {
res.SessionID = uploadSession.Key
res.Callback = apiURL
res.UploadURLs = []string{handler.Policy.Server}
}
return res, err
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}
// Meta 获取文件信息
func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
res, err := handler.Client.Object.Head(ctx, path, &cossdk.ObjectHeadOptions{})
if err != nil {
return nil, err
}
return &MetaData{
Size: uint64(res.ContentLength),
CallbackKey: res.Header.Get("x-cos-meta-key"),
CallbackURL: res.Header.Get("x-cos-meta-callback"),
}, nil
}
func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, keyTime string, savePath string) (*serializer.UploadCredential, error) {
// 编码上传策略
policyJSON, err := json.Marshal(policy)
if err != nil {
return nil, err
}
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
// 签名上传策略
hmacSign := hmac.New(sha1.New, []byte(handler.Policy.SecretKey))
_, err = io.WriteString(hmacSign, keyTime)
if err != nil {
return nil, err
}
signKey := fmt.Sprintf("%x", hmacSign.Sum(nil))
sha1Sign := sha1.New()
_, err = sha1Sign.Write(policyJSON)
if err != nil {
return nil, err
}
stringToSign := fmt.Sprintf("%x", sha1Sign.Sum(nil))
// 最终签名
hmacFinalSign := hmac.New(sha1.New, []byte(signKey))
_, err = hmacFinalSign.Write([]byte(stringToSign))
if err != nil {
return nil, err
}
signature := hmacFinalSign.Sum(nil)
return &serializer.UploadCredential{
Policy: policyEncoded,
Path: savePath,
AccessKey: handler.Policy.AccessKey,
Credential: fmt.Sprintf("%x", signature),
KeyTime: keyTime,
}, nil
}

View File

@@ -0,0 +1,134 @@
package cos
import (
"archive/zip"
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/url"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
scf "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf/v20180416"
)
const scfFunc = `# -*- coding: utf8 -*-
# SCF配置COS触发向 Cloudreve 发送回调
from qcloud_cos_v5 import CosConfig
from qcloud_cos_v5 import CosS3Client
from qcloud_cos_v5 import CosServiceError
from qcloud_cos_v5 import CosClientError
import sys
import logging
import requests
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger()
def main_handler(event, context):
logger.info("start main handler")
for record in event['Records']:
try:
if "x-cos-meta-callback" not in record['cos']['cosObject']['meta']:
logger.info("Cannot find callback URL, skiped.")
return 'Success'
callback = record['cos']['cosObject']['meta']['x-cos-meta-callback']
key = record['cos']['cosObject']['key']
logger.info("Callback URL is " + callback)
r = requests.get(callback)
print(r.text)
except Exception as e:
print(e)
print('Error getting object {} callback url. '.format(key))
raise e
return "Fail"
return "Success"
`
// CreateSCF 创建回调云函数
func CreateSCF(policy *model.Policy, region string) error {
// 初始化客户端
credential := common.NewCredential(
policy.AccessKey,
policy.SecretKey,
)
cpf := profile.NewClientProfile()
client, err := scf.NewClient(credential, region, cpf)
if err != nil {
return err
}
// 创建回调代码数据
buff := &bytes.Buffer{}
bs64 := base64.NewEncoder(base64.StdEncoding, buff)
zipWriter := zip.NewWriter(bs64)
header := zip.FileHeader{
Name: "callback.py",
Method: zip.Deflate,
}
writer, err := zipWriter.CreateHeader(&header)
if err != nil {
return err
}
_, err = io.Copy(writer, strings.NewReader(scfFunc))
zipWriter.Close()
// 创建云函数
req := scf.NewCreateFunctionRequest()
funcName := "cloudreve_" + hashid.HashID(policy.ID, hashid.PolicyID) + strconv.FormatInt(time.Now().Unix(), 10)
zipFileBytes, _ := ioutil.ReadAll(buff)
zipFileStr := string(zipFileBytes)
codeSource := "ZipFile"
handler := "callback.main_handler"
desc := "Cloudreve 用回调函数"
timeout := int64(60)
runtime := "Python3.6"
req.FunctionName = &funcName
req.Code = &scf.Code{
ZipFile: &zipFileStr,
}
req.Handler = &handler
req.Description = &desc
req.Timeout = &timeout
req.Runtime = &runtime
req.CodeSource = &codeSource
_, err = client.CreateFunction(req)
if err != nil {
return err
}
time.Sleep(time.Duration(5) * time.Second)
// 创建触发器
server, _ := url.Parse(policy.Server)
triggerType := "cos"
triggerDesc := `{"event":"cos:ObjectCreated:Post","filter":{"Prefix":"","Suffix":""}}`
enable := "OPEN"
trigger := scf.NewCreateTriggerRequest()
trigger.FunctionName = &funcName
trigger.TriggerName = &server.Host
trigger.Type = &triggerType
trigger.TriggerDesc = &triggerDesc
trigger.Enable = &enable
_, err = client.CreateTrigger(trigger)
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,73 @@
package googledrive
import (
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"google.golang.org/api/drive/v3"
)
// Client Google Drive client
type Client struct {
Endpoints *Endpoints
Policy *model.Policy
Credential *Credential
ClientID string
ClientSecret string
Redirect string
Request request.Client
ClusterController cluster.Controller
}
// Endpoints OneDrive客户端相关设置
type Endpoints struct {
UserConsentEndpoint string // OAuth认证的基URL
TokenEndpoint string // OAuth token 基URL
EndpointURL string // 接口请求的基URL
}
const (
TokenCachePrefix = "googledrive_"
oauthEndpoint = "https://oauth2.googleapis.com/token"
userConsentBase = "https://accounts.google.com/o/oauth2/auth"
v3DriveEndpoint = "https://www.googleapis.com/drive/v3"
)
var (
// Defualt required scopes
RequiredScope = []string{
drive.DriveScope,
"openid",
"profile",
"https://www.googleapis.com/auth/userinfo.profile",
}
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
)
// NewClient 根据存储策略获取新的client
func NewClient(policy *model.Policy) (*Client, error) {
client := &Client{
Endpoints: &Endpoints{
TokenEndpoint: oauthEndpoint,
UserConsentEndpoint: userConsentBase,
EndpointURL: v3DriveEndpoint,
},
Credential: &Credential{
RefreshToken: policy.AccessKey,
},
Policy: policy,
ClientID: policy.BucketName,
ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OauthRedirect,
Request: request.NewClient(),
ClusterController: cluster.DefaultController,
}
return client, nil
}

View File

@@ -0,0 +1,65 @@
package googledrive
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver Google Drive 适配器
type Driver struct {
Policy *model.Policy
HTTPClient request.Client
}
// NewDriver 从存储策略初始化新的Driver实例
func NewDriver(policy *model.Policy) (driver.Handler, error) {
return &Driver{
Policy: policy,
HTTPClient: request.NewClient(),
}, nil
}
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
//TODO implement me
panic("implement me")
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
//TODO implement me
panic("implement me")
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
//TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,154 @@
package googledrive
import (
"context"
"encoding/json"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// OAuthURL 获取OAuth认证页面URL
func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
query := url.Values{
"client_id": {client.ClientID},
"scope": {strings.Join(scope, " ")},
"response_type": {"code"},
"redirect_uri": {client.Redirect},
"access_type": {"offline"},
"prompt": {"consent"},
}
u, _ := url.Parse(client.Endpoints.UserConsentEndpoint)
u.RawQuery = query.Encode()
return u.String()
}
// ObtainToken 通过code或refresh_token兑换token
func (client *Client) ObtainToken(ctx context.Context, code, refreshToken string) (*Credential, error) {
body := url.Values{
"client_id": {client.ClientID},
"redirect_uri": {client.Redirect},
"client_secret": {client.ClientSecret},
}
if code != "" {
body.Add("grant_type", "authorization_code")
body.Add("code", code)
} else {
body.Add("grant_type", "refresh_token")
body.Add("refresh_token", refreshToken)
}
strBody := body.Encode()
res := client.Request.Request(
"POST",
client.Endpoints.TokenEndpoint,
io.NopCloser(strings.NewReader(strBody)),
request.WithHeader(http.Header{
"Content-Type": {"application/x-www-form-urlencoded"}},
),
request.WithContentLength(int64(len(strBody))),
)
if res.Err != nil {
return nil, res.Err
}
respBody, err := res.GetResponse()
if err != nil {
return nil, err
}
var (
errResp OAuthError
credential Credential
decodeErr error
)
if res.Response.StatusCode != 200 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
} else {
decodeErr = json.Unmarshal([]byte(respBody), &credential)
}
if decodeErr != nil {
return nil, decodeErr
}
if errResp.ErrorType != "" {
return nil, errResp
}
return &credential, nil
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
if isSlave {
return client.fetchCredentialFromMaster(ctx)
}
oauth.GlobalMutex.Lock(client.Policy.ID)
defer oauth.GlobalMutex.Unlock(client.Policy.ID)
// 如果已存在凭证
if client.Credential != nil && client.Credential.AccessToken != "" {
// 检查已有凭证是否过期
if client.Credential.ExpiresIn > time.Now().Unix() {
// 未过期,不要更新
return nil
}
}
// 尝试从缓存中获取凭证
if cacheCredential, ok := cache.Get(TokenCachePrefix + client.ClientID); ok {
credential := cacheCredential.(Credential)
if credential.ExpiresIn > time.Now().Unix() {
client.Credential = &credential
return nil
}
}
// 获取新的凭证
if client.Credential == nil || client.Credential.RefreshToken == "" {
// 无有效的RefreshToken
util.Log().Error("Failed to refresh credential for policy %q, please login your Google account again.", client.Policy.Name)
return ErrInvalidRefreshToken
}
credential, err := client.ObtainToken(ctx, "", client.Credential.RefreshToken)
if err != nil {
return err
}
// 更新有效期为绝对时间戳
expires := credential.ExpiresIn - 60
credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
// refresh token for Google Drive does not expire in production
credential.RefreshToken = client.Credential.RefreshToken
client.Credential = credential
// 更新缓存
cache.Set(TokenCachePrefix+client.ClientID, *credential, int(expires))
return nil
}
func (client *Client) AccessToken() string {
return client.Credential.AccessToken
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
if err != nil {
return err
}
client.Credential = &Credential{AccessToken: res}
return nil
}

View File

@@ -0,0 +1,43 @@
package googledrive
import "encoding/gob"
// RespError 接口返回错误
type RespError struct {
APIError APIError `json:"error"`
}
// APIError 接口返回的错误内容
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
}
// Credential 获取token时返回的凭证
type Credential struct {
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
UserID string `json:"user_id"`
}
// OAuthError OAuth相关接口的错误响应
type OAuthError struct {
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
}
// Error 实现error接口
func (err OAuthError) Error() string {
return err.ErrorDescription
}
func init() {
gob.Register(Credential{})
}

View File

@@ -0,0 +1,52 @@
package driver
import (
"context"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrorThumbNotExist = fmt.Errorf("thumb not exist")
ErrorThumbNotSupported = fmt.Errorf("thumb not supported")
)
// Handler 存储策略适配器
type Handler interface {
// 上传文件, dst为文件存储路径size 为文件大小。上下文关闭
// 时,应取消上传并清理临时文件
Put(ctx context.Context, file fsctx.FileHeader) error
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
Delete(ctx context.Context, files []string) ([]string, error)
// 获取文件内容
Get(ctx context.Context, path string) (response.RSCloser, error)
// 获取缩略图可直接在ContentResponse中返回文件数据流也可指
// 定为重定向
// 如果缩略图不存在, 且需要 Cloudreve 代理生成并上传,应返回 ErrorThumbNotExist
// 成的缩略图文件存储规则与本机策略一致。
// 如果不支持此文件的缩略图,并且不希望后续继续请求此缩略图,应返回 ErrorThumbNotSupported
Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error)
// 获取外链/下载地址,
// url - 站点本身地址,
// isDownload - 是否直接下载
Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error)
// Token 获取有效期为ttl的上传凭证和签名
Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error)
// CancelToken 取消已经创建的有状态上传凭证
CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error
// List 递归列取远程端path路径下文件、目录不包含path本身
// 返回的对象路径以path作为起始根目录.
// recursive - 是否递归列出
List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
}

View File

@@ -0,0 +1,292 @@
package local
import (
"context"
"errors"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
const (
Perm = 0744
)
// Driver 本地策略适配器
type Driver struct {
Policy *model.Policy
}
// List 递归列取给定物理路径下所有文件
func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
var res []response.Object
// 取得起始路径
root := util.RelativePath(filepath.FromSlash(path))
// 开始遍历路径下的文件、目录
err := filepath.Walk(root,
func(path string, info os.FileInfo, err error) error {
// 跳过根目录
if path == root {
return nil
}
if err != nil {
util.Log().Warning("Failed to walk folder %q: %s", path, err)
return filepath.SkipDir
}
// 将遍历对象的绝对路径转换为相对路径
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
res = append(res, response.Object{
Name: info.Name(),
RelativePath: filepath.ToSlash(rel),
Source: path,
Size: uint64(info.Size()),
IsDir: info.IsDir(),
LastModify: info.ModTime(),
})
// 如果非递归,则不步入目录
if !recursive && info.IsDir() {
return filepath.SkipDir
}
return nil
})
return res, err
}
// Get 获取文件内容
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 打开文件
file, err := os.Open(util.RelativePath(path))
if err != nil {
util.Log().Debug("Failed to open file: %s", err)
return nil, err
}
return file, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
fileInfo := file.Info()
dst := util.RelativePath(filepath.FromSlash(fileInfo.SavePath))
// 如果非 Overwrite则检查是否有重名冲突
if fileInfo.Mode&fsctx.Overwrite != fsctx.Overwrite {
if util.Exists(dst) {
util.Log().Warning("File with the same name existed or unavailable: %s", dst)
return errors.New("file with the same name existed or unavailable")
}
}
// 如果目标目录不存在,创建
basePath := filepath.Dir(dst)
if !util.Exists(basePath) {
err := os.MkdirAll(basePath, Perm)
if err != nil {
util.Log().Warning("Failed to create directory: %s", err)
return err
}
}
var (
out *os.File
err error
)
openMode := os.O_CREATE | os.O_RDWR
if fileInfo.Mode&fsctx.Append == fsctx.Append {
openMode |= os.O_APPEND
} else {
openMode |= os.O_TRUNC
}
out, err = os.OpenFile(dst, openMode, Perm)
if err != nil {
util.Log().Warning("Failed to open or create file: %s", err)
return err
}
defer out.Close()
if fileInfo.Mode&fsctx.Append == fsctx.Append {
stat, err := out.Stat()
if err != nil {
util.Log().Warning("Failed to read file info: %s", err)
return err
}
if uint64(stat.Size()) < fileInfo.AppendStart {
return errors.New("size of unfinished uploaded chunks is not as expected")
} else if uint64(stat.Size()) > fileInfo.AppendStart {
out.Close()
if err := handler.Truncate(ctx, dst, fileInfo.AppendStart); err != nil {
return fmt.Errorf("failed to overwrite chunk: %w", err)
}
out, err = os.OpenFile(dst, openMode, Perm)
defer out.Close()
if err != nil {
util.Log().Warning("Failed to create or open file: %s", err)
return err
}
}
}
// 写入文件内容
_, err = io.Copy(out, file)
return err
}
func (handler Driver) Truncate(ctx context.Context, src string, size uint64) error {
util.Log().Warning("Truncate file %q to [%d].", src, size)
out, err := os.OpenFile(src, os.O_WRONLY, Perm)
if err != nil {
util.Log().Warning("Failed to open file: %s", err)
return err
}
defer out.Close()
return out.Truncate(int64(size))
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
deleteFailed := make([]string, 0, len(files))
var retErr error
for _, value := range files {
filePath := util.RelativePath(filepath.FromSlash(value))
if util.Exists(filePath) {
err := os.Remove(filePath)
if err != nil {
util.Log().Warning("Failed to delete file: %s", err)
retErr = err
deleteFailed = append(deleteFailed, value)
}
}
// 尝试删除文件的缩略图(如果有)
_ = os.Remove(util.RelativePath(value + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")))
}
return deleteFailed, retErr
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// Quick check thumb existence on master.
if conf.SystemConfig.Mode == "master" && file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusNotExist {
// Tell invoker to generate a thumb
return nil, driver.ErrorThumbNotExist
}
thumbFile, err := handler.Get(ctx, file.ThumbFile())
if err != nil {
if errors.Is(err, os.ErrNotExist) {
err = fmt.Errorf("thumb not exist: %w (%w)", err, driver.ErrorThumbNotExist)
}
return nil, err
}
return &response.ContentResponse{
Redirect: false,
Content: thumbFile,
}, nil
}
// Source 获取外链URL
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
file, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return "", errors.New("failed to read file model context")
}
var baseURL *url.URL
// 是否启用了CDN
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
baseURL = cdnURL
}
var (
signedURI *url.URL
err error
)
if isDownload {
// 创建下载会话,将文件信息写入缓存
downloadSessionID := util.RandStringRunes(16)
err = cache.Set("download_"+downloadSessionID, file, int(ttl))
if err != nil {
return "", serializer.NewError(serializer.CodeCacheOperation, "Failed to create download session", err)
}
// 签名生成文件记录
signedURI, err = auth.SignURI(
auth.General,
fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID),
ttl,
)
} else {
// 签名生成文件记录
signedURI, err = auth.SignURI(
auth.General,
fmt.Sprintf("/api/v3/file/get/%d/%s", file.ID, file.Name),
ttl,
)
}
if err != nil {
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err)
}
finalURL := signedURI.String()
if baseURL != nil {
finalURL = baseURL.ResolveReference(signedURI).String()
}
return finalURL, nil
}
// Token 获取上传策略和认证Token本地策略直接返回空值
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
if util.Exists(uploadSession.SavePath) {
return nil, errors.New("placeholder file already exist")
}
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
}, nil
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}

View File

@@ -0,0 +1,595 @@
package onedrive
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
const (
// SmallFileSize 单文件上传接口最大尺寸
SmallFileSize uint64 = 4 * 1024 * 1024
// ChunkSize 服务端中转分片上传分片大小
ChunkSize uint64 = 10 * 1024 * 1024
// ListRetry 列取请求重试次数
ListRetry = 1
chunkRetrySleep = time.Second * 5
notFoundError = "itemNotFound"
)
// GetSourcePath 获取文件的绝对路径
func (info *FileInfo) GetSourcePath() string {
res, err := url.PathUnescape(info.ParentReference.Path)
if err != nil {
return ""
}
return strings.TrimPrefix(
path.Join(
strings.TrimPrefix(res, "/drive/root:"),
info.Name,
),
"/",
)
}
func (client *Client) getRequestURL(api string, opts ...Option) string {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
base, _ := url.Parse(client.Endpoints.EndpointURL)
if base == nil {
return ""
}
if options.useDriverResource {
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
} else {
base.Path = path.Join(base.Path, api)
}
return base.String()
}
// ListChildren 根据路径列取子对象
func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) {
var requestURL string
dst := strings.TrimPrefix(path, "/")
if dst == "" {
requestURL = client.getRequestURL("root/children")
} else {
requestURL = client.getRequestURL("root:/" + dst + ":/children")
}
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
if err != nil {
retried := 0
if v, ok := ctx.Value(fsctx.RetryCtx).(int); ok {
retried = v
}
if retried < ListRetry {
retried++
util.Log().Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err)
time.Sleep(time.Duration(5) * time.Second)
return client.ListChildren(context.WithValue(ctx, fsctx.RetryCtx, retried), path)
}
return nil, err
}
var (
decodeErr error
fileInfo ListResponse
)
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
if decodeErr != nil {
return nil, decodeErr
}
return fileInfo.Value, nil
}
// Meta 根据资源ID或文件路径获取文件元信息
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
var requestURL string
if id != "" {
requestURL = client.getRequestURL("items/" + id)
} else {
dst := strings.TrimPrefix(path, "/")
requestURL = client.getRequestURL("root:/" + dst)
}
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
if err != nil {
return nil, err
}
var (
decodeErr error
fileInfo FileInfo
)
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
if decodeErr != nil {
return nil, decodeErr
}
return &fileInfo, nil
}
// CreateUploadSession 创建分片上传会话
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
body := map[string]map[string]interface{}{
"item": {
"@microsoft.graph.conflictBehavior": options.conflictBehavior,
},
}
bodyBytes, _ := json.Marshal(body)
res, err := client.requestWithStr(ctx, "POST", requestURL, string(bodyBytes), 200)
if err != nil {
return "", err
}
var (
decodeErr error
uploadSession UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
if decodeErr != nil {
return "", decodeErr
}
return uploadSession.UploadURL, nil
}
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
siteUrlParsed, err := url.Parse(siteUrl)
if err != nil {
return "", err
}
hostName := siteUrlParsed.Hostname()
relativePath := strings.Trim(siteUrlParsed.Path, "/")
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
if reqErr != nil {
return "", reqErr
}
var (
decodeErr error
siteInfo Site
)
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
if decodeErr != nil {
return "", decodeErr
}
return siteInfo.ID, nil
}
// GetUploadSessionStatus 查询上传会话状态
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
if err != nil {
return nil, err
}
var (
decodeErr error
uploadSession UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadSession, nil
}
// UploadChunk 上传分片
func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) {
res, err := client.request(
ctx, "PUT", uploadURL, content,
request.WithContentLength(current.Length()),
request.WithHeader(http.Header{
"Content-Range": {current.RangeHeader()},
}),
request.WithoutHeader([]string{"Authorization", "Content-Type"}),
request.WithTimeout(0),
)
if err != nil {
return nil, fmt.Errorf("failed to upload OneDrive chunk #%d: %w", current.Index(), err)
}
if current.IsLast() {
return nil, nil
}
var (
decodeErr error
uploadRes UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadRes, nil
}
// Upload 上传文件
func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
fileInfo := file.Info()
// 决定是否覆盖文件
overwrite := "fail"
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
overwrite = "replace"
}
size := int(fileInfo.Size)
dst := fileInfo.SavePath
// 小文件,使用简单上传接口上传
if size <= int(SmallFileSize) {
_, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite))
return err
}
// 大文件,进行分片
// 创建上传会话
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior(overwrite))
if err != nil {
return err
}
// Initial chunk groups
chunks := chunk.NewChunkGroup(file, client.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
Max: model.GetIntSetting("chunk_retries", 5),
Sleep: chunkRetrySleep,
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
_, err := client.UploadChunk(ctx, uploadURL, content, current)
return err
}
// upload chunks
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
return nil
}
// DeleteUploadSession 删除上传会话
func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string) error {
_, err := client.requestWithStr(ctx, "DELETE", uploadURL, "", 204)
if err != nil {
return err
}
return nil
}
// SimpleUpload 上传小文件到dst
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/" + dst + ":/content")
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
request.WithTimeout(0),
)
if err != nil {
return nil, err
}
var (
decodeErr error
uploadRes UploadResult
)
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadRes, nil
}
// BatchDelete 并行删除给出的文件,返回删除失败的文件,及第一个遇到的错误。此方法将文件分为
// 20个一组调用Delete并行删除
// TODO 测试
func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, error) {
groupNum := len(dst)/20 + 1
finalRes := make([]string, 0, len(dst))
res := make([]string, 0, 20)
var err error
for i := 0; i < groupNum; i++ {
end := 20*i + 20
if i == groupNum-1 {
end = len(dst)
}
res, err = client.Delete(ctx, dst[20*i:end])
finalRes = append(finalRes, res...)
}
return finalRes, err
}
// Delete 并行删除文件,返回删除失败的文件,及第一个遇到的错误,
// 由于API限制最多删除20个
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
body := client.makeBatchDeleteRequestsBody(dst)
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
WithDriverResource(false)), body, 200)
if err != nil {
return dst, err
}
var (
decodeErr error
deleteRes BatchResponses
)
decodeErr = json.Unmarshal([]byte(res), &deleteRes)
if decodeErr != nil {
return dst, decodeErr
}
// 取得删除失败的文件
failed := getDeleteFailed(&deleteRes)
if len(failed) != 0 {
return failed, ErrDeleteFile
}
return failed, nil
}
func getDeleteFailed(res *BatchResponses) []string {
var failed = make([]string, 0, len(res.Responses))
for _, v := range res.Responses {
if v.Status != 204 && v.Status != 404 {
failed = append(failed, v.ID)
}
}
return failed
}
// makeBatchDeleteRequestsBody 生成批量删除请求正文
func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
req := BatchRequests{
Requests: make([]BatchRequest, len(files)),
}
for i, v := range files {
v = strings.TrimPrefix(v, "/")
filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
filePath.Path = path.Join(filePath.Path, v)
req.Requests[i] = BatchRequest{
ID: v,
Method: "DELETE",
URL: filePath.EscapedPath(),
}
}
res, _ := json.Marshal(req)
return string(res)
}
// GetThumbURL 获取给定尺寸的缩略图URL
func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) {
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
if err != nil {
return "", err
}
var (
decodeErr error
thumbRes ThumbResponse
)
decodeErr = json.Unmarshal([]byte(res), &thumbRes)
if decodeErr != nil {
return "", decodeErr
}
if thumbRes.URL != "" {
return thumbRes.URL, nil
}
if len(thumbRes.Value) == 1 {
if res, ok := thumbRes.Value[0]["large"]; ok {
return res.(map[string]interface{})["url"].(string), nil
}
}
return "", ErrThumbSizeNotFound
}
// MonitorUpload 监控客户端分片上传进度
func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size uint64, ttl int64) {
// 回调完成通知chan
callbackChan := mq.GlobalMQ.Subscribe(callbackKey, 1)
defer mq.GlobalMQ.Unsubscribe(callbackKey, callbackChan)
timeout := model.GetIntSetting("onedrive_monitor_timeout", 600)
interval := model.GetIntSetting("onedrive_callback_check", 20)
for {
select {
case <-callbackChan:
util.Log().Debug("Client finished OneDrive callback.")
return
case <-time.After(time.Duration(ttl) * time.Second):
// 上传会话到期,仍未完成上传,创建占位符
client.DeleteUploadSession(context.Background(), uploadURL)
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
if err != nil {
util.Log().Debug("Failed to create placeholder file: %s", err)
}
return
case <-time.After(time.Duration(timeout) * time.Second):
util.Log().Debug("Checking OneDrive upload status.")
status, err := client.GetUploadSessionStatus(context.Background(), uploadURL)
if err != nil {
if resErr, ok := err.(*RespError); ok {
if resErr.APIError.Code == notFoundError {
util.Log().Debug("Upload completed, will check upload callback later.")
select {
case <-time.After(time.Duration(interval) * time.Second):
util.Log().Warning("No callback is made, file will be deleted.")
cache.Deletes([]string{callbackKey}, "callback_")
_, err = client.Delete(context.Background(), []string{path})
if err != nil {
util.Log().Warning("Failed to delete file without callback: %s", err)
}
case <-callbackChan:
util.Log().Debug("Client finished callback.")
}
return
}
}
util.Log().Debug("Failed to get upload session status: %s, continue next iteration.", err.Error())
continue
}
// 成功获取分片上传状态,检查文件大小
if len(status.NextExpectedRanges) == 0 {
continue
}
sizeRange := strings.Split(
status.NextExpectedRanges[len(status.NextExpectedRanges)-1],
"-",
)
if len(sizeRange) != 2 {
continue
}
uploadFullSize, _ := strconv.ParseUint(sizeRange[1], 10, 64)
if (sizeRange[0] == "0" && sizeRange[1] == "") || uploadFullSize+1 != size {
util.Log().Debug("Upload has not started, or uploaded file size not match, canceling upload session...")
// 取消上传会话实测OneDrive取消上传会话后客户端还是可以上传
// 所以上传一个空文件占位,阻止客户端上传
client.DeleteUploadSession(context.Background(), uploadURL)
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
if err != nil {
util.Log().Debug("无法创建占位文件,%s", err)
}
return
}
}
}
}
func sysError(err error) *RespError {
return &RespError{APIError: APIError{
Code: "system",
Message: err.Error(),
}}
}
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
// 获取凭证
err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
if err != nil {
return "", sysError(err)
}
option = append(option,
request.WithHeader(http.Header{
"Authorization": {"Bearer " + client.Credential.AccessToken},
"Content-Type": {"application/json"},
}),
request.WithContext(ctx),
request.WithTPSLimit(
fmt.Sprintf("policy_%d", client.Policy.ID),
client.Policy.OptionsSerialized.TPSLimit,
client.Policy.OptionsSerialized.TPSLimitBurst,
),
)
// 发送请求
res := client.Request.Request(
method,
url,
body,
option...,
)
if res.Err != nil {
return "", sysError(res.Err)
}
respBody, err := res.GetResponse()
if err != nil {
return "", sysError(err)
}
// 解析请求响应
var (
errResp RespError
decodeErr error
)
// 如果有错误
if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
if decodeErr != nil {
util.Log().Debug("Onedrive returns unknown response: %s", respBody)
return "", sysError(decodeErr)
}
if res.Response.StatusCode == 429 {
util.Log().Warning("OneDrive request is throttled.")
return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
}
return "", &errResp
}
return respBody, nil
}
func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
// 发送请求
bodyReader := io.NopCloser(strings.NewReader(body))
return client.request(ctx, method, url, bodyReader,
request.WithContentLength(int64(len(body))),
)
}

View File

@@ -0,0 +1,78 @@
package onedrive
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
)
var (
// ErrAuthEndpoint 无法解析授权端点地址
ErrAuthEndpoint = errors.New("failed to parse endpoint url")
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
// ErrDeleteFile 无法删除文件
ErrDeleteFile = errors.New("cannot delete file")
// ErrClientCanceled 客户端取消操作
ErrClientCanceled = errors.New("client canceled")
// Desired thumb size not available
ErrThumbSizeNotFound = errors.New("thumb size not found")
)
// Client OneDrive客户端
type Client struct {
Endpoints *Endpoints
Policy *model.Policy
Credential *Credential
ClientID string
ClientSecret string
Redirect string
Request request.Client
ClusterController cluster.Controller
}
// Endpoints OneDrive客户端相关设置
type Endpoints struct {
OAuthURL string // OAuth认证的基URL
OAuthEndpoints *oauthEndpoint
EndpointURL string // 接口请求的基URL
isInChina bool // 是否为世纪互联
DriverResource string // 要使用的驱动器
}
// NewClient 根据存储策略获取新的client
func NewClient(policy *model.Policy) (*Client, error) {
client := &Client{
Endpoints: &Endpoints{
OAuthURL: policy.BaseURL,
EndpointURL: policy.Server,
DriverResource: policy.OptionsSerialized.OdDriver,
},
Credential: &Credential{
RefreshToken: policy.AccessKey,
},
Policy: policy,
ClientID: policy.BucketName,
ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OauthRedirect,
Request: request.NewClient(),
ClusterController: cluster.DefaultController,
}
if client.Endpoints.DriverResource == "" {
client.Endpoints.DriverResource = "me/drive"
}
oauthBase := client.getOAuthEndpoint()
if oauthBase == nil {
return nil, ErrAuthEndpoint
}
client.Endpoints.OAuthEndpoints = oauthBase
return client, nil
}

View File

@@ -0,0 +1,238 @@
package onedrive
import (
"context"
"errors"
"fmt"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver OneDrive 适配器
type Driver struct {
Policy *model.Policy
Client *Client
HTTPClient request.Client
}
// NewDriver 从存储策略初始化新的Driver实例
func NewDriver(policy *model.Policy) (driver.Handler, error) {
client, err := NewClient(policy)
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 50 << 20 // 50MB
}
return Driver{
Policy: policy,
Client: client,
HTTPClient: request.NewClient(),
}, err
}
// List 列取项目
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
base = strings.TrimPrefix(base, "/")
// 列取子项目
objects, _ := handler.Client.ListChildren(ctx, base)
// 获取真实的列取起始根目录
rootPath := base
if realBase, ok := ctx.Value(fsctx.PathCtx).(string); ok {
rootPath = realBase
} else {
ctx = context.WithValue(ctx, fsctx.PathCtx, base)
}
// 整理结果
res := make([]response.Object, 0, len(objects))
for _, object := range objects {
source := path.Join(base, object.Name)
rel, err := filepath.Rel(rootPath, source)
if err != nil {
continue
}
res = append(res, response.Object{
Name: object.Name,
RelativePath: filepath.ToSlash(rel),
Source: source,
Size: object.Size,
IsDir: object.Folder != nil,
LastModify: time.Now(),
})
}
// 递归列取子目录
if recursive {
for _, object := range objects {
if object.Folder != nil {
sub, _ := handler.List(ctx, path.Join(base, object.Name), recursive)
res = append(res, sub...)
}
}
}
return res, nil
}
// Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(
ctx,
path,
60,
false,
0,
)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.HTTPClient.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
return handler.Client.Upload(ctx, file)
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return handler.Client.BatchDelete(ctx, files)
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
res, err := handler.Client.GetThumbURL(ctx, file.SourceName, thumbSize[0], thumbSize[1])
if err != nil {
var apiErr *RespError
if errors.As(err, &apiErr); err == ErrThumbSizeNotFound || (apiErr != nil && apiErr.APIError.Code == notFoundError) {
// OneDrive cannot generate thumbnail for this file
return nil, driver.ErrorThumbNotSupported
}
}
return &response.ContentResponse{
Redirect: true,
URL: res,
}, err
}
// Source 获取外链URL
func (handler Driver) Source(
ctx context.Context,
path string,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID)
}
// 尝试从缓存中查找
if cachedURL, ok := cache.Get(cacheKey); ok {
return handler.replaceSourceHost(cachedURL.(string))
}
// 缓存不存在,重新获取
res, err := handler.Client.Meta(ctx, "", path)
if err == nil {
// 写入新的缓存
cache.Set(
cacheKey,
res.DownloadURL,
model.GetIntSetting("onedrive_source_timeout", 1800),
)
return handler.replaceSourceHost(res.DownloadURL)
}
return "", err
}
func (handler Driver) replaceSourceHost(origin string) (string, error) {
if handler.Policy.OptionsSerialized.OdProxy != "" {
source, err := url.Parse(origin)
if err != nil {
return "", err
}
cdn, err := url.Parse(handler.Policy.OptionsSerialized.OdProxy)
if err != nil {
return "", err
}
// 替换反代地址
source.Scheme = cdn.Scheme
source.Host = cdn.Host
return source.String(), nil
}
return origin, nil
}
// Token 获取上传会话URL
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
fileInfo := file.Info()
uploadURL, err := handler.Client.CreateUploadSession(ctx, fileInfo.SavePath, WithConflictBehavior("fail"))
if err != nil {
return nil, err
}
// 监控回调及上传
go handler.Client.MonitorUpload(uploadURL, uploadSession.Key, fileInfo.SavePath, fileInfo.Size, ttl)
uploadSession.UploadURL = uploadURL
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadURLs: []string{uploadURL},
}, nil
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return handler.Client.DeleteUploadSession(ctx, uploadSession.UploadURL)
}

View File

@@ -0,0 +1,25 @@
package onedrive
import "sync"
// CredentialLock 针对存储策略凭证的锁
type CredentialLock interface {
Lock(uint)
Unlock(uint)
}
var GlobalMutex = mutexMap{}
type mutexMap struct {
locks sync.Map
}
func (m *mutexMap) Lock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Lock()
}
func (m *mutexMap) Unlock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Unlock()
}

View File

@@ -0,0 +1,192 @@
package onedrive
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Error 实现error接口
func (err OAuthError) Error() string {
return err.ErrorDescription
}
// OAuthURL 获取OAuth认证页面URL
func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
query := url.Values{
"client_id": {client.ClientID},
"scope": {strings.Join(scope, " ")},
"response_type": {"code"},
"redirect_uri": {client.Redirect},
}
client.Endpoints.OAuthEndpoints.authorize.RawQuery = query.Encode()
return client.Endpoints.OAuthEndpoints.authorize.String()
}
// getOAuthEndpoint 根据指定的AuthURL获取详细的认证接口地址
func (client *Client) getOAuthEndpoint() *oauthEndpoint {
base, err := url.Parse(client.Endpoints.OAuthURL)
if err != nil {
return nil
}
var (
token *url.URL
authorize *url.URL
)
switch base.Host {
case "login.live.com":
token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
case "login.chinacloudapi.cn":
client.Endpoints.isInChina = true
token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
default:
token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
}
return &oauthEndpoint{
token: *token,
authorize: *authorize,
}
}
// ObtainToken 通过code或refresh_token兑换token
func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
body := url.Values{
"client_id": {client.ClientID},
"redirect_uri": {client.Redirect},
"client_secret": {client.ClientSecret},
}
if options.code != "" {
body.Add("grant_type", "authorization_code")
body.Add("code", options.code)
} else {
body.Add("grant_type", "refresh_token")
body.Add("refresh_token", options.refreshToken)
}
strBody := body.Encode()
res := client.Request.Request(
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
ioutil.NopCloser(strings.NewReader(strBody)),
request.WithHeader(http.Header{
"Content-Type": {"application/x-www-form-urlencoded"}},
),
request.WithContentLength(int64(len(strBody))),
)
if res.Err != nil {
return nil, res.Err
}
respBody, err := res.GetResponse()
if err != nil {
return nil, err
}
var (
errResp OAuthError
credential Credential
decodeErr error
)
if res.Response.StatusCode != 200 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
} else {
decodeErr = json.Unmarshal([]byte(respBody), &credential)
}
if decodeErr != nil {
return nil, decodeErr
}
if errResp.ErrorType != "" {
return nil, errResp
}
return &credential, nil
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
if isSlave {
return client.fetchCredentialFromMaster(ctx)
}
oauth.GlobalMutex.Lock(client.Policy.ID)
defer oauth.GlobalMutex.Unlock(client.Policy.ID)
// 如果已存在凭证
if client.Credential != nil && client.Credential.AccessToken != "" {
// 检查已有凭证是否过期
if client.Credential.ExpiresIn > time.Now().Unix() {
// 未过期,不要更新
return nil
}
}
// 尝试从缓存中获取凭证
if cacheCredential, ok := cache.Get("onedrive_" + client.ClientID); ok {
credential := cacheCredential.(Credential)
if credential.ExpiresIn > time.Now().Unix() {
client.Credential = &credential
return nil
}
}
// 获取新的凭证
if client.Credential == nil || client.Credential.RefreshToken == "" {
// 无有效的RefreshToken
util.Log().Error("Failed to refresh credential for policy %q, please login your Microsoft account again.", client.Policy.Name)
return ErrInvalidRefreshToken
}
credential, err := client.ObtainToken(ctx, WithRefreshToken(client.Credential.RefreshToken))
if err != nil {
return err
}
// 更新有效期为绝对时间戳
expires := credential.ExpiresIn - 60
credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
client.Credential = credential
// 更新存储策略的 RefreshToken
client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken)
// 更新缓存
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
return nil
}
func (client *Client) AccessToken() string {
return client.Credential.AccessToken
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
if err != nil {
return err
}
client.Credential = &Credential{AccessToken: res}
return nil
}

View File

@@ -0,0 +1,59 @@
package onedrive
import "time"
// Option 发送请求的额外设置
type Option interface {
apply(*options)
}
type options struct {
redirect string
code string
refreshToken string
conflictBehavior string
expires time.Time
useDriverResource bool
}
type optionFunc func(*options)
// WithCode 设置接口Code
func WithCode(t string) Option {
return optionFunc(func(o *options) {
o.code = t
})
}
// WithRefreshToken 设置接口RefreshToken
func WithRefreshToken(t string) Option {
return optionFunc(func(o *options) {
o.refreshToken = t
})
}
// WithConflictBehavior 设置文件重名后的处理方式
func WithConflictBehavior(t string) Option {
return optionFunc(func(o *options) {
o.conflictBehavior = t
})
}
// WithConflictBehavior 设置文件重名后的处理方式
func WithDriverResource(t bool) Option {
return optionFunc(func(o *options) {
o.useDriverResource = t
})
}
func (f optionFunc) apply(o *options) {
f(o)
}
func newDefaultOption() *options {
return &options{
conflictBehavior: "fail",
useDriverResource: true,
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
}
}

View File

@@ -0,0 +1,140 @@
package onedrive
import (
"encoding/gob"
"net/url"
)
// RespError 接口返回错误
type RespError struct {
APIError APIError `json:"error"`
}
// APIError 接口返回的错误内容
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// UploadSessionResponse 分片上传会话
type UploadSessionResponse struct {
DataContext string `json:"@odata.context"`
ExpirationDateTime string `json:"expirationDateTime"`
NextExpectedRanges []string `json:"nextExpectedRanges"`
UploadURL string `json:"uploadUrl"`
}
// FileInfo 文件元信息
type FileInfo struct {
Name string `json:"name"`
Size uint64 `json:"size"`
Image imageInfo `json:"image"`
ParentReference parentReference `json:"parentReference"`
DownloadURL string `json:"@microsoft.graph.downloadUrl"`
File *file `json:"file"`
Folder *folder `json:"folder"`
}
type file struct {
MimeType string `json:"mimeType"`
}
type folder struct {
ChildCount int `json:"childCount"`
}
type imageInfo struct {
Height int `json:"height"`
Width int `json:"width"`
}
type parentReference struct {
Path string `json:"path"`
Name string `json:"name"`
ID string `json:"id"`
}
// UploadResult 上传结果
type UploadResult struct {
ID string `json:"id"`
Name string `json:"name"`
Size uint64 `json:"size"`
}
// BatchRequests 批量操作请求
type BatchRequests struct {
Requests []BatchRequest `json:"requests"`
}
// BatchRequest 批量操作单个请求
type BatchRequest struct {
ID string `json:"id"`
Method string `json:"method"`
URL string `json:"url"`
Body interface{} `json:"body,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
}
// BatchResponses 批量操作响应
type BatchResponses struct {
Responses []BatchResponse `json:"responses"`
}
// BatchResponse 批量操作单个响应
type BatchResponse struct {
ID string `json:"id"`
Status int `json:"status"`
}
// ThumbResponse 获取缩略图的响应
type ThumbResponse struct {
Value []map[string]interface{} `json:"value"`
URL string `json:"url"`
}
// ListResponse 列取子项目响应
type ListResponse struct {
Value []FileInfo `json:"value"`
Context string `json:"@odata.context"`
}
// oauthEndpoint OAuth接口地址
type oauthEndpoint struct {
token url.URL
authorize url.URL
}
// Credential 获取token时返回的凭证
type Credential struct {
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
UserID string `json:"user_id"`
}
// OAuthError OAuth相关接口的错误响应
type OAuthError struct {
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
CorrelationID string `json:"correlation_id"`
}
// Site SharePoint 站点信息
type Site struct {
Description string `json:"description"`
ID string `json:"id"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
WebUrl string `json:"webUrl"`
}
func init() {
gob.Register(Credential{})
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
}

View File

@@ -0,0 +1,117 @@
package oss
import (
"bytes"
"crypto"
"crypto/md5"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
)
// GetPublicKey 从回调请求或缓存中获取OSS的回调签名公钥
func GetPublicKey(r *http.Request) ([]byte, error) {
var pubKey []byte
// 尝试从缓存中获取
pub, exist := cache.Get("oss_public_key")
if exist {
return pub.([]byte), nil
}
// 从请求中获取
pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get("x-oss-pub-key-url"))
if err != nil {
return pubKey, err
}
// 确保这个 public key 是由 OSS 颁发的
if !strings.HasPrefix(string(pubURL), "http://gosspublic.alicdn.com/") &&
!strings.HasPrefix(string(pubURL), "https://gosspublic.alicdn.com/") {
return pubKey, errors.New("public key url invalid")
}
// 获取公钥
client := request.NewClient()
body, err := client.Request("GET", string(pubURL), nil).
CheckHTTPResponse(200).
GetResponse()
if err != nil {
return pubKey, err
}
// 写入缓存
_ = cache.Set("oss_public_key", []byte(body), 86400*7)
return []byte(body), nil
}
func getRequestMD5(r *http.Request) ([]byte, error) {
var byteMD5 []byte
// 获取请求正文
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return byteMD5, err
}
r.Body = ioutil.NopCloser(bytes.NewReader(body))
strURLPathDecode, err := url.PathUnescape(r.URL.Path)
if err != nil {
return byteMD5, err
}
strAuth := fmt.Sprintf("%s\n%s", strURLPathDecode, string(body))
md5Ctx := md5.New()
md5Ctx.Write([]byte(strAuth))
byteMD5 = md5Ctx.Sum(nil)
return byteMD5, nil
}
// VerifyCallbackSignature 验证OSS回调请求
func VerifyCallbackSignature(r *http.Request) error {
bytePublicKey, err := GetPublicKey(r)
if err != nil {
return err
}
byteMD5, err := getRequestMD5(r)
if err != nil {
return err
}
strAuthorizationBase64 := r.Header.Get("authorization")
if strAuthorizationBase64 == "" {
return errors.New("no authorization field in Request header")
}
authorization, _ := base64.StdEncoding.DecodeString(strAuthorizationBase64)
pubBlock, _ := pem.Decode(bytePublicKey)
if pubBlock == nil {
return errors.New("pubBlock not exist")
}
pubInterface, err := x509.ParsePKIXPublicKey(pubBlock.Bytes)
if (pubInterface == nil) || (err != nil) {
return err
}
pub := pubInterface.(*rsa.PublicKey)
errorVerifyPKCS1v15 := rsa.VerifyPKCS1v15(pub, crypto.MD5, byteMD5, authorization)
if errorVerifyPKCS1v15 != nil {
return errorVerifyPKCS1v15
}
return nil
}

View File

@@ -0,0 +1,501 @@
package oss
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
"path"
"path/filepath"
"strings"
"time"
"github.com/HFO4/aliyun-oss-go-sdk/oss"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// UploadPolicy 阿里云OSS上传策略
type UploadPolicy struct {
Expiration string `json:"expiration"`
Conditions []interface{} `json:"conditions"`
}
// CallbackPolicy 回调策略
type CallbackPolicy struct {
CallbackURL string `json:"callbackUrl"`
CallbackBody string `json:"callbackBody"`
CallbackBodyType string `json:"callbackBodyType"`
}
// Driver 阿里云OSS策略适配器
type Driver struct {
Policy *model.Policy
client *oss.Client
bucket *oss.Bucket
HTTPClient request.Client
}
type key int
const (
chunkRetrySleep = time.Duration(5) * time.Second
// MultiPartUploadThreshold 服务端使用分片上传的阈值
MultiPartUploadThreshold uint64 = 5 * (1 << 30) // 5GB
// VersionID 文件版本标识
VersionID key = iota
)
func NewDriver(policy *model.Policy) (*Driver, error) {
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
}
driver := &Driver{
Policy: policy,
HTTPClient: request.NewClient(),
}
return driver, driver.InitOSSClient(false)
}
// CORS 创建跨域策略
func (handler *Driver) CORS() error {
return handler.client.SetBucketCORS(handler.Policy.BucketName, []oss.CORSRule{
{
AllowedOrigin: []string{"*"},
AllowedMethod: []string{
"GET",
"POST",
"PUT",
"DELETE",
"HEAD",
},
ExposeHeader: []string{},
AllowedHeader: []string{"*"},
MaxAgeSeconds: 3600,
},
})
}
// InitOSSClient 初始化OSS鉴权客户端
func (handler *Driver) InitOSSClient(forceUsePublicEndpoint bool) error {
if handler.Policy == nil {
return errors.New("empty policy")
}
// 决定是否使用内网 Endpoint
endpoint := handler.Policy.Server
if handler.Policy.OptionsSerialized.ServerSideEndpoint != "" && !forceUsePublicEndpoint {
endpoint = handler.Policy.OptionsSerialized.ServerSideEndpoint
}
// 初始化客户端
client, err := oss.New(endpoint, handler.Policy.AccessKey, handler.Policy.SecretKey)
if err != nil {
return err
}
handler.client = client
// 初始化存储桶
bucket, err := client.Bucket(handler.Policy.BucketName)
if err != nil {
return err
}
handler.bucket = bucket
return nil
}
// List 列出OSS上的文件
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
// 列取文件
base = strings.TrimPrefix(base, "/")
if base != "" {
base += "/"
}
var (
delimiter string
marker string
objects []oss.ObjectProperties
commons []string
)
if !recursive {
delimiter = "/"
}
for {
subRes, err := handler.bucket.ListObjects(oss.Marker(marker), oss.Prefix(base),
oss.MaxKeys(1000), oss.Delimiter(delimiter))
if err != nil {
return nil, err
}
objects = append(objects, subRes.Objects...)
commons = append(commons, subRes.CommonPrefixes...)
marker = subRes.NextMarker
if marker == "" {
break
}
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(base, object)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(base, object.Key)
if err != nil {
continue
}
if strings.HasSuffix(object.Key, "/") {
res = append(res, response.Object{
Name: path.Base(object.Key),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
} else {
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Size),
IsDir: false,
LastModify: object.LastModified,
})
}
}
return res, nil
}
// Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 通过VersionID禁止缓存
ctx = context.WithValue(ctx, VersionID, time.Now().UnixNano())
// 尽可能使用私有 Endpoint
ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false)
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.HTTPClient.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
fileInfo := file.Info()
// 凭证有效期
credentialTTL := model.GetIntSetting("upload_session_timeout", 3600)
// 是否允许覆盖
overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite
options := []oss.Option{
oss.Expires(time.Now().Add(time.Duration(credentialTTL) * time.Second)),
oss.ForbidOverWrite(!overwrite),
}
// 小文件直接上传
if fileInfo.Size < MultiPartUploadThreshold {
return handler.bucket.PutObject(fileInfo.SavePath, file, options...)
}
// 超过阈值时使用分片上传
imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...)
if err != nil {
return fmt.Errorf("failed to initiate multipart upload: %w", err)
}
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
Max: model.GetIntSetting("chunk_retries", 5),
Sleep: chunkRetrySleep,
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
_, err := handler.bucket.UploadPart(imur, content, current.Length(), current.Index()+1)
return err
}
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
_, err = handler.bucket.CompleteMultipartUpload(imur, oss.CompleteAll("yes"), oss.ForbidOverWrite(!overwrite))
return err
}
// Delete 删除一个或多个文件,
// 返回未删除的文件
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
// 删除文件
delRes, err := handler.bucket.DeleteObjects(files)
if err != nil {
return files, err
}
// 统计未删除的文件
failed := util.SliceDifference(files, delRes.DeletedObjects)
if len(failed) > 0 {
return failed, errors.New("failed to delete")
}
return []string{}, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://help.aliyun.com/document_detail/183902.html
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heic", "tiff", "avif"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) {
return nil, driver.ErrorThumbNotSupported
}
// 初始化客户端
if err := handler.InitOSSClient(true); err != nil {
return nil, err
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumbParam := fmt.Sprintf("image/resize,m_lfit,h_%d,w_%d/quality,q_%d", thumbSize[1], thumbSize[0], thumbEncodeQuality)
ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, thumbParam)
thumbOption := []oss.Option{oss.Process(thumbParam)}
thumbURL, err := handler.signSourceURL(
ctx,
file.SourceName,
int64(model.GetIntSetting("preview_timeout", 60)),
thumbOption,
)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: thumbURL,
}, nil
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 初始化客户端
usePublicEndpoint := true
if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok {
usePublicEndpoint = forceUsePublicEndpoint
}
if err := handler.InitOSSClient(usePublicEndpoint); err != nil {
return "", err
}
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 添加各项设置
var signOptions = make([]oss.Option, 0, 2)
if isDownload {
signOptions = append(signOptions, oss.ResponseContentDisposition("attachment; filename=\""+url.PathEscape(fileName)+"\""))
}
if speed > 0 {
// Byte 转换为 bit
speed *= 8
// OSS对速度值有范围限制
if speed < 819200 {
speed = 819200
}
if speed > 838860800 {
speed = 838860800
}
signOptions = append(signOptions, oss.TrafficLimitParam(int64(speed)))
}
return handler.signSourceURL(ctx, path, ttl, signOptions)
}
func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64, options []oss.Option) (string, error) {
signedURL, err := handler.bucket.SignURL(path, oss.HTTPGet, ttl, options...)
if err != nil {
return "", err
}
// 将最终生成的签名URL域名换成用户自定义的加速域名如果有
finalURL, err := url.Parse(signedURL)
if err != nil {
return "", err
}
// 公有空间替换掉Key及不支持的头
if !handler.Policy.IsPrivate {
query := finalURL.Query()
query.Del("OSSAccessKeyId")
query.Del("Signature")
query.Del("response-content-disposition")
query.Del("x-oss-traffic-limit")
finalURL.RawQuery = query.Encode()
}
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
finalURL.Host = cdnURL.Host
finalURL.Scheme = cdnURL.Scheme
}
return finalURL.String(), nil
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 初始化客户端
if err := handler.InitOSSClient(true); err != nil {
return nil, err
}
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/oss/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI)
// 回调策略
callbackPolicy := CallbackPolicy{
CallbackURL: apiURL.String(),
CallbackBody: `{"name":${x:fname},"source_name":${object},"size":${size},"pic_info":"${imageInfo.width},${imageInfo.height}"}`,
CallbackBodyType: "application/json",
}
callbackPolicyJSON, err := json.Marshal(callbackPolicy)
if err != nil {
return nil, fmt.Errorf("failed to encode callback policy: %w", err)
}
callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON)
// 初始化分片上传
fileInfo := file.Info()
options := []oss.Option{
oss.Expires(time.Now().Add(time.Duration(ttl) * time.Second)),
oss.ForbidOverWrite(true),
oss.ContentType(fileInfo.DetectMimeType()),
}
imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...)
if err != nil {
return nil, fmt.Errorf("failed to initialize multipart upload: %w", err)
}
uploadSession.UploadID = imur.UploadID
// 为每个分片签名上传 URL
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false)
urls := make([]string, chunks.Num())
for chunks.Next() {
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
signedURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPut, ttl,
oss.PartNumber(c.Index()+1),
oss.UploadID(imur.UploadID),
oss.ContentType("application/octet-stream"))
if err != nil {
return err
}
urls[c.Index()] = signedURL
return nil
})
if err != nil {
return nil, err
}
}
// 签名完成分片上传的URL
completeURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPost, ttl,
oss.ContentType("application/octet-stream"),
oss.UploadID(imur.UploadID),
oss.Expires(time.Now().Add(time.Duration(ttl)*time.Second)),
oss.CompleteAll("yes"),
oss.ForbidOverWrite(true),
oss.CallbackParam(callbackPolicyEncoded))
if err != nil {
return nil, err
}
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadID: imur.UploadID,
UploadURLs: urls,
CompleteURL: completeURL,
}, nil
}
// 取消上传凭证
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return handler.bucket.AbortMultipartUpload(oss.InitiateMultipartUploadResult{UploadID: uploadSession.UploadID, Key: uploadSession.SavePath}, nil)
}

View File

@@ -0,0 +1,354 @@
package qiniu
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
)
// Driver 本地策略适配器
type Driver struct {
Policy *model.Policy
mac *qbox.Mac
cfg *storage.Config
bucket *storage.BucketManager
}
func NewDriver(policy *model.Policy) *Driver {
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
}
mac := qbox.NewMac(policy.AccessKey, policy.SecretKey)
cfg := &storage.Config{UseHTTPS: true}
return &Driver{
Policy: policy,
mac: mac,
cfg: cfg,
bucket: storage.NewBucketManager(mac, cfg),
}
}
// List 列出给定路径下的文件
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
base = strings.TrimPrefix(base, "/")
if base != "" {
base += "/"
}
var (
delimiter string
marker string
objects []storage.ListItem
commons []string
)
if !recursive {
delimiter = "/"
}
for {
entries, folders, nextMarker, hashNext, err := handler.bucket.ListFiles(
handler.Policy.BucketName,
base, delimiter, marker, 1000)
if err != nil {
return nil, err
}
objects = append(objects, entries...)
commons = append(commons, folders...)
if !hashNext {
break
}
marker = nextMarker
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(base, object)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(base, object.Key)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Fsize),
IsDir: false,
LastModify: time.Unix(object.PutTime/10000000, 0),
})
}
return res, nil
}
// Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 给文件名加上随机参数以强制拉取
path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano())
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithHeader(
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
// 凭证有效期
credentialTTL := model.GetIntSetting("upload_session_timeout", 3600)
// 生成上传策略
fileInfo := file.Info()
scope := handler.Policy.BucketName
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
scope = fmt.Sprintf("%s:%s", handler.Policy.BucketName, fileInfo.SavePath)
}
putPolicy := storage.PutPolicy{
// 指定为覆盖策略
Scope: scope,
SaveKey: fileInfo.SavePath,
ForceSaveKey: true,
FsizeLimit: int64(fileInfo.Size),
}
// 是否开启了MIMEType限制
if handler.Policy.OptionsSerialized.MimeType != "" {
putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType
}
// 生成上传凭证
token, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, int64(credentialTTL), false)
if err != nil {
return err
}
// 创建上传表单
cfg := storage.Config{}
formUploader := storage.NewFormUploader(&cfg)
ret := storage.PutRet{}
putExtra := storage.PutExtra{
Params: map[string]string{},
}
// 开始上传
err = formUploader.Put(ctx, &ret, token.Credential, fileInfo.SavePath, file, int64(fileInfo.Size), &putExtra)
if err != nil {
return err
}
return nil
}
// Delete 删除一个或多个文件,
// 返回未删除的文件
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
// TODO 大于一千个文件需要分批发送
deleteOps := make([]string, 0, len(files))
for _, key := range files {
deleteOps = append(deleteOps, storage.URIDelete(handler.Policy.BucketName, key))
}
rets, err := handler.bucket.Batch(deleteOps)
// 处理删除结果
if err != nil {
failed := make([]string, 0, len(rets))
for k, ret := range rets {
if ret.Code != 200 && ret.Code != 612 {
failed = append(failed, files[k])
}
}
return failed, errors.New("删除失败")
}
return []string{}, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://developer.qiniu.com/dora/api/basic-processing-images-imageview2
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "tiff", "avif", "psd"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) {
return nil, driver.ErrorThumbNotSupported
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumb := fmt.Sprintf("%s?imageView2/1/w/%d/h/%d/q/%d", file.SourceName, thumbSize[0], thumbSize[1], thumbEncodeQuality)
return &response.ContentResponse{
Redirect: true,
URL: handler.signSourceURL(
ctx,
thumb,
int64(model.GetIntSetting("preview_timeout", 60)),
),
}, nil
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 加入下载相关设置
if isDownload {
path = path + "?attname=" + url.PathEscape(fileName)
}
// 取得原始文件地址
return handler.signSourceURL(ctx, path, ttl), nil
}
func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64) string {
var sourceURL string
if handler.Policy.IsPrivate {
deadline := time.Now().Add(time.Second * time.Duration(ttl)).Unix()
sourceURL = storage.MakePrivateURL(handler.mac, handler.Policy.BaseURL, path, deadline)
} else {
sourceURL = storage.MakePublicURL(handler.Policy.BaseURL, path)
}
return sourceURL
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/qiniu/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI)
// 创建上传策略
fileInfo := file.Info()
putPolicy := storage.PutPolicy{
Scope: handler.Policy.BucketName,
CallbackURL: apiURL.String(),
CallbackBody: `{"size":$(fsize),"pic_info":"$(imageInfo.width),$(imageInfo.height)"}`,
CallbackBodyType: "application/json",
SaveKey: fileInfo.SavePath,
ForceSaveKey: true,
FsizeLimit: int64(handler.Policy.MaxSize),
}
// 是否开启了MIMEType限制
if handler.Policy.OptionsSerialized.MimeType != "" {
putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType
}
credential, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, ttl, true)
if err != nil {
return nil, fmt.Errorf("failed to init parts: %w", err)
}
credential.SessionID = uploadSession.Key
credential.ChunkSize = handler.Policy.OptionsSerialized.ChunkSize
uploadSession.UploadURL = credential.UploadURLs[0]
uploadSession.Credential = credential.Credential
return credential, nil
}
// getUploadCredential 签名上传策略并创建上传会话
func (handler *Driver) getUploadCredential(ctx context.Context, policy storage.PutPolicy, file *fsctx.UploadTaskInfo, TTL int64, resume bool) (*serializer.UploadCredential, error) {
// 上传凭证
policy.Expires = uint64(TTL)
upToken := policy.UploadToken(handler.mac)
// 初始化分片上传
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
upHost, err := resumeUploader.UpHost(handler.Policy.AccessKey, handler.Policy.BucketName)
if err != nil {
return nil, err
}
ret := &storage.InitPartsRet{}
if resume {
err = resumeUploader.InitParts(ctx, upToken, upHost, handler.Policy.BucketName, file.SavePath, true, ret)
}
return &serializer.UploadCredential{
UploadURLs: []string{upHost + "/buckets/" + handler.Policy.BucketName + "/objects/" + base64.URLEncoding.EncodeToString([]byte(file.SavePath)) + "/uploads/" + ret.UploadID},
Credential: upToken,
}, err
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
return resumeUploader.Client.CallWith(ctx, nil, "DELETE", uploadSession.UploadURL, http.Header{"Authorization": {"UpToken " + uploadSession.Credential}}, nil, 0)
}

View File

@@ -0,0 +1,195 @@
package remote
import (
"context"
"encoding/json"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
"io"
"net/http"
"net/url"
"path"
"strings"
"time"
)
const (
basePath = "/api/v3/slave/"
OverwriteHeader = auth.CrHeaderPrefix + "Overwrite"
chunkRetrySleep = time.Duration(5) * time.Second
)
// Client to operate uploading to remote slave server
type Client interface {
// CreateUploadSession creates remote upload session
CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error
// GetUploadURL signs an url for uploading file
GetUploadURL(ttl int64, sessionID string) (string, string, error)
// Upload uploads file to remote server
Upload(ctx context.Context, file fsctx.FileHeader) error
// DeleteUploadSession deletes remote upload session
DeleteUploadSession(ctx context.Context, sessionID string) error
}
// NewClient creates new Client from given policy
func NewClient(policy *model.Policy) (Client, error) {
authInstance := auth.HMACAuth{[]byte(policy.SecretKey)}
serverURL, err := url.Parse(policy.Server)
if err != nil {
return nil, err
}
base, _ := url.Parse(basePath)
signTTL := model.GetIntSetting("slave_api_timeout", 60)
return &remoteClient{
policy: policy,
authInstance: authInstance,
httpClient: request.NewClient(
request.WithEndpoint(serverURL.ResolveReference(base).String()),
request.WithCredential(authInstance, int64(signTTL)),
request.WithMasterMeta(),
request.WithSlaveMeta(policy.AccessKey),
),
}, nil
}
type remoteClient struct {
policy *model.Policy
authInstance auth.Auth
httpClient request.Client
}
func (c *remoteClient) Upload(ctx context.Context, file fsctx.FileHeader) error {
ttl := model.GetIntSetting("upload_session_timeout", 86400)
fileInfo := file.Info()
session := &serializer.UploadSession{
Key: uuid.Must(uuid.NewV4()).String(),
VirtualPath: fileInfo.VirtualPath,
Name: fileInfo.FileName,
Size: fileInfo.Size,
SavePath: fileInfo.SavePath,
LastModified: fileInfo.LastModified,
Policy: *c.policy,
}
// Create upload session
overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite
if err := c.CreateUploadSession(ctx, session, int64(ttl), overwrite); err != nil {
return fmt.Errorf("failed to create upload session: %w", err)
}
// Initial chunk groups
chunks := chunk.NewChunkGroup(file, c.policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
Max: model.GetIntSetting("chunk_retries", 5),
Sleep: chunkRetrySleep,
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
return c.uploadChunk(ctx, session.Key, current.Index(), content, overwrite, current.Length())
}
// upload chunks
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
if err := c.DeleteUploadSession(ctx, session.Key); err != nil {
util.Log().Warning("failed to delete upload session: %s", err)
}
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
return nil
}
func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error {
resp, err := c.httpClient.Request(
"DELETE",
"upload/"+sessionID,
nil,
request.WithContext(ctx),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (c *remoteClient) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error {
reqBodyEncoded, err := json.Marshal(map[string]interface{}{
"session": session,
"ttl": ttl,
"overwrite": overwrite,
})
if err != nil {
return err
}
bodyReader := strings.NewReader(string(reqBodyEncoded))
resp, err := c.httpClient.Request(
"PUT",
"upload",
bodyReader,
request.WithContext(ctx),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (c *remoteClient) GetUploadURL(ttl int64, sessionID string) (string, string, error) {
base, err := url.Parse(c.policy.Server)
if err != nil {
return "", "", err
}
base.Path = path.Join(base.Path, basePath, "upload", sessionID)
req, err := http.NewRequest("POST", base.String(), nil)
if err != nil {
return "", "", err
}
req = auth.SignRequest(c.authInstance, req, ttl)
return req.URL.String(), req.Header["Authorization"][0], nil
}
func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error {
resp, err := c.httpClient.Request(
"POST",
fmt.Sprintf("upload/%s?chunk=%d", sessionID, index),
chunk,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
request.WithContentLength(size),
request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}

View File

@@ -0,0 +1,311 @@
package remote
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Driver 远程存储策略适配器
type Driver struct {
Client request.Client
Policy *model.Policy
AuthInstance auth.Auth
uploadClient Client
}
// NewDriver initializes a new Driver from policy
// TODO: refactor all method into upload client
func NewDriver(policy *model.Policy) (*Driver, error) {
client, err := NewClient(policy)
if err != nil {
return nil, err
}
return &Driver{
Policy: policy,
Client: request.NewClient(),
AuthInstance: auth.HMACAuth{[]byte(policy.SecretKey)},
uploadClient: client,
}, nil
}
// List 列取文件
func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
var res []response.Object
reqBody := serializer.ListRequest{
Path: path,
Recursive: recursive,
}
reqBodyEncoded, err := json.Marshal(reqBody)
if err != nil {
return res, err
}
// 发送列表请求
bodyReader := strings.NewReader(string(reqBodyEncoded))
signTTL := model.GetIntSetting("slave_api_timeout", 60)
resp, err := handler.Client.Request(
"POST",
handler.getAPIUrl("list"),
bodyReader,
request.WithCredential(handler.AuthInstance, int64(signTTL)),
request.WithMasterMeta(),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return res, err
}
// 处理列取结果
if resp.Code != 0 {
return res, errors.New(resp.Error)
}
if resStr, ok := resp.Data.(string); ok {
err = json.Unmarshal([]byte(resStr), &res)
if err != nil {
return res, err
}
}
return res, nil
}
// getAPIUrl 获取接口请求地址
func (handler *Driver) getAPIUrl(scope string, routes ...string) string {
serverURL, err := url.Parse(handler.Policy.Server)
if err != nil {
return ""
}
var controller *url.URL
switch scope {
case "delete":
controller, _ = url.Parse("/api/v3/slave/delete")
case "thumb":
controller, _ = url.Parse("/api/v3/slave/thumb")
case "list":
controller, _ = url.Parse("/api/v3/slave/list")
default:
controller = serverURL
}
for _, r := range routes {
controller.Path = path.Join(controller.Path, r)
}
return serverURL.ResolveReference(controller).String()
}
// Get 获取文件内容
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 尝试获取速度限制
speedLimit := 0
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
speedLimit = user.Group.SpeedLimit
}
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, 0, true, speedLimit)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.Client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
request.WithMasterMeta(),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
return handler.uploadClient.Upload(ctx, file)
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
// 封装接口请求正文
reqBody := serializer.RemoteDeleteRequest{
Files: files,
}
reqBodyEncoded, err := json.Marshal(reqBody)
if err != nil {
return files, err
}
// 发送删除请求
bodyReader := strings.NewReader(string(reqBodyEncoded))
signTTL := model.GetIntSetting("slave_api_timeout", 60)
resp, err := handler.Client.Request(
"POST",
handler.getAPIUrl("delete"),
bodyReader,
request.WithCredential(handler.AuthInstance, int64(signTTL)),
request.WithMasterMeta(),
request.WithSlaveMeta(handler.Policy.AccessKey),
).CheckHTTPResponse(200).GetResponse()
if err != nil {
return files, err
}
// 处理删除结果
var reqResp serializer.Response
err = json.Unmarshal([]byte(resp), &reqResp)
if err != nil {
return files, err
}
if reqResp.Code != 0 {
var failedResp serializer.RemoteDeleteRequest
if failed, ok := reqResp.Data.(string); ok {
err = json.Unmarshal([]byte(failed), &failedResp)
if err == nil {
return failedResp.Files, errors.New(reqResp.Error)
}
}
return files, errors.New("unknown format of returned response")
}
return []string{}, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
supported := []string{"png", "jpg", "jpeg", "gif"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) {
return nil, driver.ErrorThumbNotSupported
}
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(file.SourceName))
thumbURL := fmt.Sprintf("%s/%s/%s", handler.getAPIUrl("thumb"), sourcePath, filepath.Ext(file.Name))
ttl := model.GetIntSetting("preview_timeout", 60)
signedThumbURL, err := auth.SignURI(handler.AuthInstance, thumbURL, int64(ttl))
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: signedThumbURL.String(),
}, nil
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := "file"
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
serverURL, err := url.Parse(handler.Policy.Server)
if err != nil {
return "", errors.New("无法解析远程服务端地址")
}
// 是否启用了CDN
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
serverURL = cdnURL
}
var (
signedURI *url.URL
controller = "/api/v3/slave/download"
)
if !isDownload {
controller = "/api/v3/slave/source"
}
// 签名下载地址
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path))
signedURI, err = auth.SignURI(
handler.AuthInstance,
fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, url.PathEscape(fileName)),
ttl,
)
if err != nil {
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign URL", err)
}
finalURL := serverURL.ResolveReference(signedURI).String()
return finalURL, nil
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote", uploadSession.Key, uploadSession.CallbackSecret))
apiURL := siteURL.ResolveReference(apiBaseURI)
// 在从机端创建上传会话
uploadSession.Callback = apiURL.String()
if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, ttl, false); err != nil {
return nil, err
}
// 获取上传地址
uploadURL, sign, err := handler.uploadClient.GetUploadURL(ttl, uploadSession.Key)
if err != nil {
return nil, fmt.Errorf("failed to sign upload url: %w", err)
}
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadURLs: []string{uploadURL},
Credential: sign,
}, nil
}
// 取消上传凭证
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Key)
}

View File

@@ -0,0 +1,440 @@
package s3
import (
"context"
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"io"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver 适配器模板
type Driver struct {
Policy *model.Policy
sess *session.Session
svc *s3.S3
}
// UploadPolicy S3上传策略
type UploadPolicy struct {
Expiration string `json:"expiration"`
Conditions []interface{} `json:"conditions"`
}
// MetaData 文件信息
type MetaData struct {
Size uint64
Etag string
}
func NewDriver(policy *model.Policy) (*Driver, error) {
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
}
driver := &Driver{
Policy: policy,
}
return driver, driver.InitS3Client()
}
// InitS3Client 初始化S3会话
func (handler *Driver) InitS3Client() error {
if handler.Policy == nil {
return errors.New("empty policy")
}
if handler.svc == nil {
// 初始化会话
sess, err := session.NewSession(&aws.Config{
Credentials: credentials.NewStaticCredentials(handler.Policy.AccessKey, handler.Policy.SecretKey, ""),
Endpoint: &handler.Policy.Server,
Region: &handler.Policy.OptionsSerialized.Region,
S3ForcePathStyle: &handler.Policy.OptionsSerialized.S3ForcePathStyle,
})
if err != nil {
return err
}
handler.sess = sess
handler.svc = s3.New(sess)
}
return nil
}
// List 列出给定路径下的文件
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
// 初始化列目录参数
base = strings.TrimPrefix(base, "/")
if base != "" {
base += "/"
}
opt := &s3.ListObjectsInput{
Bucket: &handler.Policy.BucketName,
Prefix: &base,
MaxKeys: aws.Int64(1000),
}
// 是否为递归列出
if !recursive {
opt.Delimiter = aws.String("/")
}
var (
objects []*s3.Object
commons []*s3.CommonPrefix
)
for {
res, err := handler.svc.ListObjectsWithContext(ctx, opt)
if err != nil {
return nil, err
}
objects = append(objects, res.Contents...)
commons = append(commons, res.CommonPrefixes...)
// 如果本次未列取完则继续使用marker获取结果
if *res.IsTruncated {
opt.Marker = res.NextMarker
} else {
break
}
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(*opt.Prefix, *object.Prefix)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(*object.Prefix),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(*opt.Prefix, *object.Key)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(*object.Key),
Source: *object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(*object.Size),
IsDir: false,
LastModify: time.Now(),
})
}
return res, nil
}
// Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithHeader(
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
// 初始化客户端
if err := handler.InitS3Client(); err != nil {
return err
}
uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) {
u.PartSize = int64(handler.Policy.OptionsSerialized.ChunkSize)
})
dst := file.Info().SavePath
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: &handler.Policy.BucketName,
Key: &dst,
Body: io.LimitReader(file, int64(file.Info().Size)),
})
if err != nil {
return err
}
return nil
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
failed := make([]string, 0, len(files))
deleted := make([]string, 0, len(files))
keys := make([]*s3.ObjectIdentifier, 0, len(files))
for _, file := range files {
filePath := file
keys = append(keys, &s3.ObjectIdentifier{Key: &filePath})
}
// 发送异步删除请求
res, err := handler.svc.DeleteObjects(
&s3.DeleteObjectsInput{
Bucket: &handler.Policy.BucketName,
Delete: &s3.Delete{
Objects: keys,
},
})
if err != nil {
return files, err
}
// 统计未删除的文件
for _, deleteRes := range res.Deleted {
deleted = append(deleted, *deleteRes.Key)
}
failed = util.SliceDifference(files, deleted)
return failed, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
return nil, driver.ErrorThumbNotSupported
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 初始化客户端
if err := handler.InitS3Client(); err != nil {
return "", err
}
contentDescription := aws.String("attachment; filename=\"" + url.PathEscape(fileName) + "\"")
if !isDownload {
contentDescription = nil
}
req, _ := handler.svc.GetObjectRequest(
&s3.GetObjectInput{
Bucket: &handler.Policy.BucketName,
Key: &path,
ResponseContentDisposition: contentDescription,
})
signedURL, err := req.Presign(time.Duration(ttl) * time.Second)
if err != nil {
return "", err
}
// 将最终生成的签名URL域名换成用户自定义的加速域名如果有
finalURL, err := url.Parse(signedURL)
if err != nil {
return "", err
}
// 公有空间替换掉Key及不支持的头
if !handler.Policy.IsPrivate {
finalURL.RawQuery = ""
}
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
finalURL.Host = cdnURL.Host
finalURL.Scheme = cdnURL.Scheme
}
return finalURL.String(), nil
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 检查文件是否存在
fileInfo := file.Info()
if _, err := handler.Meta(ctx, fileInfo.SavePath); err == nil {
return nil, fmt.Errorf("file already exist")
}
// 创建分片上传
expires := time.Now().Add(time.Duration(ttl) * time.Second)
res, err := handler.svc.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
Bucket: &handler.Policy.BucketName,
Key: &fileInfo.SavePath,
Expires: &expires,
ContentType: aws.String(fileInfo.DetectMimeType()),
})
if err != nil {
return nil, fmt.Errorf("failed to create multipart upload: %w", err)
}
uploadSession.UploadID = *res.UploadId
// 为每个分片签名上传 URL
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false)
urls := make([]string, chunks.Num())
for chunks.Next() {
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{
Bucket: &handler.Policy.BucketName,
Key: &fileInfo.SavePath,
PartNumber: aws.Int64(int64(c.Index() + 1)),
UploadId: res.UploadId,
})
signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
if err != nil {
return err
}
urls[c.Index()] = signedURL
return nil
})
if err != nil {
return nil, err
}
}
// 签名完成分片上传的请求URL
signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{
Bucket: &handler.Policy.BucketName,
Key: &fileInfo.SavePath,
UploadId: res.UploadId,
})
signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
if err != nil {
return nil, err
}
// 生成上传凭证
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadID: *res.UploadId,
UploadURLs: urls,
CompleteURL: signedURL,
}, nil
}
// Meta 获取文件信息
func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
res, err := handler.svc.HeadObject(
&s3.HeadObjectInput{
Bucket: &handler.Policy.BucketName,
Key: &path,
})
if err != nil {
return nil, err
}
return &MetaData{
Size: uint64(*res.ContentLength),
Etag: *res.ETag,
}, nil
}
// CORS 创建跨域策略
func (handler *Driver) CORS() error {
rule := s3.CORSRule{
AllowedMethods: aws.StringSlice([]string{
"GET",
"POST",
"PUT",
"DELETE",
"HEAD",
}),
AllowedOrigins: aws.StringSlice([]string{"*"}),
AllowedHeaders: aws.StringSlice([]string{"*"}),
ExposeHeaders: aws.StringSlice([]string{"ETag"}),
MaxAgeSeconds: aws.Int64(3600),
}
_, err := handler.svc.PutBucketCors(&s3.PutBucketCorsInput{
Bucket: &handler.Policy.BucketName,
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{&rule},
},
})
return err
}
// 取消上传凭证
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
_, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
UploadId: &uploadSession.UploadID,
Bucket: &handler.Policy.BucketName,
Key: &uploadSession.SavePath,
})
return err
}

View File

@@ -0,0 +1,7 @@
package masterinslave
import "errors"
var (
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
)

View File

@@ -0,0 +1,60 @@
package masterinslave
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver 影子存储策略,用于在从机端上传文件
type Driver struct {
master cluster.Node
handler driver.Handler
policy *model.Policy
}
// NewDriver 返回新的处理器
func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
return &Driver{
master: master,
handler: handler,
policy: policy,
}
}
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
return d.handler.Put(ctx, file)
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return d.handler.Delete(ctx, files)
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented
}
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
return nil, ErrNotImplemented
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
return nil, ErrNotImplemented
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}

View File

@@ -0,0 +1,9 @@
package slaveinmaster
import "errors"
var (
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
ErrWaitResultTimeout = errors.New("timeout waiting for slave transfer result")
)

View File

@@ -0,0 +1,124 @@
package slaveinmaster
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/url"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果
type Driver struct {
node cluster.Node
handler driver.Handler
policy *model.Policy
client request.Client
}
// NewDriver 返回新的从机指派处理器
func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
var endpoint *url.URL
if serverURL, err := url.Parse(node.DBModel().Server); err == nil {
var controller *url.URL
controller, _ = url.Parse("/api/v3/slave/")
endpoint = serverURL.ResolveReference(controller)
}
signTTL := model.GetIntSetting("slave_api_timeout", 60)
return &Driver{
node: node,
handler: handler,
policy: policy,
client: request.NewClient(
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(node.SlaveAuthInstance(), int64(signTTL)),
request.WithEndpoint(endpoint.String()),
),
}
}
// Put 将ctx中指定的从机物理文件由从机上传到目标存储策略
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
fileInfo := file.Info()
req := serializer.SlaveTransferReq{
Src: fileInfo.Src,
Dst: fileInfo.SavePath,
Policy: d.policy,
}
body, err := json.Marshal(req)
if err != nil {
return err
}
// 订阅转存结果
resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0)
defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan)
res, err := d.client.Request("PUT", "task/transfer", bytes.NewReader(body)).
CheckHTTPResponse(200).
DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
// 等待转存结果或者超时
waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800)
select {
case <-time.After(time.Duration(waitTimeout) * time.Second):
return ErrWaitResultTimeout
case msg := <-resChan:
if msg.Event != serializer.SlaveTransferSuccess {
return errors.New(msg.Content.(serializer.SlaveTransferResult).Error)
}
}
return nil
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return d.handler.Delete(ctx, files)
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented
}
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
return nil, ErrNotImplemented
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
return nil, ErrNotImplemented
}
// 取消上传凭证
func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}

View File

@@ -0,0 +1,358 @@
package upyun
import (
"context"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"sync"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/upyun/go-sdk/upyun"
)
// UploadPolicy 又拍云上传策略
type UploadPolicy struct {
Bucket string `json:"bucket"`
SaveKey string `json:"save-key"`
Expiration int64 `json:"expiration"`
CallbackURL string `json:"notify-url"`
ContentLength uint64 `json:"content-length"`
ContentLengthRange string `json:"content-length-range,omitempty"`
AllowFileType string `json:"allow-file-type,omitempty"`
}
// Driver 又拍云策略适配器
type Driver struct {
Policy *model.Policy
}
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
base = strings.TrimPrefix(base, "/")
// 用于接受SDK返回对象的chan
objChan := make(chan *upyun.FileInfo)
objects := []*upyun.FileInfo{}
// 列取配置
listConf := &upyun.GetObjectsConfig{
Path: "/" + base,
ObjectsChan: objChan,
MaxListTries: 1,
}
// 递归列取时不限制递归次数
if recursive {
listConf.MaxListLevel = -1
}
// 启动一个goroutine收集列取的对象信
wg := &sync.WaitGroup{}
wg.Add(1)
go func(input chan *upyun.FileInfo, output *[]*upyun.FileInfo, wg *sync.WaitGroup) {
defer wg.Done()
for {
file, ok := <-input
if !ok {
return
}
*output = append(*output, file)
}
}(objChan, &objects, wg)
up := upyun.NewUpYun(&upyun.UpYunConfig{
Bucket: handler.Policy.BucketName,
Operator: handler.Policy.AccessKey,
Password: handler.Policy.SecretKey,
})
err := up.List(listConf)
if err != nil {
return nil, err
}
wg.Wait()
// 汇总处理列取结果
res := make([]response.Object, 0, len(objects))
for _, object := range objects {
res = append(res, response.Object{
Name: path.Base(object.Name),
RelativePath: object.Name,
Source: path.Join(base, object.Name),
Size: uint64(object.Size),
IsDir: object.IsDir,
LastModify: object.Time,
})
}
return res, nil
}
// Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithHeader(
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
up := upyun.NewUpYun(&upyun.UpYunConfig{
Bucket: handler.Policy.BucketName,
Operator: handler.Policy.AccessKey,
Password: handler.Policy.SecretKey,
})
err := up.Put(&upyun.PutObjectConfig{
Path: file.Info().SavePath,
Reader: file,
})
return err
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
up := upyun.NewUpYun(&upyun.UpYunConfig{
Bucket: handler.Policy.BucketName,
Operator: handler.Policy.AccessKey,
Password: handler.Policy.SecretKey,
})
var (
failed = make([]string, 0, len(files))
lastErr error
currentIndex = 0
indexLock sync.Mutex
failedLock sync.Mutex
wg sync.WaitGroup
routineNum = 4
)
wg.Add(routineNum)
// upyun不支持批量操作这里开四个协程并行操作
for i := 0; i < routineNum; i++ {
go func() {
for {
// 取得待删除文件
indexLock.Lock()
if currentIndex >= len(files) {
// 所有文件处理完成
wg.Done()
indexLock.Unlock()
return
}
path := files[currentIndex]
currentIndex++
indexLock.Unlock()
// 发送异步删除请求
err := up.Delete(&upyun.DeleteObjectConfig{
Path: path,
Async: true,
})
// 处理错误
if err != nil {
failedLock.Lock()
lastErr = err
failed = append(failed, path)
failedLock.Unlock()
}
}
}()
}
wg.Wait()
return failed, lastErr
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://help.upyun.com/knowledge-base/image/
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "svg"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) {
return nil, driver.ErrorThumbNotSupported
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumbParam := fmt.Sprintf("!/fwfh/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality)
thumbURL, err := handler.Source(ctx, file.SourceName+thumbParam, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: thumbURL,
}, nil
}
// Source 获取外链URL
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
sourceURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
fileKey, err := url.Parse(url.PathEscape(path))
if err != nil {
return "", err
}
sourceURL = sourceURL.ResolveReference(fileKey)
// 如果是下载文件URL
if isDownload {
query := sourceURL.Query()
query.Add("_upd", fileName)
sourceURL.RawQuery = query.Encode()
}
return handler.signURL(ctx, sourceURL, ttl)
}
func (handler Driver) signURL(ctx context.Context, path *url.URL, TTL int64) (string, error) {
if !handler.Policy.IsPrivate {
// 未开启Token防盗链时直接返回
return path.String(), nil
}
etime := time.Now().Add(time.Duration(TTL) * time.Second).Unix()
signStr := fmt.Sprintf(
"%s&%d&%s",
handler.Policy.OptionsSerialized.Token,
etime,
path.Path,
)
signMd5 := fmt.Sprintf("%x", md5.Sum([]byte(signStr)))
finalSign := signMd5[12:20] + strconv.FormatInt(etime, 10)
// 将签名添加到URL中
query := path.Query()
query.Add("_upt", finalSign)
path.RawQuery = query.Encode()
return path.String(), nil
}
// Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/upyun/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI)
// 上传策略
fileInfo := file.Info()
putPolicy := UploadPolicy{
Bucket: handler.Policy.BucketName,
// TODO escape
SaveKey: fileInfo.SavePath,
Expiration: time.Now().Add(time.Duration(ttl) * time.Second).Unix(),
CallbackURL: apiURL.String(),
ContentLength: fileInfo.Size,
ContentLengthRange: fmt.Sprintf("0,%d", fileInfo.Size),
AllowFileType: strings.Join(handler.Policy.OptionsSerialized.FileType, ","),
}
// 生成上传凭证
policyJSON, err := json.Marshal(putPolicy)
if err != nil {
return nil, err
}
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
// 生成签名
elements := []string{"POST", "/" + handler.Policy.BucketName, policyEncoded}
signStr := handler.Sign(ctx, elements)
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
Policy: policyEncoded,
Credential: signStr,
UploadURLs: []string{"https://v0.api.upyun.com/" + handler.Policy.BucketName},
}, nil
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}
// Sign 计算又拍云的签名头
func (handler Driver) Sign(ctx context.Context, elements []string) string {
password := fmt.Sprintf("%x", md5.Sum([]byte(handler.Policy.SecretKey)))
mac := hmac.New(sha1.New, []byte(password))
value := strings.Join(elements, "&")
mac.Write([]byte(value))
signStr := base64.StdEncoding.EncodeToString((mac.Sum(nil)))
return fmt.Sprintf("UPYUN %s:%s", handler.Policy.AccessKey, signStr)
}

25
pkg/filesystem/errors.go Normal file
View File

@@ -0,0 +1,25 @@
package filesystem
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil)
ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil)
ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "File type not allowed", nil)
ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil)
ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil)
ErrClientCanceled = errors.New("Client canceled operation")
ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "Root protected", nil)
ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil)
ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil)
ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "Upload session existed", nil)
ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil)
ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "Object not exist", nil)
ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil)
ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil)
ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil)
ErrOneObjectOnly = serializer.ParamErr("You can only copy one object at the same time", nil)
)

387
pkg/filesystem/file.go Normal file
View File

@@ -0,0 +1,387 @@
package filesystem
import (
"context"
"fmt"
"io"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/juju/ratelimit"
)
/* ============
文件相关
============
*/
// 限速后的ReaderSeeker
type lrs struct {
response.RSCloser
r io.Reader
}
func (r lrs) Read(p []byte) (int, error) {
return r.r.Read(p)
}
// withSpeedLimit 给原有的ReadSeeker加上限速
func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
// 如果用户组有速度限制就返回限制流速的ReaderSeeker
if fs.User.Group.SpeedLimit != 0 {
speed := fs.User.Group.SpeedLimit
bucket := ratelimit.NewBucketWithRate(float64(speed), int64(speed))
lrs := lrs{rs, ratelimit.Reader(rs, bucket)}
return lrs
}
// 否则返回原始流
return rs
}
// AddFile 新增文件记录
func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fsctx.FileHeader) (*model.File, error) {
// 添加文件记录前的钩子
err := fs.Trigger(ctx, "BeforeAddFile", file)
if err != nil {
return nil, err
}
uploadInfo := file.Info()
newFile := model.File{
Name: uploadInfo.FileName,
SourceName: uploadInfo.SavePath,
UserID: fs.User.ID,
Size: uploadInfo.Size,
FolderID: parent.ID,
PolicyID: fs.Policy.ID,
MetadataSerialized: uploadInfo.Metadata,
UploadSessionID: uploadInfo.UploadSessionID,
}
err = newFile.Create()
if err != nil {
if err := fs.Trigger(ctx, "AfterValidateFailed", file); err != nil {
util.Log().Debug("AfterValidateFailed hook execution failed: %s", err)
}
return nil, ErrFileExisted.WithError(err)
}
fs.User.Storage += newFile.Size
return &newFile, nil
}
// GetPhysicalFileContent 根据文件物理路径获取文件流
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
// 重设上传策略
fs.Policy = &model.Policy{Type: "local"}
_ = fs.DispatchHandler()
// 获取文件流
rs, err := fs.Handler.Get(ctx, path)
if err != nil {
return nil, err
}
return fs.withSpeedLimit(rs), nil
}
// Preview 预览文件
//
// path - 文件虚拟路径
// isText - 是否为文本文件,文本文件会忽略重定向,直接由
// 服务端拉取中转给用户,故会对文件大小进行限制
func (fs *FileSystem) Preview(ctx context.Context, id uint, isText bool) (*response.ContentResponse, error) {
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return nil, err
}
// 如果是文本文件预览,需要检查大小限制
sizeLimit := model.GetIntSetting("maxEditSize", 2<<20)
if isText && fs.FileTarget[0].Size > uint64(sizeLimit) {
return nil, ErrFileSizeTooBig
}
// 是否直接返回文件内容
if isText || fs.Policy.IsDirectlyPreview() {
resp, err := fs.GetDownloadContent(ctx, id)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: false,
Content: resp,
}, nil
}
// 否则重定向到签名的预览URL
ttl := model.GetIntSetting("preview_timeout", 60)
previewURL, err := fs.SignURL(ctx, &fs.FileTarget[0], int64(ttl), false)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: previewURL,
MaxAge: ttl,
}, nil
}
// GetDownloadContent 获取用于下载的文件流
func (fs *FileSystem) GetDownloadContent(ctx context.Context, id uint) (response.RSCloser, error) {
// 获取原始文件流
rs, err := fs.GetContent(ctx, id)
if err != nil {
return nil, err
}
// 返回限速处理后的文件流
return fs.withSpeedLimit(rs), nil
}
// GetContent 获取文件内容path为虚拟路径
func (fs *FileSystem) GetContent(ctx context.Context, id uint) (response.RSCloser, error) {
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
// 获取文件流
rs, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName)
if err != nil {
return nil, ErrIO.WithError(err)
}
return rs, nil
}
// deleteGroupedFile 对分组好的文件执行删除操作,
// 返回每个分组失败的文件列表
func (fs *FileSystem) deleteGroupedFile(ctx context.Context, files map[uint][]*model.File) map[uint][]string {
// 失败的文件列表
// TODO 并行删除
failed := make(map[uint][]string, len(files))
thumbs := make([]string, 0)
for policyID, toBeDeletedFiles := range files {
// 列举出需要物理删除的文件的物理路径
sourceNamesAll := make([]string, 0, len(toBeDeletedFiles))
uploadSessions := make([]*serializer.UploadSession, 0, len(toBeDeletedFiles))
for i := 0; i < len(toBeDeletedFiles); i++ {
sourceNamesAll = append(sourceNamesAll, toBeDeletedFiles[i].SourceName)
if toBeDeletedFiles[i].UploadSessionID != nil {
if session, ok := cache.Get(UploadSessionCachePrefix + *toBeDeletedFiles[i].UploadSessionID); ok {
uploadSession := session.(serializer.UploadSession)
uploadSessions = append(uploadSessions, &uploadSession)
}
}
// Check if sidecar thumb file exist
if model.IsTrueVal(toBeDeletedFiles[i].MetadataSerialized[model.ThumbSidecarMetadataKey]) {
thumbs = append(thumbs, toBeDeletedFiles[i].ThumbFile())
}
}
// 切换上传策略
fs.Policy = toBeDeletedFiles[0].GetPolicy()
err := fs.DispatchHandler()
if err != nil {
failed[policyID] = sourceNamesAll
continue
}
// 取消上传会话
for _, upSession := range uploadSessions {
if err := fs.Handler.CancelToken(ctx, upSession); err != nil {
util.Log().Warning("Failed to cancel upload session for %q: %s", upSession.Name, err)
}
cache.Deletes([]string{upSession.Key}, UploadSessionCachePrefix)
}
// 执行删除
toBeDeletedSrcs := append(sourceNamesAll, thumbs...)
failedFile, _ := fs.Handler.Delete(ctx, toBeDeletedSrcs)
// Exclude failed results related to thumb file
failed[policyID] = util.SliceDifference(failedFile, thumbs)
}
return failed
}
// GroupFileByPolicy 将目标文件按照存储策略分组
func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File) map[uint][]*model.File {
var policyGroup = make(map[uint][]*model.File)
for key := range files {
if file, ok := policyGroup[files[key].PolicyID]; ok {
// 如果已存在分组,直接追加
policyGroup[files[key].PolicyID] = append(file, &files[key])
} else {
// 分组不存在,创建
policyGroup[files[key].PolicyID] = make([]*model.File, 0)
policyGroup[files[key].PolicyID] = append(policyGroup[files[key].PolicyID], &files[key])
}
}
return policyGroup
}
// GetDownloadURL 创建文件下载链接, timeout 为数据库中存储过期时间的字段
func (fs *FileSystem) GetDownloadURL(ctx context.Context, id uint, timeout string) (string, error) {
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return "", err
}
fileTarget := &fs.FileTarget[0]
// 生成下載地址
ttl := model.GetIntSetting(timeout, 60)
source, err := fs.SignURL(
ctx,
fileTarget,
int64(ttl),
true,
)
if err != nil {
return "", err
}
return source, nil
}
// GetSource 获取可直接访问文件的外链地址
func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) {
// 查找文件记录
err := fs.resetFileIDIfNotExist(ctx, fileID)
if err != nil {
return "", ErrObjectNotExist.WithError(err)
}
// 检查存储策略是否可以获得外链
if !fs.Policy.IsOriginLinkEnable {
return "", serializer.NewError(
serializer.CodePolicyNotAllowed,
"This policy is not enabled for getting source link",
nil,
)
}
source, err := fs.SignURL(ctx, &fs.FileTarget[0], 0, false)
if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
}
return source, nil
}
// SignURL 签名文件原始 URL
func (fs *FileSystem) SignURL(ctx context.Context, file *model.File, ttl int64, isDownload bool) (string, error) {
fs.FileTarget = []model.File{*file}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, *file)
err := fs.resetPolicyToFirstFile(ctx)
if err != nil {
return "", err
}
// 签名最终URL
// 生成外链地址
source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, ttl, isDownload, fs.User.Group.SpeedLimit)
if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
}
return source, nil
}
// ResetFileIfNotExist 重设当前目标文件为 path如果当前目标为空
func (fs *FileSystem) ResetFileIfNotExist(ctx context.Context, path string) error {
// 找到文件
if len(fs.FileTarget) == 0 {
exist, file := fs.IsFileExist(path)
if !exist {
return ErrObjectNotExist
}
fs.FileTarget = []model.File{*file}
}
// 将当前存储策略重设为文件使用的
return fs.resetPolicyToFirstFile(ctx)
}
// ResetFileIfNotExist 重设当前目标文件为 id如果当前目标为空
func (fs *FileSystem) resetFileIDIfNotExist(ctx context.Context, id uint) error {
// 找到文件
if len(fs.FileTarget) == 0 {
file, err := model.GetFilesByIDs([]uint{id}, fs.User.ID)
if err != nil || len(file) == 0 {
return ErrObjectNotExist
}
fs.FileTarget = []model.File{file[0]}
}
// 如果上下文限制了父目录,则进行检查
if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok {
if parent.ID != fs.FileTarget[0].FolderID {
return ErrObjectNotExist
}
}
// 将当前存储策略重设为文件使用的
return fs.resetPolicyToFirstFile(ctx)
}
// resetPolicyToFirstFile 将当前存储策略重设为第一个目标文件文件使用的
func (fs *FileSystem) resetPolicyToFirstFile(ctx context.Context) error {
if len(fs.FileTarget) == 0 {
return ErrObjectNotExist
}
// 从机模式不进行操作
if conf.SystemConfig.Mode == "slave" {
return nil
}
fs.Policy = fs.FileTarget[0].GetPolicy()
err := fs.DispatchHandler()
if err != nil {
return err
}
return nil
}
// Search 搜索文件
func (fs *FileSystem) Search(ctx context.Context, keywords ...interface{}) ([]serializer.Object, error) {
parents := make([]uint, 0)
// 如果限定了根目录,则只在这个根目录下搜索。
if fs.Root != nil {
allFolders, err := model.GetRecursiveChildFolder([]uint{fs.Root.ID}, fs.User.ID, true)
if err != nil {
return nil, fmt.Errorf("failed to list all folders: %w", err)
}
for _, folder := range allFolders {
parents = append(parents, folder.ID)
}
}
files, _ := model.GetFilesByKeywords(fs.User.ID, parents, keywords...)
fs.SetTargetFile(&files)
return fs.listObjects(ctx, "/", files, nil, nil), nil
}

View File

@@ -0,0 +1,295 @@
package filesystem
import (
"errors"
"fmt"
"net/http"
"net/url"
"sync"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/qiniu"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
cossdk "github.com/tencentyun/cos-go-sdk-v5"
)
// FSPool 文件系统资源池
var FSPool = sync.Pool{
New: func() interface{} {
return &FileSystem{}
},
}
// FileSystem 管理文件的文件系统
type FileSystem struct {
// 文件系统所有者
User *model.User
// 操作文件使用的存储策略
Policy *model.Policy
// 当前正在处理的文件对象
FileTarget []model.File
// 当前正在处理的目录对象
DirTarget []model.Folder
// 相对根目录
Root *model.Folder
// 互斥锁
Lock sync.Mutex
/*
钩子函数
*/
Hooks map[string][]Hook
/*
文件系统处理适配器
*/
Handler driver.Handler
// 回收锁
recycleLock sync.Mutex
}
// getEmptyFS 从pool中获取新的FileSystem
func getEmptyFS() *FileSystem {
fs := FSPool.Get().(*FileSystem)
return fs
}
// Recycle 回收FileSystem资源
func (fs *FileSystem) Recycle() {
fs.recycleLock.Lock()
fs.reset()
FSPool.Put(fs)
}
// reset 重设文件系统,以便回收使用
func (fs *FileSystem) reset() {
fs.User = nil
fs.CleanTargets()
fs.Policy = nil
fs.Hooks = nil
fs.Handler = nil
fs.Root = nil
fs.Lock = sync.Mutex{}
fs.recycleLock = sync.Mutex{}
}
// NewFileSystem 初始化一个文件系统
func NewFileSystem(user *model.User) (*FileSystem, error) {
fs := getEmptyFS()
fs.User = user
fs.Policy = user.GetPolicyID(nil)
// 分配存储策略适配器
err := fs.DispatchHandler()
return fs, err
}
// NewAnonymousFileSystem 初始化匿名文件系统
func NewAnonymousFileSystem() (*FileSystem, error) {
fs := getEmptyFS()
fs.User = &model.User{}
// 如果是主机模式下,则为匿名文件系统分配游客用户组
if conf.SystemConfig.Mode == "master" {
anonymousGroup, err := model.GetGroupByID(3)
if err != nil {
return nil, err
}
fs.User.Group = anonymousGroup
} else {
// 从机模式下,分配本地策略处理器
fs.Handler = local.Driver{}
}
return fs, nil
}
// DispatchHandler 根据存储策略分配文件适配器
func (fs *FileSystem) DispatchHandler() error {
handler, err := getNewPolicyHandler(fs.Policy)
fs.Handler = handler
return err
}
// getNewPolicyHandler 根据存储策略类型字段获取处理器
func getNewPolicyHandler(policy *model.Policy) (driver.Handler, error) {
if policy == nil {
return nil, ErrUnknownPolicyType
}
switch policy.Type {
case "mock", "anonymous":
return nil, nil
case "local":
return local.Driver{
Policy: policy,
}, nil
case "remote":
return remote.NewDriver(policy)
case "qiniu":
return qiniu.NewDriver(policy), nil
case "oss":
return oss.NewDriver(policy)
case "upyun":
return upyun.Driver{
Policy: policy,
}, nil
case "onedrive":
return onedrive.NewDriver(policy)
case "cos":
u, _ := url.Parse(policy.Server)
b := &cossdk.BaseURL{BucketURL: u}
return cos.Driver{
Policy: policy,
Client: cossdk.NewClient(b, &http.Client{
Transport: &cossdk.AuthorizationTransport{
SecretID: policy.AccessKey,
SecretKey: policy.SecretKey,
},
}),
HTTPClient: request.NewClient(),
}, nil
case "s3":
return s3.NewDriver(policy)
case "googledrive":
return googledrive.NewDriver(policy)
default:
return nil, ErrUnknownPolicyType
}
}
// NewFileSystemFromContext 从gin.Context创建文件系统
func NewFileSystemFromContext(c *gin.Context) (*FileSystem, error) {
user, exist := c.Get("user")
if !exist {
return NewAnonymousFileSystem()
}
fs, err := NewFileSystem(user.(*model.User))
return fs, err
}
// NewFileSystemFromCallback 从gin.Context创建回调用文件系统
func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) {
fs, err := NewFileSystemFromContext(c)
if err != nil {
return nil, err
}
// 获取回调会话
callbackSessionRaw, ok := c.Get(UploadSessionCtx)
if !ok {
return nil, errors.New("upload session not exist")
}
callbackSession := callbackSessionRaw.(*serializer.UploadSession)
// 重新指向上传策略
fs.Policy = &callbackSession.Policy
err = fs.DispatchHandler()
return fs, err
}
// SwitchToSlaveHandler 将负责上传的 Handler 切换为从机节点
func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) {
fs.Handler = slaveinmaster.NewDriver(node, fs.Handler, fs.Policy)
}
// SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器
func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) {
switch fs.Policy.Type {
case "local":
fs.Policy.Type = "remote"
fs.Policy.Server = masterURL
fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID())
fs.Policy.SecretKey = master.DBModel().MasterKey
fs.DispatchHandler()
case "onedrive":
fs.Policy.MasterID = masterID
}
fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy)
}
// SetTargetFile 设置当前处理的目标文件
func (fs *FileSystem) SetTargetFile(files *[]model.File) {
if len(fs.FileTarget) == 0 {
fs.FileTarget = *files
} else {
fs.FileTarget = append(fs.FileTarget, *files...)
}
}
// SetTargetDir 设置当前处理的目标目录
func (fs *FileSystem) SetTargetDir(dirs *[]model.Folder) {
if len(fs.DirTarget) == 0 {
fs.DirTarget = *dirs
} else {
fs.DirTarget = append(fs.DirTarget, *dirs...)
}
}
// SetTargetFileByIDs 根据文件ID设置目标文件忽略用户ID
func (fs *FileSystem) SetTargetFileByIDs(ids []uint) error {
files, err := model.GetFilesByIDs(ids, 0)
if err != nil || len(files) == 0 {
return ErrFileExisted.WithError(err)
}
fs.SetTargetFile(&files)
return nil
}
// SetTargetByInterface 根据 model.File 或者 model.Folder 设置目标对象
// TODO 测试
func (fs *FileSystem) SetTargetByInterface(target interface{}) error {
if file, ok := target.(*model.File); ok {
fs.SetTargetFile(&[]model.File{*file})
return nil
}
if folder, ok := target.(*model.Folder); ok {
fs.SetTargetDir(&[]model.Folder{*folder})
return nil
}
return ErrObjectNotExist
}
// CleanTargets 清空目标
func (fs *FileSystem) CleanTargets() {
fs.FileTarget = fs.FileTarget[:0]
fs.DirTarget = fs.DirTarget[:0]
}
// SetPolicyFromPath 根据给定路径尝试设定偏好存储策略
func (fs *FileSystem) SetPolicyFromPath(filePath string) error {
_, parent := fs.getClosedParent(filePath)
// 尝试获取并重设存储策略
fs.Policy = fs.User.GetPolicyID(parent)
return fs.DispatchHandler()
}
// SetPolicyFromPreference 尝试设定偏好存储策略
func (fs *FileSystem) SetPolicyFromPreference(preference uint) error {
// 尝试获取并重设存储策略
fs.Policy = fs.User.GetPolicyByPreference(preference)
return fs.DispatchHandler()
}

View File

@@ -0,0 +1,44 @@
package fsctx
type key int
const (
// GinCtx Gin的上下文
GinCtx key = iota
// PathCtx 文件或目录的虚拟路径
PathCtx
// FileModelCtx 文件数据库模型
FileModelCtx
// FolderModelCtx 目录数据库模型
FolderModelCtx
// HTTPCtx HTTP请求的上下文
HTTPCtx
// UploadPolicyCtx 上传策略一般为slave模式下使用
UploadPolicyCtx
// UserCtx 用户
UserCtx
// ThumbSizeCtx 缩略图尺寸
ThumbSizeCtx
// FileSizeCtx 文件大小
FileSizeCtx
// ShareKeyCtx 分享文件的 HashID
ShareKeyCtx
// LimitParentCtx 限制父目录
LimitParentCtx
// IgnoreDirectoryConflictCtx 忽略目录重名冲突
IgnoreDirectoryConflictCtx
// RetryCtx 失败重试次数
RetryCtx
// ForceUsePublicEndpointCtx 强制使用公网 Endpoint
ForceUsePublicEndpointCtx
// CancelFuncCtx Context 取消函數
CancelFuncCtx
// 文件在从机节点中的路径
SlaveSrcPath
// Webdav目标名称
WebdavDstName
// WebDAVCtx WebDAV
WebDAVCtx
// WebDAV反代Url
WebDAVProxyUrlCtx
)

View File

@@ -0,0 +1,123 @@
package fsctx
import (
"errors"
"github.com/HFO4/aliyun-oss-go-sdk/oss"
"io"
"time"
)
type WriteMode int
const (
Overwrite WriteMode = 0x00001
// Append 只适用于本地策略
Append WriteMode = 0x00002
Nop WriteMode = 0x00004
)
type UploadTaskInfo struct {
Size uint64
MimeType string
FileName string
VirtualPath string
Mode WriteMode
Metadata map[string]string
LastModified *time.Time
SavePath string
UploadSessionID *string
AppendStart uint64
Model interface{}
Src string
}
// Get mimetype of uploaded file, if it's not defined, detect it from file name
func (u *UploadTaskInfo) DetectMimeType() string {
if u.MimeType != "" {
return u.MimeType
}
return oss.TypeByExtension(u.FileName)
}
// FileHeader 上传来的文件数据处理器
type FileHeader interface {
io.Reader
io.Closer
io.Seeker
Info() *UploadTaskInfo
SetSize(uint64)
SetModel(fileModel interface{})
Seekable() bool
}
// FileStream 用户传来的文件
type FileStream struct {
Mode WriteMode
LastModified *time.Time
Metadata map[string]string
File io.ReadCloser
Seeker io.Seeker
Size uint64
VirtualPath string
Name string
MimeType string
SavePath string
UploadSessionID *string
AppendStart uint64
Model interface{}
Src string
}
func (file *FileStream) Read(p []byte) (n int, err error) {
if file.File != nil {
return file.File.Read(p)
}
return 0, io.EOF
}
func (file *FileStream) Close() error {
if file.File != nil {
return file.File.Close()
}
return nil
}
func (file *FileStream) Seek(offset int64, whence int) (int64, error) {
if file.Seekable() {
return file.Seeker.Seek(offset, whence)
}
return 0, errors.New("no seeker")
}
func (file *FileStream) Seekable() bool {
return file.Seeker != nil
}
func (file *FileStream) Info() *UploadTaskInfo {
return &UploadTaskInfo{
Size: file.Size,
MimeType: file.MimeType,
FileName: file.Name,
VirtualPath: file.VirtualPath,
Mode: file.Mode,
Metadata: file.Metadata,
LastModified: file.LastModified,
SavePath: file.SavePath,
UploadSessionID: file.UploadSessionID,
AppendStart: file.AppendStart,
Model: file.Model,
Src: file.Src,
}
}
func (file *FileStream) SetSize(size uint64) {
file.Size = size
}
func (file *FileStream) SetModel(fileModel interface{}) {
file.Model = fileModel
}

320
pkg/filesystem/hooks.go Normal file
View File

@@ -0,0 +1,320 @@
package filesystem
import (
"context"
"io/ioutil"
"net/http"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Hook 钩子函数
type Hook func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error
// Use 注入钩子
func (fs *FileSystem) Use(name string, hook Hook) {
if fs.Hooks == nil {
fs.Hooks = make(map[string][]Hook)
}
if _, ok := fs.Hooks[name]; ok {
fs.Hooks[name] = append(fs.Hooks[name], hook)
return
}
fs.Hooks[name] = []Hook{hook}
}
// CleanHooks 清空钩子,name为空表示全部清空
func (fs *FileSystem) CleanHooks(name string) {
if name == "" {
fs.Hooks = nil
} else {
delete(fs.Hooks, name)
}
}
// Trigger 触发钩子,遇到第一个错误时
// 返回错误,后续钩子不会继续执行
func (fs *FileSystem) Trigger(ctx context.Context, name string, file fsctx.FileHeader) error {
if hooks, ok := fs.Hooks[name]; ok {
for _, hook := range hooks {
err := hook(ctx, fs, file)
if err != nil {
util.Log().Warning("Failed to execute hook%s", err)
return err
}
}
}
return nil
}
// HookValidateFile 一系列对文件检验的集合
func HookValidateFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
fileInfo := file.Info()
// 验证单文件尺寸
if !fs.ValidateFileSize(ctx, fileInfo.Size) {
return ErrFileSizeTooBig
}
// 验证文件名
if !fs.ValidateLegalName(ctx, fileInfo.FileName) {
return ErrIllegalObjectName
}
// 验证扩展名
if !fs.ValidateExtension(ctx, fileInfo.FileName) {
return ErrFileExtensionNotAllowed
}
return nil
}
// HookResetPolicy 重设存储策略为上下文已有文件
func HookResetPolicy(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
fs.Policy = originFile.GetPolicy()
return fs.DispatchHandler()
}
// HookValidateCapacity 验证用户容量
func HookValidateCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
// 验证并扣除容量
if fs.User.GetRemainingCapacity() < file.Info().Size {
return ErrInsufficientCapacity
}
return nil
}
// HookValidateCapacityDiff 根据原有文件和新文件的大小验证用户容量
func HookValidateCapacityDiff(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
originFile := ctx.Value(fsctx.FileModelCtx).(model.File)
newFileSize := newFile.Info().Size
if newFileSize > originFile.Size {
return HookValidateCapacity(ctx, fs, newFile)
}
return nil
}
// HookDeleteTempFile 删除已保存的临时文件
func HookDeleteTempFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
// 删除临时文件
_, err := fs.Handler.Delete(ctx, []string{file.Info().SavePath})
if err != nil {
util.Log().Warning("Failed to clean-up temp files: %s", err)
}
return nil
}
// HookCleanFileContent 清空文件内容
func HookCleanFileContent(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
// 清空内容
return fs.Handler.Put(ctx, &fsctx.FileStream{
File: ioutil.NopCloser(strings.NewReader("")),
SavePath: file.Info().SavePath,
Size: 0,
Mode: fsctx.Overwrite,
})
}
// HookClearFileSize 将原始文件的尺寸设为0
func HookClearFileSize(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
return originFile.UpdateSize(0)
}
// HookCancelContext 取消上下文
func HookCancelContext(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
cancelFunc, ok := ctx.Value(fsctx.CancelFuncCtx).(context.CancelFunc)
if ok {
cancelFunc()
}
return nil
}
// HookUpdateSourceName 更新文件SourceName
func HookUpdateSourceName(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
return originFile.UpdateSourceName(originFile.SourceName)
}
// GenericAfterUpdate 文件内容更新后
func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
// 更新文件尺寸
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
newFile.SetModel(&originFile)
err := originFile.UpdateSize(newFile.Info().Size)
if err != nil {
return err
}
return nil
}
// SlaveAfterUpload Slave模式下上传完成钩子
func SlaveAfterUpload(session *serializer.UploadSession) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
if session.Callback == "" {
return nil
}
// 发送回调请求
callbackBody := serializer.UploadCallback{}
return cluster.RemoteCallback(session.Callback, callbackBody)
}
}
// GenericAfterUpload 文件上传完成后,包含数据库操作
func GenericAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
// 创建或查找根目录
folder, err := fs.CreateDirectory(ctx, fileInfo.VirtualPath)
if err != nil {
return err
}
// 检查文件是否存在
if ok, file := fs.IsChildFileExist(
folder,
fileInfo.FileName,
); ok {
if file.UploadSessionID != nil {
return ErrFileUploadSessionExisted
}
return ErrFileExisted
}
// 向数据库中插入记录
file, err := fs.AddFile(ctx, folder, fileHeader)
if err != nil {
return ErrInsertFileRecord
}
fileHeader.SetModel(file)
return nil
}
// HookGenerateThumb 生成缩略图
// func HookGenerateThumb(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
// // 异步尝试生成缩略图
// fileMode := fileHeader.Info().Model.(*model.File)
// if fs.Policy.IsThumbGenerateNeeded() {
// fs.recycleLock.Lock()
// go func() {
// defer fs.recycleLock.Unlock()
// _, _ = fs.Handler.Delete(ctx, []string{fileMode.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")})
// fs.GenerateThumbnail(ctx, fileMode)
// }()
// }
// return nil
// }
// HookClearFileHeaderSize 将FileHeader大小设定为0
func HookClearFileHeaderSize(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileHeader.SetSize(0)
return nil
}
// HookTruncateFileTo 将物理文件截断至 size
func HookTruncateFileTo(size uint64) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
if handler, ok := fs.Handler.(local.Driver); ok {
return handler.Truncate(ctx, fileHeader.Info().SavePath, size)
}
return nil
}
}
// HookChunkUploadFinished 单个分片上传结束后
func HookChunkUploaded(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
// 更新文件大小
return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart + fileInfo.Size)
}
// HookChunkUploadFailed 单个分片上传失败后
func HookChunkUploadFailed(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
// 更新文件大小
return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart)
}
// HookPopPlaceholderToFile 将占位文件提升为正式文件
func HookPopPlaceholderToFile(picInfo string) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
fileModel := fileInfo.Model.(*model.File)
return fileModel.PopChunkToFile(fileInfo.LastModified, picInfo)
}
}
// HookChunkUploadFinished 分片上传结束后处理文件
func HookDeleteUploadSession(id string) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
cache.Deletes([]string{id}, UploadSessionCachePrefix)
return nil
}
}
// NewWebdavAfterUploadHook 每次创建一个新的钩子函数 rclone 在 PUT 请求里有 OC-Checksum 字符串
// 和 X-OC-Mtime
func NewWebdavAfterUploadHook(request *http.Request) func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
var modtime time.Time
if timeVal := request.Header.Get("X-OC-Mtime"); timeVal != "" {
timeUnix, err := strconv.ParseInt(timeVal, 10, 64)
if err == nil {
modtime = time.Unix(timeUnix, 0)
}
}
checksum := request.Header.Get("OC-Checksum")
return func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
file := newFile.Info().Model.(*model.File)
if !modtime.IsZero() {
err := model.DB.Model(file).UpdateColumn("updated_at", modtime).Error
if err != nil {
return err
}
}
if checksum != "" {
return file.UpdateMetadata(map[string]string{
model.ChecksumMetadataKey: checksum,
})
}
return nil
}
}

218
pkg/filesystem/image.go Normal file
View File

@@ -0,0 +1,218 @@
package filesystem
import (
"context"
"errors"
"fmt"
"os"
"sync"
"runtime"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/thumb"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* ================
图像处理相关
================
*/
// GetThumb 获取文件的缩略图
func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentResponse, error) {
// 根据 ID 查找文件
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return nil, ErrObjectNotExist
}
file := fs.FileTarget[0]
if !file.ShouldLoadThumb() {
return nil, ErrObjectNotExist
}
w, h := fs.GenerateThumbnailSize(0, 0)
ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, [2]uint{w, h})
ctx = context.WithValue(ctx, fsctx.FileModelCtx, file)
res, err := fs.Handler.Thumb(ctx, &file)
if errors.Is(err, driver.ErrorThumbNotExist) {
// Regenerate thumb if the thumb is not initialized yet
if generateErr := fs.generateThumbnail(ctx, &file); generateErr == nil {
res, err = fs.Handler.Thumb(ctx, &file)
} else {
err = generateErr
}
} else if errors.Is(err, driver.ErrorThumbNotSupported) {
// Policy handler explicitly indicates thumb not available, check if proxy is enabled
if fs.Policy.CouldProxyThumb() {
// if thumb id marked as existed, redirect to "sidecar" thumb file.
if file.MetadataSerialized != nil &&
file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusExist {
// redirect to sidecar file
res = &response.ContentResponse{
Redirect: true,
}
res.URL, err = fs.Handler.Source(ctx, file.ThumbFile(), int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
} else {
// if not exist, generate and upload the sidecar thumb.
if err = fs.generateThumbnail(ctx, &file); err == nil {
return fs.GetThumb(ctx, id)
}
}
} else {
// thumb not supported and proxy is disabled, mark as not available
_ = updateThumbStatus(&file, model.ThumbStatusNotAvailable)
}
}
if err == nil && conf.SystemConfig.Mode == "master" {
res.MaxAge = model.GetIntSetting("preview_timeout", 60)
}
return res, err
}
// thumbPool 要使用的任务池
var thumbPool *Pool
var once sync.Once
// Pool 带有最大配额的任务池
type Pool struct {
// 容量
worker chan int
}
// Init 初始化任务池
func getThumbWorker() *Pool {
once.Do(func() {
maxWorker := model.GetIntSetting("thumb_max_task_count", -1)
if maxWorker <= 0 {
maxWorker = runtime.GOMAXPROCS(0)
}
thumbPool = &Pool{
worker: make(chan int, maxWorker),
}
util.Log().Debug("Initialize thumbnails task queue with: WorkerNum = %d", maxWorker)
})
return thumbPool
}
func (pool *Pool) addWorker() {
pool.worker <- 1
util.Log().Debug("Worker added to thumbnails task queue.")
}
func (pool *Pool) releaseWorker() {
util.Log().Debug("Worker released from thumbnails task queue.")
<-pool.worker
}
// generateThumbnail generates thumb for given file, upload the thumb file back with given suffix
func (fs *FileSystem) generateThumbnail(ctx context.Context, file *model.File) error {
// 新建上下文
newCtx, cancel := context.WithCancel(context.Background())
defer cancel()
// TODO: check file size
if file.Size > uint64(model.GetIntSetting("thumb_max_src_size", 31457280)) {
_ = updateThumbStatus(file, model.ThumbStatusNotAvailable)
return errors.New("file too large")
}
getThumbWorker().addWorker()
defer getThumbWorker().releaseWorker()
// 获取文件数据
source, err := fs.Handler.Get(newCtx, file.SourceName)
if err != nil {
return fmt.Errorf("faield to fetch original file %q: %w", file.SourceName, err)
}
defer source.Close()
// Provide file source path for local policy files
src := ""
if conf.SystemConfig.Mode == "slave" || file.GetPolicy().Type == "local" {
src = file.SourceName
}
thumbRes, err := thumb.Generators.Generate(ctx, source, src, file.Name, model.GetSettingByNames(
"thumb_width",
"thumb_height",
"thumb_builtin_enabled",
"thumb_vips_enabled",
"thumb_ffmpeg_enabled",
"thumb_libreoffice_enabled",
))
if err != nil {
_ = updateThumbStatus(file, model.ThumbStatusNotAvailable)
return fmt.Errorf("failed to generate thumb for %q: %w", file.Name, err)
}
defer os.Remove(thumbRes.Path)
thumbFile, err := os.Open(thumbRes.Path)
if err != nil {
return fmt.Errorf("failed to open temp thumb %q: %w", thumbRes.Path, err)
}
defer thumbFile.Close()
fileInfo, err := thumbFile.Stat()
if err != nil {
return fmt.Errorf("failed to stat temp thumb %q: %w", thumbRes.Path, err)
}
if err = fs.Handler.Put(newCtx, &fsctx.FileStream{
Mode: fsctx.Overwrite,
File: thumbFile,
Seeker: thumbFile,
Size: uint64(fileInfo.Size()),
SavePath: file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb"),
}); err != nil {
return fmt.Errorf("failed to save thumb for %q: %w", file.Name, err)
}
if model.IsTrueVal(model.GetSettingByName("thumb_gc_after_gen")) {
util.Log().Debug("generateThumbnail runtime.GC")
runtime.GC()
}
// Mark this file as thumb available
err = updateThumbStatus(file, model.ThumbStatusExist)
// 失败时删除缩略图文件
if err != nil {
_, _ = fs.Handler.Delete(newCtx, []string{file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")})
}
return nil
}
// GenerateThumbnailSize 获取要生成的缩略图的尺寸
func (fs *FileSystem) GenerateThumbnailSize(w, h int) (uint, uint) {
return uint(model.GetIntSetting("thumb_width", 400)), uint(model.GetIntSetting("thumb_height", 300))
}
func updateThumbStatus(file *model.File, status string) error {
if file.Model.ID > 0 {
meta := map[string]string{
model.ThumbStatusMetadataKey: status,
}
if status == model.ThumbStatusExist {
meta[model.ThumbSidecarMetadataKey] = "true"
}
return file.UpdateMetadata(meta)
} else {
if file.MetadataSerialized == nil {
file.MetadataSerialized = map[string]string{}
}
file.MetadataSerialized[model.ThumbStatusMetadataKey] = status
}
return nil
}

479
pkg/filesystem/manage.go Normal file
View File

@@ -0,0 +1,479 @@
package filesystem
import (
"context"
"fmt"
"path"
"strings"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* =================
文件/目录管理
=================
*/
// Rename 重命名对象
func (fs *FileSystem) Rename(ctx context.Context, dir, file []uint, new string) (err error) {
// 验证新名字
if !fs.ValidateLegalName(ctx, new) || (len(file) > 0 && !fs.ValidateExtension(ctx, new)) {
return ErrIllegalObjectName
}
// 如果源对象是文件
if len(file) > 0 {
fileObject, err := model.GetFilesByIDs([]uint{file[0]}, fs.User.ID)
if err != nil || len(fileObject) == 0 {
return ErrPathNotExist
}
err = fileObject[0].Rename(new)
if err != nil {
return ErrFileExisted
}
return nil
}
if len(dir) > 0 {
folderObject, err := model.GetFoldersByIDs([]uint{dir[0]}, fs.User.ID)
if err != nil || len(folderObject) == 0 {
return ErrPathNotExist
}
err = folderObject[0].Rename(new)
if err != nil {
return ErrFileExisted
}
return nil
}
return ErrPathNotExist
}
// Copy 复制src目录下的文件或目录到dst
// 暂时只支持单文件
func (fs *FileSystem) Copy(ctx context.Context, dirs, files []uint, src, dst string) error {
// 获取目的目录
isDstExist, dstFolder := fs.IsPathExist(dst)
isSrcExist, srcFolder := fs.IsPathExist(src)
// 不存在时返回空的结果
if !isDstExist || !isSrcExist {
return ErrPathNotExist
}
// 记录复制的文件的总容量
var newUsedStorage uint64
// 设置webdav目标名
if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok {
dstFolder.WebdavDstName = dstName
}
// 复制目录
if len(dirs) > 0 {
subFileSizes, err := srcFolder.CopyFolderTo(dirs[0], dstFolder)
if err != nil {
return ErrObjectNotExist.WithError(err)
}
newUsedStorage += subFileSizes
}
// 复制文件
if len(files) > 0 {
subFileSizes, err := srcFolder.MoveOrCopyFileTo(files, dstFolder, true)
if err != nil {
return ErrObjectNotExist.WithError(err)
}
newUsedStorage += subFileSizes
}
// 扣除容量
fs.User.IncreaseStorageWithoutCheck(newUsedStorage)
return nil
}
// Move 移动文件和目录, 将id列表dirs和files从src移动至dst
func (fs *FileSystem) Move(ctx context.Context, dirs, files []uint, src, dst string) error {
// 获取目的目录
isDstExist, dstFolder := fs.IsPathExist(dst)
isSrcExist, srcFolder := fs.IsPathExist(src)
// 不存在时返回空的结果
if !isDstExist || !isSrcExist {
return ErrPathNotExist
}
// 设置webdav目标名
if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok {
dstFolder.WebdavDstName = dstName
}
// 处理目录及子文件移动
err := srcFolder.MoveFolderTo(dirs, dstFolder)
if err != nil {
return ErrFileExisted.WithError(err)
}
// 处理文件移动
_, err = srcFolder.MoveOrCopyFileTo(files, dstFolder, false)
if err != nil {
return ErrFileExisted.WithError(err)
}
// 移动文件
return err
}
// Delete 递归删除对象, force 为 true 时强制删除文件记录,忽略物理删除是否成功;
// unlink 为 true 时只删除虚拟文件系统的文件记录,不删除物理文件。
func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint, force, unlink bool) error {
// 已删除的文件ID
var deletedFiles = make([]*model.File, 0, len(fs.FileTarget))
// 删除失败的文件的父目录ID
// 所有文件的ID
var allFiles = make([]*model.File, 0, len(fs.FileTarget))
// 列出要删除的目录
if len(dirs) > 0 {
err := fs.ListDeleteDirs(ctx, dirs)
if err != nil {
return err
}
}
// 列出要删除的文件
if len(files) > 0 {
err := fs.ListDeleteFiles(ctx, files)
if err != nil {
return err
}
}
// 去除待删除文件中包含软连接的部分
filesToBeDelete, err := model.RemoveFilesWithSoftLinks(fs.FileTarget)
if err != nil {
return ErrDBListObjects.WithError(err)
}
// 根据存储策略将文件分组
policyGroup := fs.GroupFileByPolicy(ctx, filesToBeDelete)
// 按照存储策略分组删除对象
failed := make(map[uint][]string)
if !unlink {
failed = fs.deleteGroupedFile(ctx, policyGroup)
}
// 整理删除结果
for i := 0; i < len(fs.FileTarget); i++ {
if !util.ContainsString(failed[fs.FileTarget[i].PolicyID], fs.FileTarget[i].SourceName) {
// 已成功删除的文件
deletedFiles = append(deletedFiles, &fs.FileTarget[i])
}
// 全部文件
allFiles = append(allFiles, &fs.FileTarget[i])
}
// 如果强制删除,则将全部文件视为删除成功
if force {
deletedFiles = allFiles
}
// 删除文件记录
err = model.DeleteFiles(deletedFiles, fs.User.ID)
if err != nil {
return ErrDBDeleteObjects.WithError(err)
}
// 删除文件记录对应的分享记录
// TODO 先取消分享再删除文件
deletedFileIDs := make([]uint, len(deletedFiles))
for k, file := range deletedFiles {
deletedFileIDs[k] = file.ID
}
model.DeleteShareBySourceIDs(deletedFileIDs, false)
// 如果文件全部删除成功,继续删除目录
if len(deletedFiles) == len(allFiles) {
var allFolderIDs = make([]uint, 0, len(fs.DirTarget))
for _, value := range fs.DirTarget {
allFolderIDs = append(allFolderIDs, value.ID)
}
err = model.DeleteFolderByIDs(allFolderIDs)
if err != nil {
return ErrDBDeleteObjects.WithError(err)
}
// 删除目录记录对应的分享记录
model.DeleteShareBySourceIDs(allFolderIDs, true)
}
if notDeleted := len(fs.FileTarget) - len(deletedFiles); notDeleted > 0 {
return serializer.NewError(
serializer.CodeNotFullySuccess,
fmt.Sprintf("Failed to delete %d file(s).", notDeleted),
nil,
)
}
return nil
}
// ListDeleteDirs 递归列出要删除目录,及目录下所有文件
func (fs *FileSystem) ListDeleteDirs(ctx context.Context, ids []uint) error {
// 列出所有递归子目录
folders, err := model.GetRecursiveChildFolder(ids, fs.User.ID, true)
if err != nil {
return ErrDBListObjects.WithError(err)
}
// 忽略根目录
for i := 0; i < len(folders); i++ {
if folders[i].ParentID == nil {
folders = append(folders[:i], folders[i+1:]...)
break
}
}
fs.SetTargetDir(&folders)
// 检索目录下的子文件
files, err := model.GetChildFilesOfFolders(&folders)
if err != nil {
return ErrDBListObjects.WithError(err)
}
fs.SetTargetFile(&files)
return nil
}
// ListDeleteFiles 根据给定的路径列出要删除的文件
func (fs *FileSystem) ListDeleteFiles(ctx context.Context, ids []uint) error {
files, err := model.GetFilesByIDs(ids, fs.User.ID)
if err != nil {
return ErrDBListObjects.WithError(err)
}
fs.SetTargetFile(&files)
return nil
}
// List 列出路径下的内容,
// pathProcessor为最终对象路径的处理钩子。
// 有些情况下(如在分享页面列对象)时,
// 路径需要截取掉被分享目录路径之前的部分。
func (fs *FileSystem) List(ctx context.Context, dirPath string, pathProcessor func(string) string) ([]serializer.Object, error) {
// 获取父目录
isExist, folder := fs.IsPathExist(dirPath)
if !isExist {
return nil, ErrPathNotExist
}
fs.SetTargetDir(&[]model.Folder{*folder})
var parentPath = path.Join(folder.Position, folder.Name)
var childFolders []model.Folder
var childFiles []model.File
// 获取子目录
childFolders, _ = folder.GetChildFolder()
// 获取子文件
childFiles, _ = folder.GetChildFiles()
return fs.listObjects(ctx, parentPath, childFiles, childFolders, pathProcessor), nil
}
// ListPhysical 列出存储策略中的外部目录
// TODO:测试
func (fs *FileSystem) ListPhysical(ctx context.Context, dirPath string) ([]serializer.Object, error) {
if err := fs.DispatchHandler(); fs.Policy == nil || err != nil {
return nil, ErrUnknownPolicyType
}
// 存储策略不支持列取时,返回空结果
if !fs.Policy.CanStructureBeListed() {
return nil, nil
}
// 列取路径
objects, err := fs.Handler.List(ctx, dirPath, false)
if err != nil {
return nil, err
}
var (
folders []model.Folder
)
for _, object := range objects {
if object.IsDir {
folders = append(folders, model.Folder{
Name: object.Name,
})
}
}
return fs.listObjects(ctx, dirPath, nil, folders, nil), nil
}
func (fs *FileSystem) listObjects(ctx context.Context, parent string, files []model.File, folders []model.Folder, pathProcessor func(string) string) []serializer.Object {
// 分享文件的ID
shareKey := ""
if key, ok := ctx.Value(fsctx.ShareKeyCtx).(string); ok {
shareKey = key
}
// 汇总处理结果
objects := make([]serializer.Object, 0, len(files)+len(folders))
// 所有对象的父目录
var processedPath string
for _, subFolder := range folders {
// 路径处理钩子,
// 所有对象父目录都是一样的,所以只处理一次
if processedPath == "" {
if pathProcessor != nil {
processedPath = pathProcessor(parent)
} else {
processedPath = parent
}
}
objects = append(objects, serializer.Object{
ID: hashid.HashID(subFolder.ID, hashid.FolderID),
Name: subFolder.Name,
Path: processedPath,
Size: 0,
Type: "dir",
Date: subFolder.UpdatedAt,
CreateDate: subFolder.CreatedAt,
})
}
for _, file := range files {
if processedPath == "" {
if pathProcessor != nil {
processedPath = pathProcessor(parent)
} else {
processedPath = parent
}
}
if file.UploadSessionID == nil {
newFile := serializer.Object{
ID: hashid.HashID(file.ID, hashid.FileID),
Name: file.Name,
Path: processedPath,
Thumb: file.ShouldLoadThumb(),
Size: file.Size,
Type: "file",
Date: file.UpdatedAt,
SourceEnabled: file.GetPolicy().IsOriginLinkEnable,
CreateDate: file.CreatedAt,
}
if shareKey != "" {
newFile.Key = shareKey
}
objects = append(objects, newFile)
}
}
return objects
}
// CreateDirectory 根据给定的完整创建目录,支持递归创建。如果目录已存在,则直接
// 返回已存在的目录。
func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*model.Folder, error) {
if fullPath == "." || fullPath == "" {
return nil, ErrRootProtected
}
if fullPath == "/" {
if fs.Root != nil {
return fs.Root, nil
}
return fs.User.Root()
}
// 获取要创建目录的父路径和目录名
fullPath = path.Clean(fullPath)
base := path.Dir(fullPath)
dir := path.Base(fullPath)
// 去掉结尾空格
dir = strings.TrimRight(dir, " ")
// 检查目录名是否合法
if !fs.ValidateLegalName(ctx, dir) {
return nil, ErrIllegalObjectName
}
// 父目录是否存在
isExist, parent := fs.IsPathExist(base)
if !isExist {
newParent, err := fs.CreateDirectory(ctx, base)
if err != nil {
return nil, err
}
parent = newParent
}
// 是否有同名文件
if ok, _ := fs.IsChildFileExist(parent, dir); ok {
return nil, ErrFileExisted
}
// 创建目录
newFolder := model.Folder{
Name: dir,
ParentID: &parent.ID,
OwnerID: fs.User.ID,
}
_, err := newFolder.Create()
if err != nil {
return nil, fmt.Errorf("failed to create folder: %w", err)
}
return &newFolder, nil
}
// SaveTo 将别人分享的文件转存到目标路径下
func (fs *FileSystem) SaveTo(ctx context.Context, path string) error {
// 获取父目录
isExist, folder := fs.IsPathExist(path)
if !isExist {
return ErrPathNotExist
}
var (
totalSize uint64
err error
)
if len(fs.DirTarget) > 0 {
totalSize, err = fs.DirTarget[0].CopyFolderTo(fs.DirTarget[0].ID, folder)
} else {
parent := model.Folder{
OwnerID: fs.FileTarget[0].UserID,
}
parent.ID = fs.FileTarget[0].FolderID
totalSize, err = parent.MoveOrCopyFileTo([]uint{fs.FileTarget[0].ID}, folder, true)
}
// 扣除用户容量
fs.User.IncreaseStorageWithoutCheck(totalSize)
if err != nil {
return ErrFileExisted.WithError(err)
}
return nil
}

View File

@@ -0,0 +1,25 @@
package oauth
import "sync"
// CredentialLock 针对存储策略凭证的锁
type CredentialLock interface {
Lock(uint)
Unlock(uint)
}
var GlobalMutex = mutexMap{}
type mutexMap struct {
locks sync.Map
}
func (m *mutexMap) Lock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Lock()
}
func (m *mutexMap) Unlock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Unlock()
}

View File

@@ -0,0 +1,8 @@
package oauth
import "context"
type TokenProvider interface {
UpdateCredential(ctx context.Context, isSlave bool) error
AccessToken() string
}

84
pkg/filesystem/path.go Normal file
View File

@@ -0,0 +1,84 @@
package filesystem
import (
"path"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* =================
路径/目录相关
=================
*/
// IsPathExist 返回给定目录是否存在
// 如果存在就返回目录
func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) {
tracedEnd, currentFolder := fs.getClosedParent(path)
if tracedEnd {
return true, currentFolder
}
return false, nil
}
func (fs *FileSystem) getClosedParent(path string) (bool, *model.Folder) {
pathList := util.SplitPath(path)
if len(pathList) == 0 {
return false, nil
}
// 递归步入目录
var currentFolder *model.Folder
// 如果已设定跟目录对象,则从给定目录向下遍历
if fs.Root != nil {
currentFolder = fs.Root
}
for _, folderName := range pathList {
var err error
// 根目录
if folderName == "/" {
if currentFolder != nil {
continue
}
currentFolder, err = fs.User.Root()
if err != nil {
return false, nil
}
} else {
nextFolder, err := currentFolder.GetChild(folderName)
if err != nil {
return false, currentFolder
}
currentFolder = nextFolder
}
}
return true, currentFolder
}
// IsFileExist 返回给定路径的文件是否存在
func (fs *FileSystem) IsFileExist(fullPath string) (bool, *model.File) {
basePath := path.Dir(fullPath)
fileName := path.Base(fullPath)
// 获得父目录
exist, parent := fs.IsPathExist(basePath)
if !exist {
return false, nil
}
file, err := parent.GetChildFile(fileName)
return err == nil, file
}
// IsChildFileExist 确定folder目录下是否有名为name的文件
func (fs *FileSystem) IsChildFileExist(folder *model.Folder, name string) (bool, *model.File) {
file, err := folder.GetChildFile(name)
return err == nil, file
}

102
pkg/filesystem/relocate.go Normal file
View File

@@ -0,0 +1,102 @@
package filesystem
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* ================
存储策略迁移
================
*/
// Relocate 将目标文件转移到当前存储策略下
func (fs *FileSystem) Relocate(ctx context.Context, files []model.File, policy *model.Policy) error {
// 重设存储策略为要转移的目的策略
fs.Policy = policy
if err := fs.DispatchHandler(); err != nil {
return err
}
// 将目前文件根据存储策略分组
fileGroup := fs.GroupFileByPolicy(ctx, files)
// 按照存储策略分组处理每个文件
for _, fileList := range fileGroup {
// 如果存储策略一样,则跳过
if fileList[0].GetPolicy().ID == fs.Policy.ID {
util.Log().Debug("Skip relocating %d file(s), since they are already in desired policy.",
len(fileList))
continue
}
// 获取当前存储策略的处理器
currentPolicy, _ := model.GetPolicyByID(fileList[0].PolicyID)
currentHandler, err := getNewPolicyHandler(&currentPolicy)
if err != nil {
return err
}
// 记录转移完毕需要删除的文件
toBeDeleted := make([]model.File, 0, len(fileList))
// 循环处理每一个文件
// for id, r := 0, len(fileList); id < r; id++ {
for id, _ := range fileList {
// 验证文件是否符合新存储策略的规定
if err := HookValidateFile(ctx, fs, fileList[id]); err != nil {
util.Log().Debug("File %q failed to pass validators in new policy %q, skipping.",
fileList[id].Name, err)
continue
}
// 为文件生成新存储策略下的物理路径
savePath := fs.GenerateSavePath(ctx, fileList[id])
// 获取原始文件
src, err := currentHandler.Get(ctx, fileList[id].SourceName)
if err != nil {
util.Log().Debug("Failed to get file %q: %s, skipping.",
fileList[id].Name, err)
continue
}
// 转存到新的存储策略
if err := fs.Handler.Put(ctx, &fsctx.FileStream{
File: src,
SavePath: savePath,
Size: fileList[id].Size,
}); err != nil {
util.Log().Debug("Failed to migrate file %q: %s, skipping.",
fileList[id].Name, err)
continue
}
toBeDeleted = append(toBeDeleted, *fileList[id])
// 更新文件信息
fileList[id].Relocate(savePath, fs.Policy.ID)
}
// 排除带有软链接的文件
toBeDeletedClean, err := model.RemoveFilesWithSoftLinks(toBeDeleted)
if err != nil {
util.Log().Warning("Failed to check soft links: %s", err)
}
deleteSourceNames := make([]string, 0, len(toBeDeleted))
for i := 0; i < len(toBeDeletedClean); i++ {
deleteSourceNames = append(deleteSourceNames, toBeDeletedClean[i].SourceName)
}
// 删除原始策略中的文件
if _, err := currentHandler.Delete(ctx, deleteSourceNames); err != nil {
util.Log().Warning("Cannot delete files in origin policy after relocating: %s", err)
}
}
return nil
}

View File

@@ -0,0 +1,32 @@
package response
import (
"io"
"time"
)
// ContentResponse 获取文件内容类方法的通用返回值。
// 有些上传策略需要重定向,
// 有些直接写文件数据到浏览器
type ContentResponse struct {
Redirect bool
Content RSCloser
URL string
MaxAge int
}
// RSCloser 存储策略适配器返回的文件流有些策略需要带有Closer
type RSCloser interface {
io.ReadSeeker
io.Closer
}
// Object 列出文件、目录时返回的对象
type Object struct {
Name string `json:"name"`
RelativePath string `json:"relative_path"`
Source string `json:"source"`
Size uint64 `json:"size"`
IsDir bool `json:"is_dir"`
LastModify time.Time `json:"last_modify"`
}

View File

View File

Binary file not shown.

243
pkg/filesystem/upload.go Normal file
View File

@@ -0,0 +1,243 @@
package filesystem
import (
"context"
"os"
"path"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
)
/* ================
上传处理相关
================
*/
const (
UploadSessionMetaKey = "upload_session"
UploadSessionCtx = "uploadSession"
UserCtx = "user"
UploadSessionCachePrefix = "callback_"
)
// Upload 上传文件
func (fs *FileSystem) Upload(ctx context.Context, file *fsctx.FileStream) (err error) {
// 上传前的钩子
err = fs.Trigger(ctx, "BeforeUpload", file)
if err != nil {
request.BlackHole(file)
return err
}
// 生成文件名和路径,
var savePath string
if file.SavePath == "" {
// 如果是更新操作就从上下文中获取
if originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
savePath = originFile.SourceName
} else {
savePath = fs.GenerateSavePath(ctx, file)
}
file.SavePath = savePath
}
// 保存文件
if file.Mode&fsctx.Nop != fsctx.Nop {
// 处理客户端未完成上传时,关闭连接
go fs.CancelUpload(ctx, savePath, file)
err = fs.Handler.Put(ctx, file)
if err != nil {
fs.Trigger(ctx, "AfterUploadFailed", file)
return err
}
}
// 上传完成后的钩子
err = fs.Trigger(ctx, "AfterUpload", file)
if err != nil {
// 上传完成后续处理失败
followUpErr := fs.Trigger(ctx, "AfterValidateFailed", file)
// 失败后再失败...
if followUpErr != nil {
util.Log().Debug("AfterValidateFailed hook execution failed: %s", followUpErr)
}
return err
}
return nil
}
// GenerateSavePath 生成要存放文件的路径
// TODO 完善测试
func (fs *FileSystem) GenerateSavePath(ctx context.Context, file fsctx.FileHeader) string {
fileInfo := file.Info()
return path.Join(
fs.Policy.GeneratePath(
fs.User.Model.ID,
fileInfo.VirtualPath,
),
fs.Policy.GenerateFileName(
fs.User.Model.ID,
fileInfo.FileName,
),
)
}
// CancelUpload 监测客户端取消上传
func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file fsctx.FileHeader) {
var reqContext context.Context
if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok {
reqContext = ginCtx.Request.Context()
} else if reqCtx, ok := ctx.Value(fsctx.HTTPCtx).(context.Context); ok {
reqContext = reqCtx
} else {
return
}
select {
case <-reqContext.Done():
select {
case <-ctx.Done():
// 客户端正常关闭,不执行操作
default:
// 客户端取消上传,删除临时文件
util.Log().Debug("Client canceled upload.")
if fs.Hooks["AfterUploadCanceled"] == nil {
return
}
err := fs.Trigger(ctx, "AfterUploadCanceled", file)
if err != nil {
util.Log().Debug("AfterUploadCanceled hook execution failed: %s", err)
}
}
}
}
// CreateUploadSession 创建上传会话
func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileStream) (*serializer.UploadCredential, error) {
// 获取相关有效期设置
callBackSessionTTL := model.GetIntSetting("upload_session_timeout", 86400)
callbackKey := uuid.Must(uuid.NewV4()).String()
fileSize := file.Size
// 创建占位的文件,同时校验文件信息
file.Mode = fsctx.Nop
if callbackKey != "" {
file.UploadSessionID = &callbackKey
}
fs.Use("BeforeUpload", HookValidateFile)
fs.Use("BeforeUpload", HookValidateCapacity)
// 验证文件规格
if err := fs.Upload(ctx, file); err != nil {
return nil, err
}
uploadSession := &serializer.UploadSession{
Key: callbackKey,
UID: fs.User.ID,
Policy: *fs.Policy,
VirtualPath: file.VirtualPath,
Name: file.Name,
Size: fileSize,
SavePath: file.SavePath,
LastModified: file.LastModified,
CallbackSecret: util.RandStringRunes(32),
}
// 获取上传凭证
credential, err := fs.Handler.Token(ctx, int64(callBackSessionTTL), uploadSession, file)
if err != nil {
return nil, err
}
// 创建占位符
if !fs.Policy.IsUploadPlaceholderWithSize() {
fs.Use("AfterUpload", HookClearFileHeaderSize)
}
fs.Use("AfterUpload", GenericAfterUpload)
ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true)
if err := fs.Upload(ctx, file); err != nil {
return nil, err
}
// 创建回调会话
err = cache.Set(
UploadSessionCachePrefix+callbackKey,
*uploadSession,
callBackSessionTTL,
)
if err != nil {
return nil, err
}
// 补全上传凭证其他信息
credential.Expires = time.Now().Add(time.Duration(callBackSessionTTL) * time.Second).Unix()
return credential, nil
}
// UploadFromStream 从文件流上传文件
func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error {
// 给文件系统分配钩子
fs.Lock.Lock()
if resetPolicy {
err := fs.SetPolicyFromPath(file.VirtualPath)
if err != nil {
return err
}
}
if fs.Hooks == nil {
fs.Use("BeforeUpload", HookValidateFile)
fs.Use("BeforeUpload", HookValidateCapacity)
fs.Use("AfterUploadCanceled", HookDeleteTempFile)
fs.Use("AfterUpload", GenericAfterUpload)
fs.Use("AfterValidateFailed", HookDeleteTempFile)
}
fs.Lock.Unlock()
// 开始上传
return fs.Upload(ctx, file)
}
// UploadFromPath 将本机已有文件上传到用户的文件系统
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, mode fsctx.WriteMode) error {
file, err := os.Open(util.RelativePath(src))
if err != nil {
return err
}
defer file.Close()
// 获取源文件大小
fi, err := file.Stat()
if err != nil {
return err
}
size := fi.Size()
// 开始上传
return fs.UploadFromStream(ctx, &fsctx.FileStream{
File: file,
Seeker: file,
Size: uint64(size),
Name: path.Base(dst),
VirtualPath: path.Dir(dst),
Mode: mode,
}, true)
}

View File

@@ -0,0 +1,66 @@
package filesystem
import (
"context"
"strings"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* ==========
验证器
==========
*/
// 文件/路径名保留字符
var reservedCharacter = []string{"\\", "?", "*", "<", "\"", ":", ">", "/", "|"}
// ValidateLegalName 验证文件名/文件夹名是否合法
func (fs *FileSystem) ValidateLegalName(ctx context.Context, name string) bool {
// 是否包含保留字符
for _, value := range reservedCharacter {
if strings.Contains(name, value) {
return false
}
}
// 是否超出长度限制
if len(name) >= 256 {
return false
}
// 是否为空限制
if len(name) == 0 {
return false
}
// 结尾不能是空格
if strings.HasSuffix(name, " ") {
return false
}
return true
}
// ValidateFileSize 验证上传的文件大小是否超出限制
func (fs *FileSystem) ValidateFileSize(ctx context.Context, size uint64) bool {
if fs.Policy.MaxSize == 0 {
return true
}
return size <= fs.Policy.MaxSize
}
// ValidateCapacity 验证并扣除用户容量
func (fs *FileSystem) ValidateCapacity(ctx context.Context, size uint64) bool {
return fs.User.IncreaseStorage(size)
}
// ValidateExtension 验证文件扩展名
func (fs *FileSystem) ValidateExtension(ctx context.Context, fileName string) bool {
// 不需要验证
if len(fs.Policy.OptionsSerialized.FileType) == 0 {
return true
}
return util.IsInExtensionList(fs.Policy.OptionsSerialized.FileType, fileName)
}