This commit is contained in:
2024-02-25 08:30:34 +08:00
commit 4947f39e74
273 changed files with 45396 additions and 0 deletions

67
pkg/aria2/aria2.go Normal file
View File

@ -0,0 +1,67 @@
package aria2
import (
"context"
"fmt"
"net/url"
"sync"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
)
// Instance 默认使用的Aria2处理实例
var Instance common.Aria2 = &common.DummyAria2{}
// LB 获取 Aria2 节点的负载均衡器
var LB balancer.Balancer
// Lock Instance的读写锁
var Lock sync.RWMutex
// GetLoadBalancer 返回供Aria2使用的负载均衡器
func GetLoadBalancer() balancer.Balancer {
Lock.RLock()
defer Lock.RUnlock()
return LB
}
// Init 初始化
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
Lock.Lock()
LB = balancer.NewBalancer("RoundRobin")
Lock.Unlock()
if !isReload {
// 从数据库中读取未完成任务,创建监控
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding)
for i := 0; i < len(unfinished); i++ {
// 创建任务监控
monitor.NewMonitor(&unfinished[i], pool, mqClient)
}
}
}
// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性
func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) {
// 解析RPC服务地址
rpcServer, err := url.Parse(server)
if err != nil {
return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err)
}
rpcServer.Path = "/jsonrpc"
caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil)
if err != nil {
return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err)
}
return caller.GetVersion()
}

119
pkg/aria2/common/common.go Normal file
View File

@ -0,0 +1,119 @@
package common
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Aria2 离线下载处理接口
type Aria2 interface {
// Init 初始化客户端连接
Init() error
// CreateTask 创建新的任务
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
// 返回状态信息
Status(task *model.Download) (rpc.StatusInfo, error)
// 取消任务
Cancel(task *model.Download) error
// 选择要下载的文件
Select(task *model.Download, files []int) error
// 获取离线下载配置
GetConfig() model.Aria2Option
// 删除临时下载文件
DeleteTempFile(*model.Download) error
}
const (
// URLTask 从URL添加的任务
URLTask = iota
// TorrentTask 种子任务
TorrentTask
)
const (
// Ready 准备就绪
Ready = iota
// Downloading 下载中
Downloading
// Paused 暂停中
Paused
// Error 出错
Error
// Complete 完成
Complete
// Canceled 取消/停止
Canceled
// Unknown 未知状态
Unknown
// Seeding 做种中
Seeding
)
var (
// ErrNotEnabled 功能未开启错误
ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil)
// ErrUserNotFound 未找到下载任务创建者
ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil)
)
// DummyAria2 未开启Aria2功能时使用的默认处理器
type DummyAria2 struct {
}
func (instance *DummyAria2) Init() error {
return nil
}
// CreateTask 创建新任务,此处直接返回未开启错误
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
return "", ErrNotEnabled
}
// Status 返回未开启错误
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
return rpc.StatusInfo{}, ErrNotEnabled
}
// Cancel 返回未开启错误
func (instance *DummyAria2) Cancel(task *model.Download) error {
return ErrNotEnabled
}
// Select 返回未开启错误
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
return ErrNotEnabled
}
// GetConfig 返回空的
func (instance *DummyAria2) GetConfig() model.Aria2Option {
return model.Aria2Option{}
}
// GetConfig 返回空的
func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
return ErrNotEnabled
}
// GetStatus 将给定的状态字符串转换为状态标识数字
func GetStatus(status rpc.StatusInfo) int {
switch status.Status {
case "complete":
return Complete
case "active":
if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength {
return Seeding
}
return Downloading
case "waiting":
return Ready
case "paused":
return Paused
case "error":
return Error
case "removed":
return Canceled
default:
return Unknown
}
}

View File

@ -0,0 +1,320 @@
package monitor
import (
"context"
"encoding/json"
"errors"
"fmt"
"path/filepath"
"strconv"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Monitor 离线下载状态监控
type Monitor struct {
Task *model.Download
Interval time.Duration
notifier <-chan mq.Message
node cluster.Node
retried int
}
var MAX_RETRY = 10
// NewMonitor 新建离线下载状态监控
func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
monitor := &Monitor{
Task: task,
notifier: make(chan mq.Message),
node: pool.GetNodeByID(task.GetNodeID()),
}
if monitor.node != nil {
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
go monitor.Loop(mqClient)
monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
} else {
monitor.setErrorStatus(errors.New("node not avaliable"))
}
}
// Loop 开启监控循环
func (monitor *Monitor) Loop(mqClient mq.MQ) {
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
fmt.Println(cluster.Default)
// 首次循环立即更新
interval := 50 * time.Millisecond
for {
select {
case <-monitor.notifier:
if monitor.Update() {
return
}
case <-time.After(interval):
interval = monitor.Interval
if monitor.Update() {
return
}
}
}
}
// Update 更新状态,返回值表示是否退出监控
func (monitor *Monitor) Update() bool {
status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
if err != nil {
monitor.retried++
util.Log().Warning("Cannot get status of download task %q: %s", monitor.Task.GID, err)
// 十次重试后认定为任务失败
if monitor.retried > MAX_RETRY {
util.Log().Warning("Cannot get status of download task %qexceed maximum retry threshold: %s",
monitor.Task.GID, err)
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
return false
}
monitor.retried = 0
// 磁力链下载需要跟随
if len(status.FollowedBy) > 0 {
util.Log().Debug("Redirected download task from %q to %q.", monitor.Task.GID, status.FollowedBy[0])
monitor.Task.GID = status.FollowedBy[0]
monitor.Task.Save()
return false
}
// 更新任务信息
if err := monitor.UpdateTaskInfo(status); err != nil {
util.Log().Warning("Failed to update status of download task %q: %s", monitor.Task.GID, err)
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
util.Log().Debug("Remote download %q status updated to %q.", status.Gid, status.Status)
switch common.GetStatus(status) {
case common.Complete, common.Seeding:
return monitor.Complete(task.TaskPoll)
case common.Error:
return monitor.Error(status)
case common.Downloading, common.Ready, common.Paused:
return false
case common.Canceled:
monitor.Task.Status = common.Canceled
monitor.Task.Save()
monitor.RemoveTempFolder()
return true
default:
util.Log().Warning("Download task %q returns unknown status %q.", monitor.Task.GID, status.Status)
return true
}
}
// UpdateTaskInfo 更新数据库中的任务信息
func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
originSize := monitor.Task.TotalSize
monitor.Task.GID = status.Gid
monitor.Task.Status = common.GetStatus(status)
// 文件大小、已下载大小
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
if err != nil {
total = 0
}
downloaded, err := strconv.ParseUint(status.CompletedLength, 10, 64)
if err != nil {
downloaded = 0
}
monitor.Task.TotalSize = total
monitor.Task.DownloadedSize = downloaded
monitor.Task.GID = status.Gid
monitor.Task.Parent = status.Dir
// 下载速度
speed, err := strconv.Atoi(status.DownloadSpeed)
if err != nil {
speed = 0
}
monitor.Task.Speed = speed
attrs, _ := json.Marshal(status)
monitor.Task.Attrs = string(attrs)
if err := monitor.Task.Save(); err != nil {
return err
}
if originSize != monitor.Task.TotalSize {
// 文件大小更新后,对文件限制等进行校验
if err := monitor.ValidateFile(); err != nil {
// 验证失败时取消任务
monitor.node.GetAria2Instance().Cancel(monitor.Task)
return err
}
}
return nil
}
// ValidateFile 上传过程中校验文件大小、文件名
func (monitor *Monitor) ValidateFile() error {
// 找到任务创建者
user := monitor.Task.GetOwner()
if user == nil {
return common.ErrUserNotFound
}
// 创建文件系统
fs, err := filesystem.NewFileSystem(user)
if err != nil {
return err
}
defer fs.Recycle()
if err := fs.SetPolicyFromPath(monitor.Task.Dst); err != nil {
return fmt.Errorf("failed to switch policy to target dir: %w", err)
}
// 创建上下文环境
file := &fsctx.FileStream{
Size: monitor.Task.TotalSize,
}
// 验证用户容量
if err := filesystem.HookValidateCapacity(context.Background(), fs, file); err != nil {
return err
}
// 验证每个文件
for _, fileInfo := range monitor.Task.StatusInfo.Files {
if fileInfo.Selected == "true" {
// 创建上下文环境
fileSize, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
file := &fsctx.FileStream{
Size: fileSize,
Name: filepath.Base(fileInfo.Path),
}
if err := filesystem.HookValidateFile(context.Background(), fs, file); err != nil {
return err
}
}
}
return nil
}
// Error 任务下载出错处理,返回是否中断监控
func (monitor *Monitor) Error(status rpc.StatusInfo) bool {
monitor.setErrorStatus(errors.New(status.ErrorMessage))
// 清理临时文件
monitor.RemoveTempFolder()
return true
}
// RemoveTempFolder 清理下载临时目录
func (monitor *Monitor) RemoveTempFolder() {
monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task)
}
// Complete 完成下载,返回是否中断监控
func (monitor *Monitor) Complete(pool task.Pool) bool {
// 未开始转存,提交转存任务
if monitor.Task.TaskID == 0 {
return monitor.transfer(pool)
}
// 做种完成
if common.GetStatus(monitor.Task.StatusInfo) == common.Complete {
transferTask, err := model.GetTasksByID(monitor.Task.TaskID)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 转存完成,回收下载目录
if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error {
job, err := task.NewRecycleTask(monitor.Task)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 提交回收任务
pool.Submit(job)
return true
}
}
return false
}
func (monitor *Monitor) transfer(pool task.Pool) bool {
// 创建中转任务
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
fileInfo := monitor.Task.StatusInfo.Files[i]
if fileInfo.Selected == "true" {
file = append(file, fileInfo.Path)
size, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
sizes[fileInfo.Path] = size
}
}
job, err := task.NewTransferTask(
monitor.Task.UserID,
file,
monitor.Task.Dst,
monitor.Task.Parent,
true,
monitor.node.ID(),
sizes,
)
if err != nil {
monitor.setErrorStatus(err)
monitor.RemoveTempFolder()
return true
}
// 提交中转任务
pool.Submit(job)
// 更新任务ID
monitor.Task.TaskID = job.Model().ID
monitor.Task.Save()
return false
}
func (monitor *Monitor) setErrorStatus(err error) {
monitor.Task.Status = common.Error
monitor.Task.Error = err.Error()
monitor.Task.Save()
}

257
pkg/aria2/rpc/README.md Normal file
View File

@ -0,0 +1,257 @@
# PACKAGE DOCUMENTATION
**package rpc**
import "github.com/matzoe/argo/rpc"
## FUNCTIONS
```
func Call(address, method string, params, reply interface{}) error
```
## TYPES
```
type Client struct {
// contains filtered or unexported fields
}
```
```
func New(uri string) *Client
```
```
func (id *Client) AddMetalink(uri string, options ...interface{}) (gid string, err error)
```
`aria2.addMetalink(metalink[, options[, position]])` This method adds Metalink download by uploading ".metalink" file. `metalink` is of type base64 which contains Base64-encoded ".metalink" file. `options` is of type struct and its members are a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at `position` in the
waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns array of GID of registered download. If `--rpc-save-upload-metadata` is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".metalink" in the directory specified by `--dir` option. The example of filename is 0a3893293e27ac0490424c06de4d09242215f0a6.metalink. If same file already exists, it is overwritten. If the file cannot be saved successfully or `--rpc-save-upload-metadata` is false, the downloads added by this method are not saved by `--save-session`.
```
func (id *Client) AddTorrent(filename string, options ...interface{}) (gid string, err error)
```
`aria2.addTorrent(torrent[, uris[, options[, position]]])` This method adds BitTorrent download by uploading ".torrent" file. If you want to add BitTorrent Magnet URI, use `aria2.addUri()` method instead. torrent is of type base64 which contains Base64-encoded ".torrent" file. `uris` is of type array and its element is URI which is of type string. `uris` is used for Web-seeding. For single file torrents, URI can be a complete URI pointing to the resource or if URI ends with /, name in torrent file is added. For multi-file torrents, name and path in torrent are added to form a URI for each file. options is of type struct and its members are
a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at `position` in the waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns GID of registered download. If `--rpc-save-upload-metadata` is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".torrent" in the
directory specified by `--dir` option. The example of filename is 0a3893293e27ac0490424c06de4d09242215f0a6.torrent. If same file already exists, it is overwritten. If the file cannot be saved successfully or `--rpc-save-upload-metadata` is false, the downloads added by this method are not saved by -`-save-session`.
```
func (id *Client) AddUri(uri string, options ...interface{}) (gid string, err error)
```
`aria2.addUri(uris[, options[, position]])` This method adds new HTTP(S)/FTP/BitTorrent Magnet URI. `uris` is of type array and its element is URI which is of type string. For BitTorrent Magnet URI, `uris` must have only one element and it should be BitTorrent Magnet URI. URIs in uris must point to the same file. If you mix other URIs which point to another file, aria2 does not complain but download may
fail. `options` is of type struct and its members are a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at position in the waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns GID of registered download.
```
func (id *Client) ChangeGlobalOption(options map[string]interface{}) (g string, err error)
```
`aria2.changeGlobalOption(options)` This method changes global options dynamically. `options` is of type struct. The following `options` are available:
download-result
log
log-level
max-concurrent-downloads
max-download-result
max-overall-download-limit
max-overall-upload-limit
save-cookies
save-session
server-stat-of
In addition to them, options listed in Input File subsection are available, except for following options: `checksum`, `index-out`, `out`, `pause` and `select-file`. Using `log` option, you can dynamically start logging or change log file. To stop logging, give empty string("") as a parameter value. Note that log file is always opened in append mode. This method returns OK for success.
```
func (id *Client) ChangeOption(gid string, options map[string]interface{}) (g string, err error)
```
`aria2.changeOption(gid, options)` This method changes options of the download denoted by `gid` dynamically. `gid` is of type string. `options` is of type struct. The following `options` are available for active downloads:
bt-max-peers
bt-request-peer-speed-limit
bt-remove-unselected-file
force-save
max-download-limit
max-upload-limit
For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option. This method returns OK for success.
```
func (id *Client) ChangePosition(gid string, pos int, how string) (p int, err error)
```
`aria2.changePosition(gid, pos, how)` This method changes the position of the download denoted by `gid`. `pos` is of type integer. `how` is of type string. If `how` is `POS_SET`, it moves the download to a position relative to the beginning of the queue. If `how` is `POS_CUR`, it moves the download to a position relative to the current position. If `how` is `POS_END`, it moves the download to a position relative to the end of the queue. If the destination position is less than 0 or beyond the end
of the queue, it moves the download to the beginning or the end of the queue respectively. The response is of type integer and it is the destination position.
```
func (id *Client) ChangeUri(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error)
```
`aria2.changeUri(gid, fileIndex, delUris, addUris[, position])` This method removes URIs in `delUris` from and appends URIs in `addUris` to download denoted by gid. `delUris` and `addUris` are list of string. A download can contain multiple files and URIs are attached to each file. `fileIndex` is used to select which file to remove/attach given URIs. `fileIndex` is 1-based. `position` is used to specify where URIs are inserted in the existing waiting URI list. `position` is 0-based. When
`position` is omitted, URIs are appended to the back of the list. This method first execute removal and then addition. `position` is the `position` after URIs are removed, not the `position` when this method is called. When removing URI, if same URIs exist in download, only one of them is removed for each URI in delUris. In other words, there are three URIs http://example.org/aria2 and you want remove them all, you
have to specify (at least) 3 http://example.org/aria2 in delUris. This method returns a list which contains 2 integers. The first integer is the number of URIs deleted. The second integer is the number of URIs added.
```
func (id *Client) ForcePause(gid string) (g string, err error)
```
`aria2.forcePause(pid)` This method pauses the download denoted by `gid`. This method behaves just like aria2.pause() except that this method pauses download without any action which takes time such as contacting BitTorrent tracker.
```
func (id *Client) ForcePauseAll() (g string, err error)
```
`aria2.forcePauseAll()` This method is equal to calling `aria2.forcePause()` for every active/waiting download. This methods returns OK for success.
```
func (id *Client) ForceRemove(gid string) (g string, err error)
```
`aria2.forceRemove(gid)` This method removes the download denoted by `gid`. This method behaves just like aria2.remove() except that this method removes download without any action which takes time such as contacting BitTorrent tracker.
```
func (id *Client) ForceShutdown() (g string, err error)
```
`aria2.forceShutdown()` This method shutdowns aria2. This method behaves like `aria2.shutdown()` except that any actions which takes time such as contacting BitTorrent tracker are skipped. This method returns OK.
```
func (id *Client) GetFiles(gid string) (m map[string]interface{}, err error)
```
`aria2.getFiles(gid)` This method returns file list of the download denoted by `gid`. `gid` is of type string.
```
func (id *Client) GetGlobalOption() (m map[string]interface{}, err error)
```
`aria2.getGlobalOption()` This method returns global options. The response is of type struct. Its key is the name of option. The value type is string. Note that this method does not return options which have no default value and have not been set by the command-line options, configuration files or RPC methods. Because global options are used as a template for the options of newly added download, the response contains
keys returned by `aria2.getOption()` method.
```
func (id *Client) GetGlobalStat() (m map[string]interface{}, err error)
```
`aria2.getGlobalStat()` This method returns global statistics such as overall download and upload speed.
```
func (id *Client) GetOption(gid string) (m map[string]interface{}, err error)
```
`aria2.getOption(gid)` This method returns options of the download denoted by `gid`. The response is of type struct. Its key is the name of option. The value type is string. Note that this method does not return options which have no default value and have not been set by the command-line options, configuration files or RPC methods.
```
func (id *Client) GetPeers(gid string) (m []map[string]interface{}, err error)
```
`aria2.getPeers(gid)` This method returns peer list of the download denoted by `gid`. `gid` is of type string. This method is for BitTorrent only.
```
func (id *Client) GetServers(gid string) (m []map[string]interface{}, err error)
```
`aria2.getServers(gid)` This method returns currently connected HTTP(S)/FTP servers of the download denoted by `gid`. `gid` is of type string.
```
func (id *Client) GetSessionInfo() (m map[string]interface{}, err error)
```
`aria2.getSessionInfo()` This method returns session information.
```
func (id *Client) GetUris(gid string) (m map[string]interface{}, err error)
```
`aria2.getUris(gid)` This method returns URIs used in the download denoted by `gid`. `gid` is of type string.
```
func (id *Client) GetVersion() (m map[string]interface{}, err error)
```
`aria2.getVersion()` This method returns version of the program and the list of enabled features.
```
func (id *Client) Multicall(methods []map[string]interface{}) (r []interface{}, err error)
```
`system.multicall(methods)` This method encapsulates multiple method calls in a single request. `methods` is of type array and its element is struct. The struct contains two keys: `methodName` and `params`. `methodName` is the method name to call and `params` is array containing parameters to the method. This method returns array of responses. The element of array will either be a one-item array containing the return value of each method call or struct of fault element if an encapsulated method call fails.
```
func (id *Client) Pause(gid string) (g string, err error)
```
`aria2.pause(gid)` This method pauses the download denoted by `gid`. `gid` is of type string. The status of paused download becomes paused. If the download is active, the download is placed on the first position of waiting queue. As long as the status is paused, the download is not started. To change status to waiting, use `aria2.unpause()` method. This method returns GID of paused download.
```
func (id *Client) PauseAll() (g string, err error)
```
`aria2.pauseAll()` This method is equal to calling `aria2.pause()` for every active/waiting download. This methods returns OK for success.
```
func (id *Client) PurgeDowloadResult() (g string, err error)
```
`aria2.purgeDownloadResult()` This method purges completed/error/removed downloads to free memory. This method returns OK.
```
func (id *Client) Remove(gid string) (g string, err error)
```
`aria2.remove(gid)` This method removes the download denoted by gid. `gid` is of type string. If specified download is in progress, it is stopped at first. The status of removed download becomes removed. This method returns GID of removed download.
```
func (id *Client) RemoveDownloadResult(gid string) (g string, err error)
```
`aria2.removeDownloadResult(gid)` This method removes completed/error/removed download denoted by `gid` from memory. This method returns OK for success.
```
func (id *Client) Shutdown() (g string, err error)
```
`aria2.shutdown()` This method shutdowns aria2. This method returns OK.
```
func (id *Client) TellActive(keys ...string) (m []map[string]interface{}, err error)
```
`aria2.tellActive([keys])` This method returns the list of active downloads. The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method. For `keys` parameter, please refer to `aria2.tellStatus()` method.
```
func (id *Client) TellStatus(gid string, keys ...string) (m map[string]interface{}, err error)
```
`aria2.tellStatus(gid[, keys])` This method returns download progress of the download denoted by `gid`. `gid` is of type string. `keys` is array of string. If it is specified, the response contains only keys in `keys` array. If `keys` is empty or not specified, the response contains all keys. This is useful when you just want specific keys and avoid unnecessary transfers. For example, `aria2.tellStatus("2089b05ecca3d829", ["gid", "status"])` returns `gid` and `status` key.
```
func (id *Client) TellStopped(offset, num int, keys ...string) (m []map[string]interface{}, err error)
```
`aria2.tellStopped(offset, num[, keys])` This method returns the list of stopped download. `offset` is of type integer and specifies the `offset` from the oldest download. `num` is of type integer and specifies the number of downloads to be returned. For keys parameter, please refer to `aria2.tellStatus()` method. `offset` and `num` have the same semantics as `aria2.tellWaiting()` method. The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method.
```
func (id *Client) TellWaiting(offset, num int, keys ...string) (m []map[string]interface{}, err error)
```
`aria2.tellWaiting(offset, num[, keys])` This method returns the list of waiting download, including paused downloads. `offset` is of type integer and specifies the `offset` from the download waiting at the front. num is of type integer and specifies the number of downloads to be returned. For keys parameter, please refer to aria2.tellStatus() method. If `offset` is a positive integer, this method returns downloads
in the range of `[offset, offset + num)`. `offset` can be a negative integer. `offset == -1` points last download in the waiting queue and `offset == -2` points the download before the last download, and so on. The downloads in the response are in reversed order. For example, imagine that three downloads "A","B" and "C" are waiting in this order.
aria2.tellWaiting(0, 1) returns ["A"].
aria2.tellWaiting(1, 2) returns ["B", "C"].
aria2.tellWaiting(-1, 2) returns ["C", "B"].
The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method.
```
func (id *Client) Unpause(gid string) (g string, err error)
```
`aria2.unpause(gid)` This method changes the status of the download denoted by `gid` from paused to waiting. This makes the download eligible to restart. `gid` is of type string. This method returns GID of unpaused download.
```
func (id *Client) UnpauseAll() (g string, err error)
```
`aria2.unpauseAll()` This method is equal to calling `aria2.unpause()` for every active/waiting download. This methods returns OK for success.

274
pkg/aria2/rpc/call.go Normal file
View File

@ -0,0 +1,274 @@
package rpc
import (
"context"
"errors"
"log"
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
)
type caller interface {
// Call sends a request of rpc to aria2 daemon
Call(method string, params, reply interface{}) (err error)
Close() error
}
type httpCaller struct {
uri string
c *http.Client
cancel context.CancelFunc
wg *sync.WaitGroup
once sync.Once
}
func newHTTPCaller(ctx context.Context, u *url.URL, timeout time.Duration, notifer Notifier) *httpCaller {
c := &http.Client{
Transport: &http.Transport{
MaxIdleConnsPerHost: 1,
MaxConnsPerHost: 1,
// TLSClientConfig: tlsConfig,
Dial: (&net.Dialer{
Timeout: timeout,
KeepAlive: 60 * time.Second,
}).Dial,
TLSHandshakeTimeout: 3 * time.Second,
ResponseHeaderTimeout: timeout,
},
}
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(ctx)
h := &httpCaller{uri: u.String(), c: c, cancel: cancel, wg: &wg}
if notifer != nil {
h.setNotifier(ctx, *u, notifer)
}
return h
}
func (h *httpCaller) Close() (err error) {
h.once.Do(func() {
h.cancel()
h.wg.Wait()
})
return
}
func (h *httpCaller) setNotifier(ctx context.Context, u url.URL, notifer Notifier) (err error) {
u.Scheme = "ws"
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
return
}
h.wg.Add(1)
go func() {
defer h.wg.Done()
defer conn.Close()
select {
case <-ctx.Done():
conn.SetWriteDeadline(time.Now().Add(time.Second))
if err := conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
log.Printf("sending websocket close message: %v", err)
}
return
}
}()
h.wg.Add(1)
go func() {
defer h.wg.Done()
var request websocketResponse
var err error
for {
select {
case <-ctx.Done():
return
default:
}
if err = conn.ReadJSON(&request); err != nil {
select {
case <-ctx.Done():
return
default:
}
log.Printf("conn.ReadJSON|err:%v", err.Error())
return
}
switch request.Method {
case "aria2.onDownloadStart":
notifer.OnDownloadStart(request.Params)
case "aria2.onDownloadPause":
notifer.OnDownloadPause(request.Params)
case "aria2.onDownloadStop":
notifer.OnDownloadStop(request.Params)
case "aria2.onDownloadComplete":
notifer.OnDownloadComplete(request.Params)
case "aria2.onDownloadError":
notifer.OnDownloadError(request.Params)
case "aria2.onBtDownloadComplete":
notifer.OnBtDownloadComplete(request.Params)
default:
log.Printf("unexpected notification: %s", request.Method)
}
}
}()
return
}
func (h httpCaller) Call(method string, params, reply interface{}) (err error) {
payload, err := EncodeClientRequest(method, params)
if err != nil {
return
}
r, err := h.c.Post(h.uri, "application/json", payload)
if err != nil {
return
}
err = DecodeClientResponse(r.Body, &reply)
r.Body.Close()
return
}
type websocketCaller struct {
conn *websocket.Conn
sendChan chan *sendRequest
cancel context.CancelFunc
wg *sync.WaitGroup
once sync.Once
timeout time.Duration
}
func newWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) {
var header = http.Header{}
conn, _, err := websocket.DefaultDialer.Dial(uri, header)
if err != nil {
return nil, err
}
sendChan := make(chan *sendRequest, 16)
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(ctx)
w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout}
processor := NewResponseProcessor()
wg.Add(1)
go func() { // routine:recv
defer wg.Done()
defer cancel()
for {
select {
case <-ctx.Done():
return
default:
}
var resp websocketResponse
if err := conn.ReadJSON(&resp); err != nil {
select {
case <-ctx.Done():
return
default:
}
log.Printf("conn.ReadJSON|err:%v", err.Error())
return
}
if resp.Id == nil { // RPC notifications
if notifier != nil {
switch resp.Method {
case "aria2.onDownloadStart":
notifier.OnDownloadStart(resp.Params)
case "aria2.onDownloadPause":
notifier.OnDownloadPause(resp.Params)
case "aria2.onDownloadStop":
notifier.OnDownloadStop(resp.Params)
case "aria2.onDownloadComplete":
notifier.OnDownloadComplete(resp.Params)
case "aria2.onDownloadError":
notifier.OnDownloadError(resp.Params)
case "aria2.onBtDownloadComplete":
notifier.OnBtDownloadComplete(resp.Params)
default:
log.Printf("unexpected notification: %s", resp.Method)
}
}
continue
}
processor.Process(resp.clientResponse)
}
}()
wg.Add(1)
go func() { // routine:send
defer wg.Done()
defer cancel()
defer w.conn.Close()
for {
select {
case <-ctx.Done():
if err := w.conn.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
log.Printf("sending websocket close message: %v", err)
}
return
case req := <-sendChan:
processor.Add(req.request.Id, func(resp clientResponse) error {
err := resp.decode(req.reply)
req.cancel()
return err
})
w.conn.SetWriteDeadline(time.Now().Add(timeout))
w.conn.WriteJSON(req.request)
}
}
}()
return w, nil
}
func (w *websocketCaller) Close() (err error) {
w.once.Do(func() {
w.cancel()
w.wg.Wait()
})
return
}
func (w websocketCaller) Call(method string, params, reply interface{}) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), w.timeout)
defer cancel()
select {
case w.sendChan <- &sendRequest{cancel: cancel, request: &clientRequest{
Version: "2.0",
Method: method,
Params: params,
Id: reqid(),
}, reply: reply}:
default:
return errors.New("sending channel blocking")
}
select {
case <-ctx.Done():
if err := ctx.Err(); err == context.DeadlineExceeded {
return err
}
}
return
}
type sendRequest struct {
cancel context.CancelFunc
request *clientRequest
reply interface{}
}
var reqid = func() func() uint64 {
var id = uint64(time.Now().UnixNano())
return func() uint64 {
return atomic.AddUint64(&id, 1)
}
}()

656
pkg/aria2/rpc/client.go Normal file
View File

@ -0,0 +1,656 @@
package rpc
import (
"context"
"encoding/base64"
"errors"
"io/ioutil"
"net/url"
"time"
)
// Option is a container for specifying Call parameters and returning results
type Option map[string]interface{}
type Client interface {
Protocol
Close() error
}
type client struct {
caller
url *url.URL
token string
}
var (
errInvalidParameter = errors.New("invalid parameter")
errNotImplemented = errors.New("not implemented")
errConnTimeout = errors.New("connect to aria2 daemon timeout")
)
// New returns an instance of Client
func New(ctx context.Context, uri string, token string, timeout time.Duration, notifier Notifier) (Client, error) {
u, err := url.Parse(uri)
if err != nil {
return nil, err
}
var caller caller
switch u.Scheme {
case "http", "https":
caller = newHTTPCaller(ctx, u, timeout, notifier)
case "ws", "wss":
caller, err = newWebsocketCaller(ctx, u.String(), timeout, notifier)
if err != nil {
return nil, err
}
default:
return nil, errInvalidParameter
}
c := &client{caller: caller, url: u, token: token}
return c, nil
}
// `aria2.addUri([secret, ]uris[, options[, position]])`
// This method adds a new download. uris is an array of HTTP/FTP/SFTP/BitTorrent URIs (strings) pointing to the same resource.
// If you mix URIs pointing to different resources, then the download may fail or be corrupted without aria2 complaining.
// When adding BitTorrent Magnet URIs, uris must have only one element and it should be BitTorrent Magnet URI.
// options is a struct and its members are pairs of option name and value.
// If position is given, it must be an integer starting from 0.
// The new download will be inserted at position in the waiting queue.
// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue.
// This method returns the GID of the newly registered download.
func (c *client) AddURI(uri string, options ...interface{}) (gid string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, []string{uri})
if options != nil {
params = append(params, options...)
}
err = c.Call(aria2AddURI, params, &gid)
return
}
// `aria2.addTorrent([secret, ]torrent[, uris[, options[, position]]])`
// This method adds a BitTorrent download by uploading a ".torrent" file.
// If you want to add a BitTorrent Magnet URI, use the aria2.addUri() method instead.
// torrent must be a base64-encoded string containing the contents of the ".torrent" file.
// uris is an array of URIs (string). uris is used for Web-seeding.
// For single file torrents, the URI can be a complete URI pointing to the resource; if URI ends with /, name in torrent file is added.
// For multi-file torrents, name and path in torrent are added to form a URI for each file. options is a struct and its members are pairs of option name and value.
// If position is given, it must be an integer starting from 0.
// The new download will be inserted at position in the waiting queue.
// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue.
// This method returns the GID of the newly registered download.
// If --rpc-save-upload-metadata is true, the uploaded data is saved as a file named as the hex string of SHA-1 hash of data plus ".torrent" in the directory specified by --dir option.
// E.g. a file name might be 0a3893293e27ac0490424c06de4d09242215f0a6.torrent.
// If a file with the same name already exists, it is overwritten!
// If the file cannot be saved successfully or --rpc-save-upload-metadata is false, the downloads added by this method are not saved by --save-session.
func (c *client) AddTorrent(filename string, options ...interface{}) (gid string, err error) {
co, err := ioutil.ReadFile(filename)
if err != nil {
return
}
file := base64.StdEncoding.EncodeToString(co)
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, file)
if options != nil {
params = append(params, options...)
}
err = c.Call(aria2AddTorrent, params, &gid)
return
}
// `aria2.addMetalink([secret, ]metalink[, options[, position]])`
// This method adds a Metalink download by uploading a ".metalink" file.
// metalink is a base64-encoded string which contains the contents of the ".metalink" file.
// options is a struct and its members are pairs of option name and value.
// If position is given, it must be an integer starting from 0.
// The new download will be inserted at position in the waiting queue.
// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue.
// This method returns an array of GIDs of newly registered downloads.
// If --rpc-save-upload-metadata is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".metalink" in the directory specified by --dir option.
// E.g. a file name might be 0a3893293e27ac0490424c06de4d09242215f0a6.metalink.
// If a file with the same name already exists, it is overwritten!
// If the file cannot be saved successfully or --rpc-save-upload-metadata is false, the downloads added by this method are not saved by --save-session.
func (c *client) AddMetalink(filename string, options ...interface{}) (gid []string, err error) {
co, err := ioutil.ReadFile(filename)
if err != nil {
return
}
file := base64.StdEncoding.EncodeToString(co)
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, file)
if options != nil {
params = append(params, options...)
}
err = c.Call(aria2AddMetalink, params, &gid)
return
}
// `aria2.remove([secret, ]gid)`
// This method removes the download denoted by gid (string).
// If the specified download is in progress, it is first stopped.
// The status of the removed download becomes removed.
// This method returns GID of removed download.
func (c *client) Remove(gid string) (g string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2Remove, params, &g)
return
}
// `aria2.forceRemove([secret, ]gid)`
// This method removes the download denoted by gid.
// This method behaves just like aria2.remove() except that this method removes the download without performing any actions which take time, such as contacting BitTorrent trackers to unregister the download first.
func (c *client) ForceRemove(gid string) (g string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2ForceRemove, params, &g)
return
}
// `aria2.pause([secret, ]gid)`
// This method pauses the download denoted by gid (string).
// The status of paused download becomes paused.
// If the download was active, the download is placed in the front of waiting queue.
// While the status is paused, the download is not started.
// To change status to waiting, use the aria2.unpause() method.
// This method returns GID of paused download.
func (c *client) Pause(gid string) (g string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2Pause, params, &g)
return
}
// `aria2.pauseAll([secret])`
// This method is equal to calling aria2.pause() for every active/waiting download.
// This methods returns OK.
func (c *client) PauseAll() (ok string, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2PauseAll, params, &ok)
return
}
// `aria2.forcePause([secret, ]gid)`
// This method pauses the download denoted by gid.
// This method behaves just like aria2.pause() except that this method pauses downloads without performing any actions which take time, such as contacting BitTorrent trackers to unregister the download first.
func (c *client) ForcePause(gid string) (g string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2ForcePause, params, &g)
return
}
// `aria2.forcePauseAll([secret])`
// This method is equal to calling aria2.forcePause() for every active/waiting download.
// This methods returns OK.
func (c *client) ForcePauseAll() (ok string, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2ForcePauseAll, params, &ok)
return
}
// `aria2.unpause([secret, ]gid)`
// This method changes the status of the download denoted by gid (string) from paused to waiting, making the download eligible to be restarted.
// This method returns the GID of the unpaused download.
func (c *client) Unpause(gid string) (g string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2Unpause, params, &g)
return
}
// `aria2.unpauseAll([secret])`
// This method is equal to calling aria2.unpause() for every active/waiting download.
// This methods returns OK.
func (c *client) UnpauseAll() (ok string, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2UnpauseAll, params, &ok)
return
}
// `aria2.tellStatus([secret, ]gid[, keys])`
// This method returns the progress of the download denoted by gid (string).
// keys is an array of strings.
// If specified, the response contains only keys in the keys array.
// If keys is empty or omitted, the response contains all keys.
// This is useful when you just want specific keys and avoid unnecessary transfers.
// For example, aria2.tellStatus("2089b05ecca3d829", ["gid", "status"]) returns the gid and status keys only.
// The response is a struct and contains following keys. Values are strings.
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.tellStatus
func (c *client) TellStatus(gid string, keys ...string) (info StatusInfo, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
if keys != nil {
params = append(params, keys)
}
err = c.Call(aria2TellStatus, params, &info)
return
}
// `aria2.getUris([secret, ]gid)`
// This method returns the URIs used in the download denoted by gid (string).
// The response is an array of structs and it contains following keys. Values are string.
// uri URI
// status 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue.
func (c *client) GetURIs(gid string) (infos []URIInfo, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2GetURIs, params, &infos)
return
}
// `aria2.getFiles([secret, ]gid)`
// This method returns the file list of the download denoted by gid (string).
// The response is an array of structs which contain following keys. Values are strings.
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getFiles
func (c *client) GetFiles(gid string) (infos []FileInfo, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2GetFiles, params, &infos)
return
}
// `aria2.getPeers([secret, ]gid)`
// This method returns a list peers of the download denoted by gid (string).
// This method is for BitTorrent only.
// The response is an array of structs and contains the following keys. Values are strings.
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getPeers
func (c *client) GetPeers(gid string) (infos []PeerInfo, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2GetPeers, params, &infos)
return
}
// `aria2.getServers([secret, ]gid)`
// This method returns currently connected HTTP(S)/FTP/SFTP servers of the download denoted by gid (string).
// The response is an array of structs and contains the following keys. Values are strings.
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getServers
func (c *client) GetServers(gid string) (infos []ServerInfo, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2GetServers, params, &infos)
return
}
// `aria2.tellActive([secret][, keys])`
// This method returns a list of active downloads.
// The response is an array of the same structs as returned by the aria2.tellStatus() method.
// For the keys parameter, please refer to the aria2.tellStatus() method.
func (c *client) TellActive(keys ...string) (infos []StatusInfo, err error) {
params := make([]interface{}, 0, 1)
if c.token != "" {
params = append(params, "token:"+c.token)
}
if keys != nil {
params = append(params, keys)
}
err = c.Call(aria2TellActive, params, &infos)
return
}
// `aria2.tellWaiting([secret, ]offset, num[, keys])`
// This method returns a list of waiting downloads, including paused ones.
// offset is an integer and specifies the offset from the download waiting at the front.
// num is an integer and specifies the max. number of downloads to be returned.
// For the keys parameter, please refer to the aria2.tellStatus() method.
// If offset is a positive integer, this method returns downloads in the range of [offset, offset + num).
// offset can be a negative integer. offset == -1 points last download in the waiting queue and offset == -2 points the download before the last download, and so on.
// Downloads in the response are in reversed order then.
// For example, imagine three downloads "A","B" and "C" are waiting in this order.
// aria2.tellWaiting(0, 1) returns ["A"].
// aria2.tellWaiting(1, 2) returns ["B", "C"].
// aria2.tellWaiting(-1, 2) returns ["C", "B"].
// The response is an array of the same structs as returned by aria2.tellStatus() method.
func (c *client) TellWaiting(offset, num int, keys ...string) (infos []StatusInfo, err error) {
params := make([]interface{}, 0, 3)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, offset)
params = append(params, num)
if keys != nil {
params = append(params, keys)
}
err = c.Call(aria2TellWaiting, params, &infos)
return
}
// `aria2.tellStopped([secret, ]offset, num[, keys])`
// This method returns a list of stopped downloads.
// offset is an integer and specifies the offset from the least recently stopped download.
// num is an integer and specifies the max. number of downloads to be returned.
// For the keys parameter, please refer to the aria2.tellStatus() method.
// offset and num have the same semantics as described in the aria2.tellWaiting() method.
// The response is an array of the same structs as returned by the aria2.tellStatus() method.
func (c *client) TellStopped(offset, num int, keys ...string) (infos []StatusInfo, err error) {
params := make([]interface{}, 0, 3)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, offset)
params = append(params, num)
if keys != nil {
params = append(params, keys)
}
err = c.Call(aria2TellStopped, params, &infos)
return
}
// `aria2.changePosition([secret, ]gid, pos, how)`
// This method changes the position of the download denoted by gid in the queue.
// pos is an integer. how is a string.
// If how is POS_SET, it moves the download to a position relative to the beginning of the queue.
// If how is POS_CUR, it moves the download to a position relative to the current position.
// If how is POS_END, it moves the download to a position relative to the end of the queue.
// If the destination position is less than 0 or beyond the end of the queue, it moves the download to the beginning or the end of the queue respectively.
// The response is an integer denoting the resulting position.
// For example, if GID#2089b05ecca3d829 is currently in position 3, aria2.changePosition('2089b05ecca3d829', -1, 'POS_CUR') will change its position to 2. Additionally aria2.changePosition('2089b05ecca3d829', 0, 'POS_SET') will change its position to 0 (the beginning of the queue).
func (c *client) ChangePosition(gid string, pos int, how string) (p int, err error) {
params := make([]interface{}, 0, 3)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
params = append(params, pos)
params = append(params, how)
err = c.Call(aria2ChangePosition, params, &p)
return
}
// `aria2.changeUri([secret, ]gid, fileIndex, delUris, addUris[, position])`
// This method removes the URIs in delUris from and appends the URIs in addUris to download denoted by gid.
// delUris and addUris are lists of strings.
// A download can contain multiple files and URIs are attached to each file.
// fileIndex is used to select which file to remove/attach given URIs. fileIndex is 1-based.
// position is used to specify where URIs are inserted in the existing waiting URI list. position is 0-based.
// When position is omitted, URIs are appended to the back of the list.
// This method first executes the removal and then the addition.
// position is the position after URIs are removed, not the position when this method is called.
// When removing an URI, if the same URIs exist in download, only one of them is removed for each URI in delUris.
// In other words, if there are three URIs http://example.org/aria2 and you want remove them all, you have to specify (at least) 3 http://example.org/aria2 in delUris.
// This method returns a list which contains two integers.
// The first integer is the number of URIs deleted.
// The second integer is the number of URIs added.
func (c *client) ChangeURI(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error) {
params := make([]interface{}, 0, 5)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
params = append(params, fileindex)
params = append(params, delUris)
params = append(params, addUris)
if position != nil {
params = append(params, position[0])
}
err = c.Call(aria2ChangeURI, params, &p)
return
}
// `aria2.getOption([secret, ]gid)`
// This method returns options of the download denoted by gid.
// The response is a struct where keys are the names of options.
// The values are strings.
// Note that this method does not return options which have no default value and have not been set on the command-line, in configuration files or RPC methods.
func (c *client) GetOption(gid string) (m Option, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2GetOption, params, &m)
return
}
// `aria2.changeOption([secret, ]gid, options)`
// This method changes options of the download denoted by gid (string) dynamically. options is a struct.
// The following options are available for active downloads:
// bt-max-peers
// bt-request-peer-speed-limit
// bt-remove-unselected-file
// force-save
// max-download-limit
// max-upload-limit
// For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option.
// This method returns OK for success.
func (c *client) ChangeOption(gid string, option Option) (ok string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
if option != nil {
params = append(params, option)
}
err = c.Call(aria2ChangeOption, params, &ok)
return
}
// `aria2.getGlobalOption([secret])`
// This method returns the global options.
// The response is a struct.
// Its keys are the names of options.
// Values are strings.
// Note that this method does not return options which have no default value and have not been set on the command-line, in configuration files or RPC methods. Because global options are used as a template for the options of newly added downloads, the response contains keys returned by the aria2.getOption() method.
func (c *client) GetGlobalOption() (m Option, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2GetGlobalOption, params, &m)
return
}
// `aria2.changeGlobalOption([secret, ]options)`
// This method changes global options dynamically.
// options is a struct.
// The following options are available:
// bt-max-open-files
// download-result
// log
// log-level
// max-concurrent-downloads
// max-download-result
// max-overall-download-limit
// max-overall-upload-limit
// save-cookies
// save-session
// server-stat-of
// In addition, options listed in the Input File subsection are available, except for following options: checksum, index-out, out, pause and select-file.
// With the log option, you can dynamically start logging or change log file.
// To stop logging, specify an empty string("") as the parameter value.
// Note that log file is always opened in append mode.
// This method returns OK for success.
func (c *client) ChangeGlobalOption(options Option) (ok string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, options)
err = c.Call(aria2ChangeGlobalOption, params, &ok)
return
}
// `aria2.getGlobalStat([secret])`
// This method returns global statistics such as the overall download and upload speeds.
// The response is a struct and contains the following keys. Values are strings.
// downloadSpeed Overall download speed (byte/sec).
// uploadSpeed Overall upload speed(byte/sec).
// numActive The number of active downloads.
// numWaiting The number of waiting downloads.
// numStopped The number of stopped downloads in the current session.
// This value is capped by the --max-download-result option.
// numStoppedTotal The number of stopped downloads in the current session and not capped by the --max-download-result option.
func (c *client) GetGlobalStat() (info GlobalStatInfo, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2GetGlobalStat, params, &info)
return
}
// `aria2.purgeDownloadResult([secret])`
// This method purges completed/error/removed downloads to free memory.
// This method returns OK.
func (c *client) PurgeDownloadResult() (ok string, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2PurgeDownloadResult, params, &ok)
return
}
// `aria2.removeDownloadResult([secret, ]gid)`
// This method removes a completed/error/removed download denoted by gid from memory.
// This method returns OK for success.
func (c *client) RemoveDownloadResult(gid string) (ok string, err error) {
params := make([]interface{}, 0, 2)
if c.token != "" {
params = append(params, "token:"+c.token)
}
params = append(params, gid)
err = c.Call(aria2RemoveDownloadResult, params, &ok)
return
}
// `aria2.getVersion([secret])`
// This method returns the version of aria2 and the list of enabled features.
// The response is a struct and contains following keys.
// version Version number of aria2 as a string.
// enabledFeatures List of enabled features. Each feature is given as a string.
func (c *client) GetVersion() (info VersionInfo, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2GetVersion, params, &info)
return
}
// `aria2.getSessionInfo([secret])`
// This method returns session information.
// The response is a struct and contains following key.
// sessionId Session ID, which is generated each time when aria2 is invoked.
func (c *client) GetSessionInfo() (info SessionInfo, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2GetSessionInfo, params, &info)
return
}
// `aria2.shutdown([secret])`
// This method shutdowns aria2.
// This method returns OK.
func (c *client) Shutdown() (ok string, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2Shutdown, params, &ok)
return
}
// `aria2.forceShutdown([secret])`
// This method shuts down aria2().
// This method behaves like :func:'aria2.shutdown` without performing any actions which take time, such as contacting BitTorrent trackers to unregister downloads first.
// This method returns OK.
func (c *client) ForceShutdown() (ok string, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2ForceShutdown, params, &ok)
return
}
// `aria2.saveSession([secret])`
// This method saves the current session to a file specified by the --save-session option.
// This method returns OK if it succeeds.
func (c *client) SaveSession() (ok string, err error) {
params := []string{}
if c.token != "" {
params = append(params, "token:"+c.token)
}
err = c.Call(aria2SaveSession, params, &ok)
return
}
// `system.multicall(methods)`
// This methods encapsulates multiple method calls in a single request.
// methods is an array of structs.
// The structs contain two keys: methodName and params.
// methodName is the method name to call and params is array containing parameters to the method call.
// This method returns an array of responses.
// The elements will be either a one-item array containing the return value of the method call or a struct of fault element if an encapsulated method call fails.
func (c *client) Multicall(methods []Method) (r []interface{}, err error) {
if len(methods) == 0 {
err = errInvalidParameter
return
}
err = c.Call(aria2Multicall, methods, &r)
return
}
// `system.listMethods()`
// This method returns the all available RPC methods in an array of string.
// Unlike other methods, this method does not require secret token.
// This is safe because this method jsut returns the available method names.
func (c *client) ListMethods() (methods []string, err error) {
err = c.Call(aria2ListMethods, []string{}, &methods)
return
}

39
pkg/aria2/rpc/const.go Normal file
View File

@ -0,0 +1,39 @@
package rpc
const (
aria2AddURI = "aria2.addUri"
aria2AddTorrent = "aria2.addTorrent"
aria2AddMetalink = "aria2.addMetalink"
aria2Remove = "aria2.remove"
aria2ForceRemove = "aria2.forceRemove"
aria2Pause = "aria2.pause"
aria2PauseAll = "aria2.pauseAll"
aria2ForcePause = "aria2.forcePause"
aria2ForcePauseAll = "aria2.forcePauseAll"
aria2Unpause = "aria2.unpause"
aria2UnpauseAll = "aria2.unpauseAll"
aria2TellStatus = "aria2.tellStatus"
aria2GetURIs = "aria2.getUris"
aria2GetFiles = "aria2.getFiles"
aria2GetPeers = "aria2.getPeers"
aria2GetServers = "aria2.getServers"
aria2TellActive = "aria2.tellActive"
aria2TellWaiting = "aria2.tellWaiting"
aria2TellStopped = "aria2.tellStopped"
aria2ChangePosition = "aria2.changePosition"
aria2ChangeURI = "aria2.changeUri"
aria2GetOption = "aria2.getOption"
aria2ChangeOption = "aria2.changeOption"
aria2GetGlobalOption = "aria2.getGlobalOption"
aria2ChangeGlobalOption = "aria2.changeGlobalOption"
aria2GetGlobalStat = "aria2.getGlobalStat"
aria2PurgeDownloadResult = "aria2.purgeDownloadResult"
aria2RemoveDownloadResult = "aria2.removeDownloadResult"
aria2GetVersion = "aria2.getVersion"
aria2GetSessionInfo = "aria2.getSessionInfo"
aria2Shutdown = "aria2.shutdown"
aria2ForceShutdown = "aria2.forceShutdown"
aria2SaveSession = "aria2.saveSession"
aria2Multicall = "system.multicall"
aria2ListMethods = "system.listMethods"
)

116
pkg/aria2/rpc/json2.go Normal file
View File

@ -0,0 +1,116 @@
package rpc
// based on "github.com/gorilla/rpc/v2/json2"
// Copyright 2009 The Go Authors. All rights reserved.
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
import (
"bytes"
"encoding/json"
"errors"
"io"
)
// ----------------------------------------------------------------------------
// Request and Response
// ----------------------------------------------------------------------------
// clientRequest represents a JSON-RPC request sent by a client.
type clientRequest struct {
// JSON-RPC protocol.
Version string `json:"jsonrpc"`
// A String containing the name of the method to be invoked.
Method string `json:"method"`
// Object to pass as request parameter to the method.
Params interface{} `json:"params"`
// The request id. This can be of any type. It is used to match the
// response with the request that it is replying to.
Id uint64 `json:"id"`
}
// clientResponse represents a JSON-RPC response returned to a client.
type clientResponse struct {
Version string `json:"jsonrpc"`
Result *json.RawMessage `json:"result"`
Error *json.RawMessage `json:"error"`
Id *uint64 `json:"id"`
}
// EncodeClientRequest encodes parameters for a JSON-RPC client request.
func EncodeClientRequest(method string, args interface{}) (*bytes.Buffer, error) {
var buf bytes.Buffer
c := &clientRequest{
Version: "2.0",
Method: method,
Params: args,
Id: reqid(),
}
if err := json.NewEncoder(&buf).Encode(c); err != nil {
return nil, err
}
return &buf, nil
}
func (c clientResponse) decode(reply interface{}) error {
if c.Error != nil {
jsonErr := &Error{}
if err := json.Unmarshal(*c.Error, jsonErr); err != nil {
return &Error{
Code: E_SERVER,
Message: string(*c.Error),
}
}
return jsonErr
}
if c.Result == nil {
return ErrNullResult
}
return json.Unmarshal(*c.Result, reply)
}
// DecodeClientResponse decodes the response body of a client request into
// the interface reply.
func DecodeClientResponse(r io.Reader, reply interface{}) error {
var c clientResponse
if err := json.NewDecoder(r).Decode(&c); err != nil {
return err
}
return c.decode(reply)
}
type ErrorCode int
const (
E_PARSE ErrorCode = -32700
E_INVALID_REQ ErrorCode = -32600
E_NO_METHOD ErrorCode = -32601
E_BAD_PARAMS ErrorCode = -32602
E_INTERNAL ErrorCode = -32603
E_SERVER ErrorCode = -32000
)
var ErrNullResult = errors.New("result is null")
type Error struct {
// A Number that indicates the error type that occurred.
Code ErrorCode `json:"code"` /* required */
// A String providing a short description of the error.
// The message SHOULD be limited to a concise single sentence.
Message string `json:"message"` /* required */
// A Primitive or Structured value that contains additional information about the error.
Data interface{} `json:"data"` /* optional */
}
func (e *Error) Error() string {
return e.Message
}

View File

@ -0,0 +1,44 @@
package rpc
import (
"log"
)
type Event struct {
Gid string `json:"gid"` // GID of the download
}
// The RPC server might send notifications to the client.
// Notifications is unidirectional, therefore the client which receives the notification must not respond to it.
// The method signature of a notification is much like a normal method request but lacks the id key
type websocketResponse struct {
clientResponse
Method string `json:"method"`
Params []Event `json:"params"`
}
// Notifier handles rpc notification from aria2 server
type Notifier interface {
// OnDownloadStart will be sent when a download is started.
OnDownloadStart([]Event)
// OnDownloadPause will be sent when a download is paused.
OnDownloadPause([]Event)
// OnDownloadStop will be sent when a download is stopped by the user.
OnDownloadStop([]Event)
// OnDownloadComplete will be sent when a download is complete. For BitTorrent downloads, this notification is sent when the download is complete and seeding is over.
OnDownloadComplete([]Event)
// OnDownloadError will be sent when a download is stopped due to an error.
OnDownloadError([]Event)
// OnBtDownloadComplete will be sent when a torrent download is complete but seeding is still going on.
OnBtDownloadComplete([]Event)
}
type DummyNotifier struct{}
func (DummyNotifier) OnDownloadStart(events []Event) { log.Printf("%s started.", events) }
func (DummyNotifier) OnDownloadPause(events []Event) { log.Printf("%s paused.", events) }
func (DummyNotifier) OnDownloadStop(events []Event) { log.Printf("%s stopped.", events) }
func (DummyNotifier) OnDownloadComplete(events []Event) { log.Printf("%s completed.", events) }
func (DummyNotifier) OnDownloadError(events []Event) { log.Printf("%s error.", events) }
func (DummyNotifier) OnBtDownloadComplete(events []Event) { log.Printf("bt %s completed.", events) }

42
pkg/aria2/rpc/proc.go Normal file
View File

@ -0,0 +1,42 @@
package rpc
import "sync"
type ResponseProcFn func(resp clientResponse) error
type ResponseProcessor struct {
cbs map[uint64]ResponseProcFn
mu *sync.RWMutex
}
func NewResponseProcessor() *ResponseProcessor {
return &ResponseProcessor{
make(map[uint64]ResponseProcFn),
&sync.RWMutex{},
}
}
func (r *ResponseProcessor) Add(id uint64, fn ResponseProcFn) {
r.mu.Lock()
r.cbs[id] = fn
r.mu.Unlock()
}
func (r *ResponseProcessor) remove(id uint64) {
r.mu.Lock()
delete(r.cbs, id)
r.mu.Unlock()
}
// Process called by recv routine
func (r *ResponseProcessor) Process(resp clientResponse) error {
id := *resp.Id
r.mu.RLock()
fn, ok := r.cbs[id]
r.mu.RUnlock()
if ok && fn != nil {
defer r.remove(id)
return fn(resp)
}
return nil
}

40
pkg/aria2/rpc/proto.go Normal file
View File

@ -0,0 +1,40 @@
package rpc
// Protocol is a set of rpc methods that aria2 daemon supports
type Protocol interface {
AddURI(uri string, options ...interface{}) (gid string, err error)
AddTorrent(filename string, options ...interface{}) (gid string, err error)
AddMetalink(filename string, options ...interface{}) (gid []string, err error)
Remove(gid string) (g string, err error)
ForceRemove(gid string) (g string, err error)
Pause(gid string) (g string, err error)
PauseAll() (ok string, err error)
ForcePause(gid string) (g string, err error)
ForcePauseAll() (ok string, err error)
Unpause(gid string) (g string, err error)
UnpauseAll() (ok string, err error)
TellStatus(gid string, keys ...string) (info StatusInfo, err error)
GetURIs(gid string) (infos []URIInfo, err error)
GetFiles(gid string) (infos []FileInfo, err error)
GetPeers(gid string) (infos []PeerInfo, err error)
GetServers(gid string) (infos []ServerInfo, err error)
TellActive(keys ...string) (infos []StatusInfo, err error)
TellWaiting(offset, num int, keys ...string) (infos []StatusInfo, err error)
TellStopped(offset, num int, keys ...string) (infos []StatusInfo, err error)
ChangePosition(gid string, pos int, how string) (p int, err error)
ChangeURI(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error)
GetOption(gid string) (m Option, err error)
ChangeOption(gid string, option Option) (ok string, err error)
GetGlobalOption() (m Option, err error)
ChangeGlobalOption(options Option) (ok string, err error)
GetGlobalStat() (info GlobalStatInfo, err error)
PurgeDownloadResult() (ok string, err error)
RemoveDownloadResult(gid string) (ok string, err error)
GetVersion() (info VersionInfo, err error)
GetSessionInfo() (info SessionInfo, err error)
Shutdown() (ok string, err error)
ForceShutdown() (ok string, err error)
SaveSession() (ok string, err error)
Multicall(methods []Method) (r []interface{}, err error)
ListMethods() (methods []string, err error)
}

104
pkg/aria2/rpc/resp.go Normal file
View File

@ -0,0 +1,104 @@
//go:generate easyjson -all
package rpc
// StatusInfo represents response of aria2.tellStatus
type StatusInfo struct {
Gid string `json:"gid"` // GID of the download.
Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user.
TotalLength string `json:"totalLength"` // Total length of the download in bytes.
CompletedLength string `json:"completedLength"` // Completed length of the download in bytes.
UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes.
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response.
DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec.
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec.
InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only.
NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only.
Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only.
PieceLength string `json:"pieceLength"` // Piece length in bytes.
NumPieces string `json:"numPieces"` // The number of pieces.
Connections string `json:"connections"` // The number of peers/servers aria2 has connected to.
ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads.
ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode.
FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response.
BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response.
Dir string `json:"dir"` // Directory to save files.
Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method.
BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
}
// URIInfo represents an element of response of aria2.getUris
type URIInfo struct {
URI string `json:"uri"` // URI
Status string `json:"status"` // 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue.
}
// FileInfo represents an element of response of aria2.getFiles
type FileInfo struct {
Index string `json:"index"` // Index of the file, starting at 1, in the same order as files appear in the multi-file torrent.
Path string `json:"path"` // File path.
Length string `json:"length"` // File size in bytes.
CompletedLength string `json:"completedLength"` // Completed length of this file in bytes. Please note that it is possible that sum of completedLength is less than the completedLength returned by the aria2.tellStatus() method. This is because completedLength in aria2.getFiles() only includes completed pieces. On the other hand, completedLength in aria2.tellStatus() also includes partially completed pieces.
Selected string `json:"selected"` // true if this file is selected by --select-file option. If --select-file is not specified or this is single-file torrent or not a torrent download at all, this value is always true. Otherwise false.
URIs []URIInfo `json:"uris"` // Returns a list of URIs for this file. The element type is the same struct used in the aria2.getUris() method.
}
// PeerInfo represents an element of response of aria2.getPeers
type PeerInfo struct {
PeerId string `json:"peerId"` // Percent-encoded peer ID.
IP string `json:"ip"` // IP address of the peer.
Port string `json:"port"` // Port number of the peer.
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress of the peer. The highest bit corresponds to the piece at index 0. Set bits indicate the piece is available and unset bits indicate the piece is missing. Any spare bits at the end are set to zero.
AmChoking string `json:"amChoking"` // true if aria2 is choking the peer. Otherwise false.
PeerChoking string `json:"peerChoking"` // true if the peer is choking aria2. Otherwise false.
DownloadSpeed string `json:"downloadSpeed"` // Download speed (byte/sec) that this client obtains from the peer.
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed(byte/sec) that this client uploads to the peer.
Seeder string `json:"seeder"` // true if this peer is a seeder. Otherwise false.
}
// ServerInfo represents an element of response of aria2.getServers
type ServerInfo struct {
Index string `json:"index"` // Index of the file, starting at 1, in the same order as files appear in the multi-file metalink.
Servers []struct {
URI string `json:"uri"` // Original URI.
CurrentURI string `json:"currentUri"` // This is the URI currently used for downloading. If redirection is involved, currentUri and uri may differ.
DownloadSpeed string `json:"downloadSpeed"` // Download speed (byte/sec)
} `json:"servers"` // A list of structs which contain the following keys.
}
// GlobalStatInfo represents response of aria2.getGlobalStat
type GlobalStatInfo struct {
DownloadSpeed string `json:"downloadSpeed"` // Overall download speed (byte/sec).
UploadSpeed string `json:"uploadSpeed"` // Overall upload speed(byte/sec).
NumActive string `json:"numActive"` // The number of active downloads.
NumWaiting string `json:"numWaiting"` // The number of waiting downloads.
NumStopped string `json:"numStopped"` // The number of stopped downloads in the current session. This value is capped by the --max-download-result option.
NumStoppedTotal string `json:"numStoppedTotal"` // The number of stopped downloads in the current session and not capped by the --max-download-result option.
}
// VersionInfo represents response of aria2.getVersion
type VersionInfo struct {
Version string `json:"version"` // Version number of aria2 as a string.
Features []string `json:"enabledFeatures"` // List of enabled features. Each feature is given as a string.
}
// SessionInfo represents response of aria2.getSessionInfo
type SessionInfo struct {
Id string `json:"sessionId"` // Session ID, which is generated each time when aria2 is invoked.
}
// Method is an element of parameters used in system.multicall
type Method struct {
Name string `json:"methodName"` // Method name to call
Params []interface{} `json:"params"` // Array containing parameters to the method call
}
type BitTorrentInfo struct {
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
Info struct {
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
}

145
pkg/auth/auth.go Normal file
View File

@ -0,0 +1,145 @@
package auth
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"sort"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
var (
ErrAuthFailed = serializer.NewError(serializer.CodeInvalidSign, "invalid sign", nil)
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "signature expired", nil)
)
const CrHeaderPrefix = "X-Cr-"
// General 通用的认证接口
var General Auth
// Auth 鉴权认证
type Auth interface {
// 对给定Body进行签名,expires为0表示永不过期
Sign(body string, expires int64) string
// 对给定Body和Sign进行检查
Check(body string, sign string) error
}
// SignRequest 对PUT\POST等复杂HTTP请求签名只会对URI部分、
// 请求正文、`X-Cr-`开头的header进行签名
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
// 处理有效期
if expires > 0 {
expires += time.Now().Unix()
}
// 生成签名
sign := instance.Sign(getSignContent(r), expires)
// 将签名加到请求Header中
r.Header["Authorization"] = []string{"Bearer " + sign}
return r
}
// CheckRequest 对复杂请求进行签名验证
func CheckRequest(instance Auth, r *http.Request) error {
var (
sign []string
ok bool
)
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
return ErrAuthHeaderMissing
}
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
return instance.Check(getSignContent(r), sign[0])
}
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果请求 path 为从机上传 API
// 则不对正文签名。返回待签名/验证的字符串
func getSignContent(r *http.Request) (rawSignString string) {
// 读取所有body正文
var body = []byte{}
if !strings.Contains(r.URL.Path, "/api/v3/slave/upload/") {
if r.Body != nil {
body, _ = ioutil.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = ioutil.NopCloser(bytes.NewReader(body))
}
}
// 决定要签名的header
var signedHeader []string
for k, _ := range r.Header {
if strings.HasPrefix(k, CrHeaderPrefix) && k != CrHeaderPrefix+"Filename" {
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
}
}
sort.Strings(signedHeader)
// 读取所有待签名Header
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
return rawSignString
}
// SignURI 对URI进行签名,签名只针对Path部分query部分不做验证
func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) {
// 处理有效期
if expires != 0 {
expires += time.Now().Unix()
}
base, err := url.Parse(uri)
if err != nil {
return nil, err
}
// 生成签名
sign := instance.Sign(base.Path, expires)
// 将签名加到URI中
queries := base.Query()
queries.Set("sign", sign)
base.RawQuery = queries.Encode()
return base, nil
}
// CheckURI 对URI进行鉴权
func CheckURI(instance Auth, url *url.URL) error {
//获取待验证的签名正文
queries := url.Query()
sign := queries.Get("sign")
queries.Del("sign")
url.RawQuery = queries.Encode()
return instance.Check(url.Path, sign)
}
// Init 初始化通用鉴权器
func Init() {
var secretKey string
if conf.SystemConfig.Mode == "master" {
secretKey = model.GetSettingByName("secret_key")
} else {
secretKey = conf.SlaveConfig.Secret
if secretKey == "" {
util.Log().Panic("SlaveSecret is not set, please specify it in config file.")
}
}
General = HMACAuth{
SecretKey: []byte(secretKey),
}
}

54
pkg/auth/hmac.go Normal file
View File

@ -0,0 +1,54 @@
package auth
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"io"
"strconv"
"strings"
"time"
)
// HMACAuth HMAC算法鉴权
type HMACAuth struct {
SecretKey []byte
}
// Sign 对给定Body生成expires后失效的签名expires为过期时间戳
// 填写为0表示不限制有效期
func (auth HMACAuth) Sign(body string, expires int64) string {
h := hmac.New(sha256.New, auth.SecretKey)
expireTimeStamp := strconv.FormatInt(expires, 10)
_, err := io.WriteString(h, body+":"+expireTimeStamp)
if err != nil {
return ""
}
return base64.URLEncoding.EncodeToString(h.Sum(nil)) + ":" + expireTimeStamp
}
// Check 对给定Body和Sign进行鉴权包括对expires的检查
func (auth HMACAuth) Check(body string, sign string) error {
signSlice := strings.Split(sign, ":")
// 如果未携带expires字段
if signSlice[len(signSlice)-1] == "" {
return ErrExpiresMissing
}
// 验证是否过期
expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64)
if err != nil {
return ErrAuthFailed.WithError(err)
}
// 如果签名过期
if expires < time.Now().Unix() && expires != 0 {
return ErrExpired
}
// 验证签名
if auth.Sign(body, expires) != sign {
return ErrAuthFailed
}
return nil
}

16
pkg/authn/auth.go Normal file
View File

@ -0,0 +1,16 @@
package authn
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/duo-labs/webauthn/webauthn"
)
// NewAuthnInstance 新建Authn实例
func NewAuthnInstance() (*webauthn.WebAuthn, error) {
base := model.GetSiteURL()
return webauthn.New(&webauthn.Config{
RPDisplayName: model.GetSettingByName("siteName"), // Display Name for your site
RPID: base.Hostname(), // Generally the FQDN for your site
RPOrigin: base.String(), // The origin URL for WebAuthn requests
})
}

15
pkg/balancer/balancer.go Normal file
View File

@ -0,0 +1,15 @@
package balancer
type Balancer interface {
NextPeer(nodes interface{}) (error, interface{})
}
// NewBalancer 根据策略标识返回新的负载均衡器
func NewBalancer(strategy string) Balancer {
switch strategy {
case "RoundRobin":
return &RoundRobin{}
default:
return &RoundRobin{}
}
}

8
pkg/balancer/errors.go Normal file
View File

@ -0,0 +1,8 @@
package balancer
import "errors"
var (
ErrInputNotSlice = errors.New("Input value is not silice")
ErrNoAvaliableNode = errors.New("No nodes avaliable")
)

View File

@ -0,0 +1,30 @@
package balancer
import (
"reflect"
"sync/atomic"
)
type RoundRobin struct {
current uint64
}
// NextPeer 返回轮盘的下一节点
func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
v := reflect.ValueOf(nodes)
if v.Kind() != reflect.Slice {
return ErrInputNotSlice, nil
}
if v.Len() == 0 {
return ErrNoAvaliableNode, nil
}
next := r.NextIndex(v.Len())
return nil, v.Index(next).Interface()
}
// NextIndex 返回下一个节点下标
func (r *RoundRobin) NextIndex(total int) int {
return int(atomic.AddUint64(&r.current, uint64(1)) % uint64(total))
}

104
pkg/cache/driver.go vendored Normal file
View File

@ -0,0 +1,104 @@
package cache
import (
"encoding/gob"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
)
func init() {
gob.Register(map[string]itemWithTTL{})
}
// Store 缓存存储器
var Store Driver = NewMemoStore()
// Init 初始化缓存
func Init() {
if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode {
Store = NewRedisStore(
10,
conf.RedisConfig.Network,
conf.RedisConfig.Server,
conf.RedisConfig.User,
conf.RedisConfig.Password,
conf.RedisConfig.DB,
)
}
}
// Restore restores cache from given disk file
func Restore(persistFile string) {
if err := Store.Restore(persistFile); err != nil {
util.Log().Warning("Failed to restore cache from disk: %s", err)
}
}
func InitSlaveOverwrites() {
err := Store.Sets(conf.OptionOverwrite, "setting_")
if err != nil {
util.Log().Warning("Failed to overwrite database setting: %s", err)
}
}
// Driver 键值缓存存储容器
type Driver interface {
// 设置值ttl为过期时间单位为秒
Set(key string, value interface{}, ttl int) error
// 取值,并返回是否成功
Get(key string) (interface{}, bool)
// 批量取值返回成功取值的map即不存在的值
Gets(keys []string, prefix string) (map[string]interface{}, []string)
// 批量设置值所有的key都会加上prefix前缀
Sets(values map[string]interface{}, prefix string) error
// 删除值
Delete(keys []string, prefix string) error
// Save in-memory cache to disk
Persist(path string) error
// Restore cache from disk
Restore(path string) error
}
// Set 设置缓存值
func Set(key string, value interface{}, ttl int) error {
return Store.Set(key, value, ttl)
}
// Get 获取缓存值
func Get(key string) (interface{}, bool) {
return Store.Get(key)
}
// Deletes 删除值
func Deletes(keys []string, prefix string) error {
return Store.Delete(keys, prefix)
}
// GetSettings 根据名称批量获取设置项缓存
func GetSettings(keys []string, prefix string) (map[string]string, []string) {
raw, miss := Store.Gets(keys, prefix)
res := make(map[string]string, len(raw))
for k, v := range raw {
res[k] = v.(string)
}
return res, miss
}
// SetSettings 批量设置站点设置缓存
func SetSettings(values map[string]string, prefix string) error {
var toBeSet = make(map[string]interface{}, len(values))
for key, value := range values {
toBeSet[key] = interface{}(value)
}
return Store.Sets(toBeSet, prefix)
}

181
pkg/cache/memo.go vendored Normal file
View File

@ -0,0 +1,181 @@
package cache
import (
"encoding/gob"
"fmt"
"os"
"sync"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// MemoStore 内存存储驱动
type MemoStore struct {
Store *sync.Map
}
// item 存储的对象
type itemWithTTL struct {
Expires int64
Value interface{}
}
const DefaultCacheFile = "cache_persist.bin"
func newItem(value interface{}, expires int) itemWithTTL {
expires64 := int64(expires)
if expires > 0 {
expires64 = time.Now().Unix() + expires64
}
return itemWithTTL{
Value: value,
Expires: expires64,
}
}
// getValue 从itemWithTTL中取值
func getValue(item interface{}, ok bool) (interface{}, bool) {
if !ok {
return nil, ok
}
var itemObj itemWithTTL
if itemObj, ok = item.(itemWithTTL); !ok {
return item, true
}
if itemObj.Expires > 0 && itemObj.Expires < time.Now().Unix() {
return nil, false
}
return itemObj.Value, ok
}
// GarbageCollect 回收已过期的缓存
func (store *MemoStore) GarbageCollect() {
store.Store.Range(func(key, value interface{}) bool {
if item, ok := value.(itemWithTTL); ok {
if item.Expires > 0 && item.Expires < time.Now().Unix() {
util.Log().Debug("Cache %q is garbage collected.", key.(string))
store.Store.Delete(key)
}
}
return true
})
}
// NewMemoStore 新建内存存储
func NewMemoStore() *MemoStore {
return &MemoStore{
Store: &sync.Map{},
}
}
// Set 存储值
func (store *MemoStore) Set(key string, value interface{}, ttl int) error {
store.Store.Store(key, newItem(value, ttl))
return nil
}
// Get 取值
func (store *MemoStore) Get(key string) (interface{}, bool) {
return getValue(store.Store.Load(key))
}
// Gets 批量取值
func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
var res = make(map[string]interface{})
var notFound = make([]string, 0, len(keys))
for _, key := range keys {
if value, ok := getValue(store.Store.Load(prefix + key)); ok {
res[key] = value
} else {
notFound = append(notFound, key)
}
}
return res, notFound
}
// Sets 批量设置值
func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error {
for key, value := range values {
store.Store.Store(prefix+key, newItem(value, 0))
}
return nil
}
// Delete 批量删除值
func (store *MemoStore) Delete(keys []string, prefix string) error {
for _, key := range keys {
store.Store.Delete(prefix + key)
}
return nil
}
// Persist write memory store into cache
func (store *MemoStore) Persist(path string) error {
persisted := make(map[string]itemWithTTL)
store.Store.Range(func(key, value interface{}) bool {
v, ok := store.Store.Load(key)
if _, ok := getValue(v, ok); ok {
persisted[key.(string)] = v.(itemWithTTL)
}
return true
})
res, err := serializer(persisted)
if err != nil {
return fmt.Errorf("failed to serialize cache: %s", err)
}
// err = os.WriteFile(path, res, 0644)
file, err := util.CreatNestedFile(path)
if err == nil {
_, err = file.Write(res)
file.Chmod(0644)
file.Close()
}
return err
}
// Restore memory cache from disk file
func (store *MemoStore) Restore(path string) error {
if !util.Exists(path) {
return nil
}
f, err := os.Open(path)
if err != nil {
return fmt.Errorf("failed to read cache file: %s", err)
}
defer func() {
f.Close()
os.Remove(path)
}()
persisted := &item{}
dec := gob.NewDecoder(f)
if err := dec.Decode(&persisted); err != nil {
return fmt.Errorf("unknown cache file format: %s", err)
}
items := persisted.Value.(map[string]itemWithTTL)
loaded := 0
for k, v := range items {
if _, ok := getValue(v, true); ok {
loaded++
store.Store.Store(k, v)
} else {
util.Log().Debug("Persisted cache %q is expired.", k)
}
}
util.Log().Info("Restored %d items from %q into memory cache.", loaded, path)
return nil
}

227
pkg/cache/redis.go vendored Normal file
View File

@ -0,0 +1,227 @@
package cache
import (
"bytes"
"encoding/gob"
"strconv"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gomodule/redigo/redis"
)
// RedisStore redis存储驱动
type RedisStore struct {
pool *redis.Pool
}
type item struct {
Value interface{}
}
func serializer(value interface{}) ([]byte, error) {
var buffer bytes.Buffer
enc := gob.NewEncoder(&buffer)
storeValue := item{
Value: value,
}
err := enc.Encode(storeValue)
if err != nil {
return nil, err
}
return buffer.Bytes(), nil
}
func deserializer(value []byte) (interface{}, error) {
var res item
buffer := bytes.NewReader(value)
dec := gob.NewDecoder(buffer)
err := dec.Decode(&res)
if err != nil {
return nil, err
}
return res.Value, nil
}
// NewRedisStore 创建新的redis存储
func NewRedisStore(size int, network, address, user, password, database string) *RedisStore {
return &RedisStore{
pool: &redis.Pool{
MaxIdle: size,
IdleTimeout: 240 * time.Second,
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
Dial: func() (redis.Conn, error) {
db, err := strconv.Atoi(database)
if err != nil {
return nil, err
}
c, err := redis.Dial(
network,
address,
redis.DialDatabase(db),
redis.DialUsername(user),
redis.DialPassword(password),
)
if err != nil {
util.Log().Panic("Failed to create Redis connection: %s", err)
}
return c, nil
},
},
}
}
// Set 存储值
func (store *RedisStore) Set(key string, value interface{}, ttl int) error {
rc := store.pool.Get()
defer rc.Close()
serialized, err := serializer(value)
if err != nil {
return err
}
if rc.Err() != nil {
return rc.Err()
}
if ttl > 0 {
_, err = rc.Do("SETEX", key, ttl, serialized)
} else {
_, err = rc.Do("SET", key, serialized)
}
if err != nil {
return err
}
return nil
}
// Get 取值
func (store *RedisStore) Get(key string) (interface{}, bool) {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return nil, false
}
v, err := redis.Bytes(rc.Do("GET", key))
if err != nil || v == nil {
return nil, false
}
finalValue, err := deserializer(v)
if err != nil {
return nil, false
}
return finalValue, true
}
// Gets 批量取值
func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return nil, keys
}
var queryKeys = make([]string, len(keys))
for key, value := range keys {
queryKeys[key] = prefix + value
}
v, err := redis.ByteSlices(rc.Do("MGET", redis.Args{}.AddFlat(queryKeys)...))
if err != nil {
return nil, keys
}
var res = make(map[string]interface{})
var missed = make([]string, 0, len(keys))
for key, value := range v {
decoded, err := deserializer(value)
if err != nil || decoded == nil {
missed = append(missed, keys[key])
} else {
res[keys[key]] = decoded
}
}
// 解码所得值
return res, missed
}
// Sets 批量设置值
func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return rc.Err()
}
var setValues = make(map[string]interface{})
// 编码待设置值
for key, value := range values {
serialized, err := serializer(value)
if err != nil {
return err
}
setValues[prefix+key] = serialized
}
_, err := rc.Do("MSET", redis.Args{}.AddFlat(setValues)...)
if err != nil {
return err
}
return nil
}
// Delete 批量删除给定的键
func (store *RedisStore) Delete(keys []string, prefix string) error {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return rc.Err()
}
// 处理前缀
for i := 0; i < len(keys); i++ {
keys[i] = prefix + keys[i]
}
_, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...)
if err != nil {
return err
}
return nil
}
// DeleteAll 批量所有键
func (store *RedisStore) DeleteAll() error {
rc := store.pool.Get()
defer rc.Close()
if rc.Err() != nil {
return rc.Err()
}
_, err := rc.Do("FLUSHDB")
return err
}
// Persist Dummy implementation
func (store *RedisStore) Persist(path string) error {
return nil
}
// Restore dummy implementation
func (store *RedisStore) Restore(path string) error {
return nil
}

210
pkg/cluster/controller.go Normal file
View File

@ -0,0 +1,210 @@
package cluster
import (
"bytes"
"encoding/gob"
"fmt"
"net/url"
"sync"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/jinzhu/gorm"
)
var DefaultController Controller
// Controller controls communications between master and slave
type Controller interface {
// Handle heartbeat sent from master
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
// Get Aria2 Instance by master node ID
GetAria2Instance(string) (common.Aria2, error)
// Send event change message to master node
SendNotification(string, string, mq.Message) error
// Submit async task into task pool
SubmitTask(string, interface{}, string, func(interface{})) error
// Get master node info
GetMasterInfo(string) (*MasterInfo, error)
// Get master Oauth based policy credential
GetPolicyOauthToken(string, uint) (string, error)
}
type slaveController struct {
masters map[string]MasterInfo
lock sync.RWMutex
}
// info of master node
type MasterInfo struct {
ID string
TTL int
URL *url.URL
// used to invoke aria2 rpc calls
Instance Node
Client request.Client
jobTracker map[string]bool
}
func InitController() {
DefaultController = &slaveController{
masters: make(map[string]MasterInfo),
}
gob.Register(rpc.StatusInfo{})
}
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
c.lock.Lock()
defer c.lock.Unlock()
req.Node.AfterFind()
// close old node if exist
origin, ok := c.masters[req.SiteID]
if (ok && req.IsUpdate) || !ok {
if ok {
origin.Instance.Kill()
}
masterUrl, err := url.Parse(req.SiteURL)
if err != nil {
return serializer.NodePingResp{}, err
}
c.masters[req.SiteID] = MasterInfo{
ID: req.SiteID,
URL: masterUrl,
TTL: req.CredentialTTL,
Client: request.NewClient(
request.WithEndpoint(masterUrl.String()),
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
request.WithCredential(auth.HMACAuth{
SecretKey: []byte(req.Node.MasterKey),
}, int64(req.CredentialTTL)),
),
jobTracker: make(map[string]bool),
Instance: NewNodeFromDBModel(&model.Node{
Model: gorm.Model{ID: req.Node.ID},
MasterKey: req.Node.MasterKey,
Type: model.MasterNodeType,
Aria2Enabled: req.Node.Aria2Enabled,
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
}),
}
}
return serializer.NodePingResp{}, nil
}
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
return node.Instance.GetAria2Instance(), nil
}
return nil, ErrMasterNotFound
}
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
c.lock.RLock()
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
body := bytes.Buffer{}
enc := gob.NewEncoder(&body)
if err := enc.Encode(&msg); err != nil {
return err
}
res, err := node.Client.Request(
"PUT",
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
&body,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
c.lock.RUnlock()
return ErrMasterNotFound
}
// SubmitTask 提交异步任务
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
if _, ok := node.jobTracker[hash]; ok {
// 任务已存在,直接返回
return nil
}
node.jobTracker[hash] = true
submitter(job)
return nil
}
return ErrMasterNotFound
}
// GetMasterInfo 获取主机节点信息
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
c.lock.RLock()
defer c.lock.RUnlock()
if node, ok := c.masters[id]; ok {
return &node, nil
}
return nil, ErrMasterNotFound
}
// GetPolicyOauthToken 获取主机存储策略 Oauth 凭证
func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) {
c.lock.RLock()
if node, ok := c.masters[id]; ok {
c.lock.RUnlock()
res, err := node.Client.Request(
"GET",
fmt.Sprintf("/api/v3/slave/credential/%d", policyID),
nil,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return "", err
}
if res.Code != 0 {
return "", serializer.NewErrorFromResponse(res)
}
return res.Data.(string), nil
}
c.lock.RUnlock()
return "", ErrMasterNotFound
}

12
pkg/cluster/errors.go Normal file
View File

@ -0,0 +1,12 @@
package cluster
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "Unknown master node id", nil)
)

272
pkg/cluster/master.go Normal file
View File

@ -0,0 +1,272 @@
package cluster
import (
"context"
"encoding/json"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
const (
deleteTempFileDuration = 60 * time.Second
statusRetryDuration = 10 * time.Second
)
type MasterNode struct {
Model *model.Node
aria2RPC rpcService
lock sync.RWMutex
}
// RPCService 通过RPC服务的Aria2任务管理器
type rpcService struct {
Caller rpc.Client
Initialized bool
retryDuration time.Duration
deletePaddingDuration time.Duration
parent *MasterNode
options *clientOptions
}
type clientOptions struct {
Options map[string]interface{} // 创建下载时额外添加的设置
}
// Init 初始化节点
func (node *MasterNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
node.aria2RPC.parent = node
node.aria2RPC.retryDuration = statusRetryDuration
node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
node.lock.Unlock()
node.lock.RLock()
if node.Model.Aria2Enabled {
node.lock.RUnlock()
node.aria2RPC.Init()
return
}
node.lock.RUnlock()
}
func (node *MasterNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
return &serializer.NodePingResp{}, nil
}
// IsFeatureEnabled 查询节点的某项功能是否启用
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
node.lock.RLock()
defer node.lock.RUnlock()
switch feature {
case "aria2":
return node.Model.Aria2Enabled
default:
return false
}
}
func (node *MasterNode) MasterAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
}
// SubscribeStatusChange 订阅节点状态更改
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
}
// IsActive 返回节点是否在线
func (node *MasterNode) IsActive() bool {
return true
}
// Kill 结束aria2请求
func (node *MasterNode) Kill() {
if node.aria2RPC.Caller != nil {
node.aria2RPC.Caller.Close()
}
}
// GetAria2Instance 获取主机Aria2实例
func (node *MasterNode) GetAria2Instance() common.Aria2 {
node.lock.RLock()
if !node.Model.Aria2Enabled {
node.lock.RUnlock()
return &common.DummyAria2{}
}
if !node.aria2RPC.Initialized {
node.lock.RUnlock()
node.aria2RPC.Init()
return &common.DummyAria2{}
}
defer node.lock.RUnlock()
return &node.aria2RPC
}
func (node *MasterNode) IsMater() bool {
return true
}
func (node *MasterNode) DBModel() *model.Node {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model
}
func (r *rpcService) Init() error {
r.parent.lock.Lock()
defer r.parent.lock.Unlock()
r.Initialized = false
// 客户端已存在,则关闭先前连接
if r.Caller != nil {
r.Caller.Close()
}
// 解析RPC服务地址
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
if err != nil {
util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err)
return err
}
server.Path = "/jsonrpc"
// 加载自定义下载配置
var globalOptions map[string]interface{}
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
if err != nil {
util.Log().Warning("Failed to parse aria2 options: %s", err)
return err
}
}
r.options = &clientOptions{
Options: globalOptions,
}
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
r.Caller = caller
r.Initialized = err == nil
return err
}
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
r.parent.lock.RLock()
// 生成存储路径
guid, _ := uuid.NewV4()
path := filepath.Join(
r.parent.Model.Aria2OptionsSerialized.TempPath,
"aria2",
guid.String(),
)
r.parent.lock.RUnlock()
// 创建下载任务
options := map[string]interface{}{
"dir": path,
}
for k, v := range r.options.Options {
options[k] = v
}
for k, v := range groupOptions {
options[k] = v
}
gid, err := r.Caller.AddURI(task.Source, options)
if err != nil || gid == "" {
return "", err
}
return gid, nil
}
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
res, err := r.Caller.TellStatus(task.GID)
if err != nil {
// 失败后重试
util.Log().Debug("Failed to get download task status, please retry later: %s", err)
time.Sleep(r.retryDuration)
res, err = r.Caller.TellStatus(task.GID)
}
return res, err
}
func (r *rpcService) Cancel(task *model.Download) error {
// 取消下载任务
_, err := r.Caller.Remove(task.GID)
if err != nil {
util.Log().Warning("Failed to cancel task %q: %s", task.GID, err)
}
return err
}
func (r *rpcService) Select(task *model.Download, files []int) error {
var selected = make([]string, len(files))
for i := 0; i < len(files); i++ {
selected[i] = strconv.Itoa(files[i])
}
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
return err
}
func (r *rpcService) GetConfig() model.Aria2Option {
r.parent.lock.RLock()
defer r.parent.lock.RUnlock()
return r.parent.Model.Aria2OptionsSerialized
}
func (s *rpcService) DeleteTempFile(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
// 避免被aria2占用异步执行删除
go func(d time.Duration, src string) {
time.Sleep(d)
err := os.RemoveAll(src)
if err != nil {
util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err)
}
}(s.deletePaddingDuration, task.Parent)
return nil
}

60
pkg/cluster/node.go Normal file
View File

@ -0,0 +1,60 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
type Node interface {
// Init a node from database model
Init(node *model.Node)
// Check if given feature is enabled
IsFeatureEnabled(feature string) bool
// Subscribe node status change to a callback function
SubscribeStatusChange(callback func(isActive bool, id uint))
// Ping the node
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
// Returns if the node is active
IsActive() bool
// Get instances for aria2 calls
GetAria2Instance() common.Aria2
// Returns unique id of this node
ID() uint
// Kill node and recycle resources
Kill()
// Returns if current node is master node
IsMater() bool
// Get auth instance used to check RPC call from slave to master
MasterAuthInstance() auth.Auth
// Get auth instance used to check RPC call from master to slave
SlaveAuthInstance() auth.Auth
// Get node DB model
DBModel() *model.Node
}
// Create new node from DB model
func NewNodeFromDBModel(node *model.Node) Node {
switch node.Type {
case model.SlaveNodeType:
slave := &SlaveNode{}
slave.Init(node)
return slave
default:
master := &MasterNode{}
master.Init(node)
return master
}
}

213
pkg/cluster/pool.go Normal file
View File

@ -0,0 +1,213 @@
package cluster
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/samber/lo"
"sync"
)
var Default *NodePool
// 需要分类的节点组
var featureGroup = []string{"aria2"}
// Pool 节点池
type Pool interface {
// Returns active node selected by given feature and load balancer
BalanceNodeByFeature(feature string, lb balancer.Balancer, available []uint, prefer uint) (error, Node)
// Returns node by ID
GetNodeByID(id uint) Node
// Add given node into pool. If node existed, refresh node.
Add(node *model.Node)
// Delete and kill node from pool by given node id
Delete(id uint)
}
// NodePool 通用节点池
type NodePool struct {
active map[uint]Node
inactive map[uint]Node
featureMap map[string][]Node
lock sync.RWMutex
}
// Init 初始化从机节点池
func Init() {
Default = &NodePool{}
Default.Init()
if err := Default.initFromDB(); err != nil {
util.Log().Warning("Failed to initialize node pool: %s", err)
}
}
func (pool *NodePool) Init() {
pool.lock.Lock()
defer pool.lock.Unlock()
pool.featureMap = make(map[string][]Node)
pool.active = make(map[uint]Node)
pool.inactive = make(map[uint]Node)
}
func (pool *NodePool) buildIndexMap() {
pool.lock.Lock()
for _, feature := range featureGroup {
pool.featureMap[feature] = make([]Node, 0)
}
for _, v := range pool.active {
for _, feature := range featureGroup {
if v.IsFeatureEnabled(feature) {
pool.featureMap[feature] = append(pool.featureMap[feature], v)
}
}
}
pool.lock.Unlock()
}
func (pool *NodePool) GetNodeByID(id uint) Node {
pool.lock.RLock()
defer pool.lock.RUnlock()
if node, ok := pool.active[id]; ok {
return node
}
return pool.inactive[id]
}
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
util.Log().Debug("Slave node [ID=%d] status changed to [Active=%t].", id, isActive)
var node Node
pool.lock.Lock()
if n, ok := pool.inactive[id]; ok {
node = n
delete(pool.inactive, id)
} else {
node = pool.active[id]
delete(pool.active, id)
}
if isActive {
pool.active[id] = node
} else {
pool.inactive[id] = node
}
pool.lock.Unlock()
pool.buildIndexMap()
}
func (pool *NodePool) initFromDB() error {
nodes, err := model.GetNodesByStatus(model.NodeActive)
if err != nil {
return err
}
pool.lock.Lock()
for i := 0; i < len(nodes); i++ {
pool.add(&nodes[i])
}
pool.lock.Unlock()
pool.buildIndexMap()
return nil
}
func (pool *NodePool) add(node *model.Node) {
newNode := NewNodeFromDBModel(node)
if newNode.IsActive() {
pool.active[node.ID] = newNode
} else {
pool.inactive[node.ID] = newNode
}
// 订阅节点状态变更
newNode.SubscribeStatusChange(func(isActive bool, id uint) {
pool.nodeStatusChange(isActive, id)
})
}
func (pool *NodePool) Add(node *model.Node) {
pool.lock.Lock()
defer pool.buildIndexMap()
defer pool.lock.Unlock()
var (
old Node
ok bool
)
if old, ok = pool.active[node.ID]; !ok {
old, ok = pool.inactive[node.ID]
}
if old != nil {
go old.Init(node)
return
}
pool.add(node)
}
func (pool *NodePool) Delete(id uint) {
pool.lock.Lock()
defer pool.buildIndexMap()
defer pool.lock.Unlock()
if node, ok := pool.active[id]; ok {
node.Kill()
delete(pool.active, id)
return
}
if node, ok := pool.inactive[id]; ok {
node.Kill()
delete(pool.inactive, id)
return
}
}
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer,
available []uint, prefer uint) (error, Node) {
pool.lock.RLock()
defer pool.lock.RUnlock()
if nodes, ok := pool.featureMap[feature]; ok {
// Find nodes that are allowed to be used in user group
availableNodes := nodes
if len(available) > 0 {
idHash := make(map[uint]struct{}, len(available))
for _, id := range available {
idHash[id] = struct{}{}
}
availableNodes = lo.Filter[Node](nodes, func(node Node, index int) bool {
_, exist := idHash[node.ID()]
return exist
})
}
// Return preferred node if exists
if preferredNode, found := lo.Find[Node](availableNodes, func(node Node) bool {
return node.ID() == prefer
}); found {
return nil, preferredNode
}
err, res := lb.NextPeer(availableNodes)
if err == nil {
return nil, res.(Node)
}
return err, nil
}
return ErrFeatureNotExist, nil
}

451
pkg/cluster/slave.go Normal file
View File

@ -0,0 +1,451 @@
package cluster
import (
"bytes"
"encoding/json"
"errors"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"net/url"
"strings"
"sync"
"time"
)
type SlaveNode struct {
Model *model.Node
Active bool
caller slaveCaller
callback func(bool, uint)
close chan bool
lock sync.RWMutex
}
type slaveCaller struct {
parent *SlaveNode
Client request.Client
}
// Init 初始化节点
func (node *SlaveNode) Init(nodeModel *model.Node) {
node.lock.Lock()
node.Model = nodeModel
// Init http request client
var endpoint *url.URL
if serverURL, err := url.Parse(node.Model.Server); err == nil {
var controller *url.URL
controller, _ = url.Parse("/api/v3/slave/")
endpoint = serverURL.ResolveReference(controller)
}
signTTL := model.GetIntSetting("slave_api_timeout", 60)
node.caller.Client = request.NewClient(
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
request.WithEndpoint(endpoint.String()),
)
node.caller.parent = node
if node.close != nil {
node.lock.Unlock()
node.close <- true
go node.StartPingLoop()
} else {
node.Active = true
node.lock.Unlock()
go node.StartPingLoop()
}
}
// IsFeatureEnabled 查询节点的某项功能是否启用
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
node.lock.RLock()
defer node.lock.RUnlock()
switch feature {
case "aria2":
return node.Model.Aria2Enabled
default:
return false
}
}
// SubscribeStatusChange 订阅节点状态更改
func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
node.lock.Lock()
node.callback = callback
node.lock.Unlock()
}
// Ping 从机节点,返回从机负载
func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
node.lock.RLock()
defer node.lock.RUnlock()
reqBodyEncoded, err := json.Marshal(req)
if err != nil {
return nil, err
}
bodyReader := strings.NewReader(string(reqBodyEncoded))
resp, err := node.caller.Client.Request(
"POST",
"heartbeat",
bodyReader,
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return nil, err
}
// 处理列取结果
if resp.Code != 0 {
return nil, serializer.NewErrorFromResponse(resp)
}
var res serializer.NodePingResp
if resStr, ok := resp.Data.(string); ok {
err = json.Unmarshal([]byte(resStr), &res)
if err != nil {
return nil, err
}
}
return &res, nil
}
// IsActive 返回节点是否在线
func (node *SlaveNode) IsActive() bool {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Active
}
// Kill 结束节点内相关循环
func (node *SlaveNode) Kill() {
node.lock.RLock()
defer node.lock.RUnlock()
if node.close != nil {
close(node.close)
}
}
// GetAria2Instance 获取从机Aria2实例
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
node.lock.RLock()
defer node.lock.RUnlock()
if !node.Model.Aria2Enabled {
return &common.DummyAria2{}
}
return &node.caller
}
func (node *SlaveNode) ID() uint {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model.ID
}
func (node *SlaveNode) StartPingLoop() {
node.lock.Lock()
node.close = make(chan bool)
node.lock.Unlock()
tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
pingTicker := time.Duration(0)
util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name)
retry := 0
recoverMode := false
isFirstLoop := true
loop:
for {
select {
case <-time.After(pingTicker):
if pingTicker == 0 {
pingTicker = tickDuration
}
util.Log().Debug("Slave node %q send ping.", node.Model.Name)
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
isFirstLoop = false
if err != nil {
util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err)
retry++
if retry >= model.GetIntSetting("slave_node_retry", 3) {
util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name)
node.changeStatus(false)
if !recoverMode {
// 启动恢复监控循环
util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name)
pingTicker = recoverDuration
recoverMode = true
}
}
} else {
if recoverMode {
util.Log().Debug("Slave node %q recovered.", node.Model.Name)
pingTicker = tickDuration
recoverMode = false
isFirstLoop = true
}
util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res)
node.changeStatus(true)
retry = 0
}
case <-node.close:
util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name)
break loop
}
}
}
func (node *SlaveNode) IsMater() bool {
return false
}
func (node *SlaveNode) MasterAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
}
func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
node.lock.RLock()
defer node.lock.RUnlock()
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
}
func (node *SlaveNode) DBModel() *model.Node {
node.lock.RLock()
defer node.lock.RUnlock()
return node.Model
}
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
return &serializer.NodePingReq{
SiteURL: model.GetSiteURL().String(),
IsUpdate: isUpdate,
SiteID: model.GetSettingByName("siteID"),
Node: node.Model,
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
}
}
func (node *SlaveNode) changeStatus(isActive bool) {
node.lock.RLock()
id := node.Model.ID
if isActive != node.Active {
node.lock.RUnlock()
node.lock.Lock()
node.Active = isActive
node.lock.Unlock()
node.callback(isActive, id)
} else {
node.lock.RUnlock()
}
}
func (s *slaveCaller) Init() error {
return nil
}
// SendAria2Call send remote aria2 call to slave node
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
reqReader, err := getAria2RequestBody(body)
if err != nil {
return nil, err
}
return s.Client.Request(
"POST",
"aria2/"+scope,
reqReader,
).CheckHTTPResponse(200).DecodeResponse()
}
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
GroupOptions: options,
}
res, err := s.SendAria2Call(req, "task")
if err != nil {
return "", err
}
if res.Code != 0 {
return "", serializer.NewErrorFromResponse(res)
}
return res.Data.(string), err
}
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "status")
if err != nil {
return rpc.StatusInfo{}, err
}
if res.Code != 0 {
return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
}
var status rpc.StatusInfo
res.GobDecode(&status)
return status, err
}
func (s *slaveCaller) Cancel(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "cancel")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) Select(task *model.Download, files []int) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
Files: files,
}
res, err := s.SendAria2Call(req, "select")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func (s *slaveCaller) GetConfig() model.Aria2Option {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
return s.parent.Model.Aria2OptionsSerialized
}
func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
s.parent.lock.RLock()
defer s.parent.lock.RUnlock()
req := &serializer.SlaveAria2Call{
Task: task,
}
res, err := s.SendAria2Call(req, "delete")
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
return nil
}
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
reqBodyEncoded, err := json.Marshal(body)
if err != nil {
return nil, err
}
return strings.NewReader(string(reqBodyEncoded)), nil
}
// RemoteCallback 发送远程存储策略上传回调请求
func RemoteCallback(url string, body serializer.UploadCallback) error {
callbackBody, err := json.Marshal(struct {
Data serializer.UploadCallback `json:"data"`
}{
Data: body,
})
if err != nil {
return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err)
}
resp := request.GeneralClient.Request(
"POST",
url,
bytes.NewReader(callbackBody),
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
)
if resp.Err != nil {
return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
}
// 解析回调服务端响应
response, err := resp.DecodeResponse()
if err != nil {
msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
return serializer.NewError(serializer.CodeCallbackError, msg, err)
}
if response.Code != 0 {
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
}
return nil
}

156
pkg/conf/conf.go Normal file
View File

@ -0,0 +1,156 @@
package conf
import (
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/go-ini/ini"
"github.com/go-playground/validator/v10"
)
// database 数据库
type database struct {
Type string
User string
Password string
Host string
Name string
TablePrefix string
DBFile string
Port int
Charset string
UnixSocket bool
}
// system 系统通用配置
type system struct {
Mode string `validate:"eq=master|eq=slave"`
Listen string `validate:"required"`
Debug bool
SessionSecret string
HashIDSalt string
GracePeriod int `validate:"gte=0"`
ProxyHeader string `validate:"required_with=Listen"`
}
type ssl struct {
CertPath string `validate:"omitempty,required"`
KeyPath string `validate:"omitempty,required"`
Listen string `validate:"required"`
}
type unix struct {
Listen string
Perm uint32
}
// slave 作为slave存储端配置
type slave struct {
Secret string `validate:"omitempty,gte=64"`
CallbackTimeout int `validate:"omitempty,gte=1"`
SignatureTTL int `validate:"omitempty,gte=1"`
}
// redis 配置
type redis struct {
Network string
Server string
User string
Password string
DB string
}
// 跨域配置
type cors struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
AllowCredentials bool
ExposeHeaders []string
SameSite string
Secure bool
}
var cfg *ini.File
const defaultConf = `[System]
Debug = false
Mode = master
Listen = :5212
SessionSecret = {SessionSecret}
HashIDSalt = {HashIDSalt}
`
// Init 初始化配置文件
func Init(path string) {
var err error
if path == "" || !util.Exists(path) {
// 创建初始配置文件
confContent := util.Replace(map[string]string{
"{SessionSecret}": util.RandStringRunes(64),
"{HashIDSalt}": util.RandStringRunes(64),
}, defaultConf)
f, err := util.CreatNestedFile(path)
if err != nil {
util.Log().Panic("Failed to create config file: %s", err)
}
// 写入配置文件
_, err = f.WriteString(confContent)
if err != nil {
util.Log().Panic("Failed to write config file: %s", err)
}
f.Close()
}
cfg, err = ini.Load(path)
if err != nil {
util.Log().Panic("Failed to parse config file %q: %s", path, err)
}
sections := map[string]interface{}{
"Database": DatabaseConfig,
"System": SystemConfig,
"SSL": SSLConfig,
"UnixSocket": UnixConfig,
"Redis": RedisConfig,
"CORS": CORSConfig,
"Slave": SlaveConfig,
}
for sectionName, sectionStruct := range sections {
err = mapSection(sectionName, sectionStruct)
if err != nil {
util.Log().Panic("Failed to parse config section %q: %s", sectionName, err)
}
}
// 映射数据库配置覆盖
for _, key := range cfg.Section("OptionOverwrite").Keys() {
OptionOverwrite[key.Name()] = key.Value()
}
// 重设log等级
if !SystemConfig.Debug {
util.Level = util.LevelInformational
util.GloablLogger = nil
util.Log()
}
}
// mapSection 将配置文件的 Section 映射到结构体上
func mapSection(section string, confStruct interface{}) error {
err := cfg.Section(section).MapTo(confStruct)
if err != nil {
return err
}
// 验证合法性
validate := validator.New()
err = validate.Struct(confStruct)
if err != nil {
return err
}
return nil
}

55
pkg/conf/defaults.go Normal file
View File

@ -0,0 +1,55 @@
package conf
// RedisConfig Redis服务器配置
var RedisConfig = &redis{
Network: "tcp",
Server: "",
Password: "",
DB: "0",
}
// DatabaseConfig 数据库配置
var DatabaseConfig = &database{
Type: "UNSET",
Charset: "utf8",
DBFile: "cloudreve.db",
Port: 3306,
UnixSocket: false,
}
// SystemConfig 系统公用配置
var SystemConfig = &system{
Debug: false,
Mode: "master",
Listen: ":5212",
ProxyHeader: "X-Forwarded-For",
}
// CORSConfig 跨域配置
var CORSConfig = &cors{
AllowOrigins: []string{"UNSET"},
AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"},
AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"},
AllowCredentials: false,
ExposeHeaders: nil,
SameSite: "Default",
Secure: false,
}
// SlaveConfig 从机配置
var SlaveConfig = &slave{
CallbackTimeout: 20,
SignatureTTL: 60,
}
var SSLConfig = &ssl{
Listen: ":443",
CertPath: "",
KeyPath: "",
}
var UnixConfig = &unix{
Listen: "",
}
var OptionOverwrite = map[string]interface{}{}

22
pkg/conf/version.go Normal file
View File

@ -0,0 +1,22 @@
package conf
// plusVersion 增强版版本号
const plusVersion = "+1.1"
// BackendVersion 当前后端版本号
const BackendVersion = "3.8.3" + plusVersion
// KeyVersion 授权版本号
const KeyVersion = "3.3.1"
// RequiredDBVersion 与当前版本匹配的数据库版本
const RequiredDBVersion = "3.8.1+1.0-plus"
// RequiredStaticVersion 与当前版本匹配的静态资源版本
const RequiredStaticVersion = "3.8.3" + plusVersion
// IsPlus 是否为Plus版本
const IsPlus = "true"
// LastCommit 最后commit id
const LastCommit = "88409cc"

99
pkg/crontab/collect.go Normal file
View File

@ -0,0 +1,99 @@
package crontab
import (
"context"
"os"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
func garbageCollect() {
// 清理打包下载产生的临时文件
collectArchiveFile()
// 清理过期的内置内存缓存
if store, ok := cache.Store.(*cache.MemoStore); ok {
collectCache(store)
}
util.Log().Info("Crontab job \"cron_garbage_collect\" complete.")
}
func collectArchiveFile() {
// 读取有效期、目录设置
tempPath := util.RelativePath(model.GetSettingByName("temp_path"))
expires := model.GetIntSetting("download_timeout", 30)
// 列出文件
root := filepath.Join(tempPath, "archive")
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err == nil && !info.IsDir() &&
strings.HasPrefix(filepath.Base(path), "archive_") &&
time.Now().Sub(info.ModTime()).Seconds() > float64(expires) {
util.Log().Debug("Delete expired batch download temp file %q.", path)
// 删除符合条件的文件
if err := os.Remove(path); err != nil {
util.Log().Debug("Failed to delete temp file %q: %s", path, err)
}
}
return nil
})
if err != nil {
util.Log().Debug("Crontab job cannot list temp batch download folder: %s", err)
}
}
func collectCache(store *cache.MemoStore) {
util.Log().Debug("Cleanup memory cache.")
store.GarbageCollect()
}
func uploadSessionCollect() {
placeholders := model.GetUploadPlaceholderFiles(0)
// 将过期的上传会话按照用户分组
userToFiles := make(map[uint][]uint)
for _, file := range placeholders {
_, sessionExist := cache.Get(filesystem.UploadSessionCachePrefix + *file.UploadSessionID)
if sessionExist {
continue
}
if _, ok := userToFiles[file.UserID]; !ok {
userToFiles[file.UserID] = make([]uint, 0)
}
userToFiles[file.UserID] = append(userToFiles[file.UserID], file.ID)
}
// 删除过期的会话
for uid, filesIDs := range userToFiles {
user, err := model.GetUserByID(uid)
if err != nil {
util.Log().Warning("Owner of the upload session cannot be found: %s", err)
continue
}
fs, err := filesystem.NewFileSystem(&user)
if err != nil {
util.Log().Warning("Failed to initialize filesystem: %s", err)
continue
}
if err = fs.Delete(context.Background(), []uint{}, filesIDs, false, false); err != nil {
util.Log().Warning("Failed to delete upload session: %s", err)
}
fs.Recycle()
}
util.Log().Info("Crontab job \"cron_recycle_upload_session\" complete.")
}

53
pkg/crontab/init.go Normal file
View File

@ -0,0 +1,53 @@
package crontab
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/robfig/cron/v3"
)
// Cron 定时任务
var Cron *cron.Cron
// Reload 重新启动定时任务
func Reload() {
if Cron != nil {
Cron.Stop()
}
Init()
}
// Init 初始化定时任务
func Init() {
util.Log().Info("Initialize crontab jobs...")
// 读取cron日程设置
options := model.GetSettingByNames(
"cron_garbage_collect",
"cron_notify_user",
"cron_ban_user",
"cron_recycle_upload_session",
)
Cron := cron.New()
for k, v := range options {
var handler func()
switch k {
case "cron_garbage_collect":
handler = garbageCollect
case "cron_notify_user":
handler = notifyExpiredVAS
case "cron_ban_user":
handler = banOverusedUser
case "cron_recycle_upload_session":
handler = uploadSessionCollect
default:
util.Log().Warning("Unknown crontab job type %q, skipping...", k)
continue
}
if _, err := Cron.AddFunc(v, handler); err != nil {
util.Log().Warning("Failed to start crontab job %q: %s", k, err)
}
}
Cron.Start()
}

83
pkg/crontab/vas.go Normal file
View File

@ -0,0 +1,83 @@
package crontab
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/email"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
func notifyExpiredVAS() {
checkStoragePack()
checkUserGroup()
util.Log().Info("Crontab job \"cron_notify_user\" complete.")
}
// banOverusedUser 封禁超出宽容期的用户
func banOverusedUser() {
users := model.GetTolerantExpiredUser()
for _, user := range users {
// 清除最后通知日期标记
user.ClearNotified()
// 检查容量是否超额
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
// 封禁用户
user.SetStatus(model.OveruseBaned)
}
}
}
// checkUserGroup 检查已过期用户组
func checkUserGroup() {
users := model.GetGroupExpiredUsers()
for _, user := range users {
// 将用户回退到初始用户组
user.GroupFallback()
// 重新加载用户
user, _ = model.GetUserByID(user.ID)
// 检查容量是否超额
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
// 如果超额,则通知用户
sendNotification(&user, "用户组过期")
// 更新最后通知日期
user.Notified()
}
}
}
// checkStoragePack 检查已过期的容量包
func checkStoragePack() {
packs := model.GetExpiredStoragePack()
for _, pack := range packs {
// 删除过期的容量包
pack.Delete()
//找到所属用户
user, err := model.GetUserByID(pack.UserID)
if err != nil {
util.Log().Warning("Crontab job failed to get user info of [UID=%d]: %s", pack.UserID, err)
continue
}
// 检查容量是否超额
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
// 如果超额,则通知用户
sendNotification(&user, "容量包过期")
// 更新最后通知日期
user.Notified()
}
}
}
func sendNotification(user *model.User, reason string) {
title, body := email.NewOveruseNotification(user.Nick, reason)
if err := email.Send(user.Email, title, body); err != nil {
util.Log().Warning("Failed to send notification email: %s", err)
}
}

52
pkg/email/init.go Normal file
View File

@ -0,0 +1,52 @@
package email
import (
"sync"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Client 默认的邮件发送客户端
var Client Driver
// Lock 读写锁
var Lock sync.RWMutex
// Init 初始化
func Init() {
util.Log().Debug("Initializing email sending queue...")
Lock.Lock()
defer Lock.Unlock()
if Client != nil {
Client.Close()
}
// 读取SMTP设置
options := model.GetSettingByNames(
"fromName",
"fromAdress",
"smtpHost",
"replyTo",
"smtpUser",
"smtpPass",
"smtpEncryption",
)
port := model.GetIntSetting("smtpPort", 25)
keepAlive := model.GetIntSetting("mail_keepalive", 30)
client := NewSMTPClient(SMTPConfig{
Name: options["fromName"],
Address: options["fromAdress"],
ReplyTo: options["replyTo"],
Host: options["smtpHost"],
Port: port,
User: options["smtpUser"],
Password: options["smtpPass"],
Keepalive: keepAlive,
Encryption: model.IsTrueVal(options["smtpEncryption"]),
})
Client = client
}

38
pkg/email/mail.go Normal file
View File

@ -0,0 +1,38 @@
package email
import (
"errors"
"strings"
)
// Driver 邮件发送驱动
type Driver interface {
// Close 关闭驱动
Close()
// Send 发送邮件
Send(to, title, body string) error
}
var (
// ErrChanNotOpen 邮件队列未开启
ErrChanNotOpen = errors.New("email queue is not started")
// ErrNoActiveDriver 无可用邮件发送服务
ErrNoActiveDriver = errors.New("no avaliable email provider")
)
// Send 发送邮件
func Send(to, title, body string) error {
// 忽略通过QQ登录的邮箱
if strings.HasSuffix(to, "@login.qq.com") {
return nil
}
Lock.RLock()
defer Lock.RUnlock()
if Client == nil {
return ErrNoActiveDriver
}
return Client.Send(to, title, body)
}

122
pkg/email/smtp.go Normal file
View File

@ -0,0 +1,122 @@
package email
import (
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/go-mail/mail"
"github.com/google/uuid"
)
// SMTP SMTP协议发送邮件
type SMTP struct {
Config SMTPConfig
ch chan *mail.Message
chOpen bool
}
// SMTPConfig SMTP发送配置
type SMTPConfig struct {
Name string // 发送者名
Address string // 发送者地址
ReplyTo string // 回复地址
Host string // 服务器主机名
Port int // 服务器端口
User string // 用户名
Password string // 密码
Encryption bool // 是否启用加密
Keepalive int // SMTP 连接保留时长
}
// NewSMTPClient 新建SMTP发送队列
func NewSMTPClient(config SMTPConfig) *SMTP {
client := &SMTP{
Config: config,
ch: make(chan *mail.Message, 30),
chOpen: false,
}
client.Init()
return client
}
// Send 发送邮件
func (client *SMTP) Send(to, title, body string) error {
if !client.chOpen {
return ErrChanNotOpen
}
m := mail.NewMessage()
m.SetAddressHeader("From", client.Config.Address, client.Config.Name)
m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name)
m.SetHeader("To", to)
m.SetHeader("Subject", title)
m.SetHeader("Message-ID", util.StrConcat(`"<`, uuid.NewString(), `@`, `cloudreveplus`, `>"`))
m.SetBody("text/html", body)
client.ch <- m
return nil
}
// Close 关闭发送队列
func (client *SMTP) Close() {
if client.ch != nil {
close(client.ch)
}
}
// Init 初始化发送队列
func (client *SMTP) Init() {
go func() {
defer func() {
if err := recover(); err != nil {
client.chOpen = false
util.Log().Error("Exception while sending email: %s, queue will be reset in 10 seconds.", err)
time.Sleep(time.Duration(10) * time.Second)
client.Init()
}
}()
d := mail.NewDialer(client.Config.Host, client.Config.Port, client.Config.User, client.Config.Password)
d.Timeout = time.Duration(client.Config.Keepalive+5) * time.Second
client.chOpen = true
// 是否启用 SSL
d.SSL = false
if client.Config.Encryption {
d.SSL = true
}
d.StartTLSPolicy = mail.OpportunisticStartTLS
var s mail.SendCloser
var err error
open := false
for {
select {
case m, ok := <-client.ch:
if !ok {
util.Log().Debug("Email queue closing...")
client.chOpen = false
return
}
if !open {
if s, err = d.Dial(); err != nil {
panic(err)
}
open = true
}
if err := mail.Send(s, m); err != nil {
util.Log().Warning("Failed to send email: %s", err)
} else {
util.Log().Debug("Email sent.")
}
// 长时间没有新邮件则关闭SMTP连接
case <-time.After(time.Duration(client.Config.Keepalive) * time.Second):
if open {
if err := s.Close(); err != nil {
util.Log().Warning("Failed to close SMTP connection: %s", err)
}
open = false
}
}
}
}()
}

50
pkg/email/template.go Normal file
View File

@ -0,0 +1,50 @@
package email
import (
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// NewOveruseNotification 新建超额提醒邮件
func NewOveruseNotification(userName, reason string) (string, string) {
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "over_used_template")
replace := map[string]string{
"{siteTitle}": options["siteName"],
"{userName}": userName,
"{notifyReason}": reason,
"{siteUrl}": options["siteURL"],
"{siteSecTitle}": options["siteTitle"],
}
return fmt.Sprintf("【%s】空间容量超额提醒", options["siteName"]),
util.Replace(replace, options["over_used_template"])
}
// NewActivationEmail 新建激活邮件
func NewActivationEmail(userName, activateURL string) (string, string) {
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template")
replace := map[string]string{
"{siteTitle}": options["siteName"],
"{userName}": userName,
"{activationUrl}": activateURL,
"{siteUrl}": options["siteURL"],
"{siteSecTitle}": options["siteTitle"],
}
return fmt.Sprintf("【%s】注册激活", options["siteName"]),
util.Replace(replace, options["mail_activation_template"])
}
// NewResetEmail 新建重设密码邮件
func NewResetEmail(userName, resetURL string) (string, string) {
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_reset_pwd_template")
replace := map[string]string{
"{siteTitle}": options["siteName"],
"{userName}": userName,
"{resetUrl}": resetURL,
"{siteUrl}": options["siteURL"],
"{siteSecTitle}": options["siteTitle"],
}
return fmt.Sprintf("【%s】密码重置", options["siteName"]),
util.Replace(replace, options["mail_reset_pwd_template"])
}

309
pkg/filesystem/archive.go Normal file
View File

@ -0,0 +1,309 @@
package filesystem
import (
"archive/zip"
"context"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/mholt/archiver/v4"
)
/* ===============
压缩/解压缩
===============
*/
// Compress 创建给定目录和文件的压缩文件
func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs, fileIDs []uint, isArchive bool) error {
// 查找待压缩目录
folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID)
if err != nil && len(folderIDs) != 0 {
return ErrDBListObjects
}
// 查找待压缩文件
files, err := model.GetFilesByIDs(fileIDs, fs.User.ID)
if err != nil && len(fileIDs) != 0 {
return ErrDBListObjects
}
// 如果上下文限制了父目录,则进行检查
if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok {
// 检查目录
for _, folder := range folders {
if *folder.ParentID != parent.ID {
return ErrObjectNotExist
}
}
// 检查文件
for _, file := range files {
if file.FolderID != parent.ID {
return ErrObjectNotExist
}
}
}
// 尝试获取请求上下文,以便于后续检查用户取消任务
reqContext := ctx
ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context)
if ok {
reqContext = ginCtx.Request.Context()
}
// 将顶级待处理对象的路径设为根路径
for i := 0; i < len(folders); i++ {
folders[i].Position = ""
}
for i := 0; i < len(files); i++ {
files[i].Position = ""
}
// 创建压缩文件Writer
zipWriter := zip.NewWriter(writer)
defer zipWriter.Close()
ctx = reqContext
// 压缩各个目录及文件
for i := 0; i < len(folders); i++ {
select {
case <-reqContext.Done():
// 取消压缩请求
return ErrClientCanceled
default:
fs.doCompress(reqContext, nil, &folders[i], zipWriter, isArchive)
}
}
for i := 0; i < len(files); i++ {
select {
case <-reqContext.Done():
// 取消压缩请求
return ErrClientCanceled
default:
fs.doCompress(reqContext, &files[i], nil, zipWriter, isArchive)
}
}
return nil
}
func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) {
// 如果对象是文件
if file != nil {
// 切换上传策略
fs.Policy = file.GetPolicy()
err := fs.DispatchHandler()
if err != nil {
util.Log().Warning("Failed to compress file %q: %s", file.Name, err)
return
}
// 获取文件内容
fileToZip, err := fs.Handler.Get(
context.WithValue(ctx, fsctx.FileModelCtx, *file),
file.SourceName,
)
if err != nil {
util.Log().Debug("Failed to open %q: %s", file.Name, err)
return
}
if closer, ok := fileToZip.(io.Closer); ok {
defer closer.Close()
}
// 创建压缩文件头
header := &zip.FileHeader{
Name: filepath.FromSlash(path.Join(file.Position, file.Name)),
Modified: file.UpdatedAt,
UncompressedSize64: file.Size,
}
// 指定是压缩还是归档
if isArchive {
header.Method = zip.Store
} else {
header.Method = zip.Deflate
}
writer, err := zipWriter.CreateHeader(header)
if err != nil {
return
}
_, err = io.Copy(writer, fileToZip)
} else if folder != nil {
// 对象是目录
// 获取子文件
subFiles, err := folder.GetChildFiles()
if err == nil && len(subFiles) > 0 {
for i := 0; i < len(subFiles); i++ {
fs.doCompress(ctx, &subFiles[i], nil, zipWriter, isArchive)
}
}
// 获取子目录,继续递归遍历
subFolders, err := folder.GetChildFolder()
if err == nil && len(subFolders) > 0 {
for i := 0; i < len(subFolders); i++ {
fs.doCompress(ctx, nil, &subFolders[i], zipWriter, isArchive)
}
}
}
}
// Decompress 解压缩给定压缩文件到dst目录
func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string) error {
err := fs.ResetFileIfNotExist(ctx, src)
if err != nil {
return err
}
tempZipFilePath := ""
defer func() {
// 结束时删除临时压缩文件
if tempZipFilePath != "" {
if err := os.Remove(tempZipFilePath); err != nil {
util.Log().Warning("Failed to delete temp archive file %q: %s", tempZipFilePath, err)
}
}
}()
// 下载压缩文件到临时目录
fileStream, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName)
if err != nil {
return err
}
defer fileStream.Close()
tempZipFilePath = filepath.Join(
util.RelativePath(model.GetSettingByName("temp_path")),
"decompress",
fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()),
)
zipFile, err := util.CreatNestedFile(tempZipFilePath)
if err != nil {
util.Log().Warning("Failed to create temp archive file %q: %s", tempZipFilePath, err)
tempZipFilePath = ""
return err
}
defer zipFile.Close()
// 下载前先判断是否是可解压的格式
format, readStream, err := archiver.Identify(fs.FileTarget[0].SourceName, fileStream)
if err != nil {
util.Log().Warning("Failed to detect compressed format of file %q: %s", fs.FileTarget[0].SourceName, err)
return err
}
extractor, ok := format.(archiver.Extractor)
if !ok {
return fmt.Errorf("file not an extractor %s", fs.FileTarget[0].SourceName)
}
// 只有zip格式可以多个文件同时上传
var isZip bool
switch extractor.(type) {
case archiver.Zip:
extractor = archiver.Zip{TextEncoding: encoding}
isZip = true
}
// 除了zip必须下载到本地其余的可以边下载边解压
reader := readStream
if isZip {
_, err = io.Copy(zipFile, readStream)
if err != nil {
util.Log().Warning("Failed to write temp archive file %q: %s", tempZipFilePath, err)
return err
}
fileStream.Close()
// 设置文件偏移量
zipFile.Seek(0, io.SeekStart)
reader = zipFile
}
var wg sync.WaitGroup
parallel := model.GetIntSetting("max_parallel_transfer", 4)
worker := make(chan int, parallel)
for i := 0; i < parallel; i++ {
worker <- i
}
// 上传文件函数
uploadFunc := func(fileStream io.ReadCloser, size int64, savePath, rawPath string) {
defer func() {
if isZip {
worker <- 1
wg.Done()
}
if err := recover(); err != nil {
util.Log().Warning("Error while uploading files inside of archive file.")
fmt.Println(err)
}
}()
err := fs.UploadFromStream(ctx, &fsctx.FileStream{
File: fileStream,
Size: uint64(size),
Name: path.Base(savePath),
VirtualPath: path.Dir(savePath),
}, true)
fileStream.Close()
if err != nil {
util.Log().Debug("Failed to upload file %q in archive file: %s, skipping...", rawPath, err)
}
}
// 解压缩文件回调函数如果出错会停止解压的下一步进行全部return nil
err = extractor.Extract(ctx, reader, nil, func(ctx context.Context, f archiver.File) error {
rawPath := util.FormSlash(f.NameInArchive)
savePath := path.Join(dst, rawPath)
// 路径是否合法
if !strings.HasPrefix(savePath, util.FillSlash(path.Clean(dst))) {
util.Log().Warning("%s: illegal file path", f.NameInArchive)
return nil
}
// 如果是目录
if f.FileInfo.IsDir() {
fs.CreateDirectory(ctx, savePath)
return nil
}
// 上传文件
fileStream, err := f.Open()
if err != nil {
util.Log().Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err)
return nil
}
if !isZip {
uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
} else {
<-worker
wg.Add(1)
go uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
}
return nil
})
wg.Wait()
return err
}

View File

@ -0,0 +1,74 @@
package backoff
import (
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"net/http"
"strconv"
"time"
)
// Backoff used for retry sleep backoff
type Backoff interface {
Next(err error) bool
Reset()
}
// ConstantBackoff implements Backoff interface with constant sleep time. If the error
// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration.
type ConstantBackoff struct {
Sleep time.Duration
Max int
tried int
}
func (c *ConstantBackoff) Next(err error) bool {
c.tried++
if c.tried > c.Max {
return false
}
var e *RetryableError
if errors.As(err, &e) && e.RetryAfter > 0 {
util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter)
time.Sleep(e.RetryAfter)
} else {
time.Sleep(c.Sleep)
}
return true
}
func (c *ConstantBackoff) Reset() {
c.tried = 0
}
type RetryableError struct {
Err error
RetryAfter time.Duration
}
// NewRetryableErrorFromHeader constructs a new RetryableError from http response header
// and existing error.
func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError {
retryAfter := header.Get("retry-after")
if retryAfter == "" {
retryAfter = "0"
}
res := &RetryableError{
Err: err,
}
if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil {
res.RetryAfter = time.Duration(retryAfterSecond) * time.Second
}
return res
}
func (e *RetryableError) Error() string {
return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err)
}

View File

@ -0,0 +1,167 @@
package chunk
import (
"context"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"os"
)
const bufferTempPattern = "cdChunk.*.tmp"
// ChunkProcessFunc callback function for processing a chunk
type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error
// ChunkGroup manage groups of chunks
type ChunkGroup struct {
file fsctx.FileHeader
chunkSize uint64
backoff backoff.Backoff
enableRetryBuffer bool
fileInfo *fsctx.UploadTaskInfo
currentIndex int
chunkNum uint64
bufferTemp *os.File
}
func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff, useBuffer bool) *ChunkGroup {
c := &ChunkGroup{
file: file,
chunkSize: chunkSize,
backoff: backoff,
fileInfo: file.Info(),
currentIndex: -1,
enableRetryBuffer: useBuffer,
}
if c.chunkSize == 0 {
c.chunkSize = c.fileInfo.Size
}
if c.fileInfo.Size == 0 {
c.chunkNum = 1
} else {
c.chunkNum = c.fileInfo.Size / c.chunkSize
if c.fileInfo.Size%c.chunkSize != 0 {
c.chunkNum++
}
}
return c
}
// TempAvailable returns if current chunk temp file is available to be read
func (c *ChunkGroup) TempAvailable() bool {
if c.bufferTemp != nil {
state, _ := c.bufferTemp.Stat()
return state != nil && state.Size() == c.Length()
}
return false
}
// Process a chunk with retry logic
func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
reader := io.LimitReader(c.file, c.Length())
// If useBuffer is enabled, tee the reader to a temp file
if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() {
c.bufferTemp, _ = os.CreateTemp("", bufferTempPattern)
reader = io.TeeReader(reader, c.bufferTemp)
}
if c.bufferTemp != nil {
defer func() {
if c.bufferTemp != nil {
c.bufferTemp.Close()
os.Remove(c.bufferTemp.Name())
c.bufferTemp = nil
}
}()
// if temp buffer file is available, use it
if c.TempAvailable() {
if _, err := c.bufferTemp.Seek(0, io.SeekStart); err != nil {
return fmt.Errorf("failed to seek temp file back to chunk start: %w", err)
}
util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
reader = io.NopCloser(c.bufferTemp)
}
}
err := processor(c, reader)
if err != nil {
if c.enableRetryBuffer {
request.BlackHole(reader)
}
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) {
if c.file.Seekable() {
if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil {
return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err)
}
}
util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err)
return c.Process(processor)
}
return err
}
util.Log().Debug("Chunk %d processed", c.currentIndex)
return nil
}
// Start returns the byte index of current chunk
func (c *ChunkGroup) Start() int64 {
return int64(uint64(c.Index()) * c.chunkSize)
}
// Total returns the total length
func (c *ChunkGroup) Total() int64 {
return int64(c.fileInfo.Size)
}
// Num returns the total chunk number
func (c *ChunkGroup) Num() int {
return int(c.chunkNum)
}
// RangeHeader returns header value of Content-Range
func (c *ChunkGroup) RangeHeader() string {
return fmt.Sprintf("bytes %d-%d/%d", c.Start(), c.Start()+c.Length()-1, c.Total())
}
// Index returns current chunk index, starts from 0
func (c *ChunkGroup) Index() int {
return c.currentIndex
}
// Next switch to next chunk, returns whether all chunks are processed
func (c *ChunkGroup) Next() bool {
c.currentIndex++
c.backoff.Reset()
return c.currentIndex < int(c.chunkNum)
}
// Length returns the length of current chunk
func (c *ChunkGroup) Length() int64 {
contentLength := c.chunkSize
if c.Index() == int(c.chunkNum-1) {
contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1)
}
return int64(contentLength)
}
// IsLast returns if current chunk is the last one
func (c *ChunkGroup) IsLast() bool {
return c.Index() == int(c.chunkNum-1)
}

View File

@ -0,0 +1,427 @@
package cos
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/google/go-querystring/query"
cossdk "github.com/tencentyun/cos-go-sdk-v5"
)
// UploadPolicy 腾讯云COS上传策略
type UploadPolicy struct {
Expiration string `json:"expiration"`
Conditions []interface{} `json:"conditions"`
}
// MetaData 文件元信息
type MetaData struct {
Size uint64
CallbackKey string
CallbackURL string
}
type urlOption struct {
Speed int `url:"x-cos-traffic-limit,omitempty"`
ContentDescription string `url:"response-content-disposition,omitempty"`
}
// Driver 腾讯云COS适配器模板
type Driver struct {
Policy *model.Policy
Client *cossdk.Client
HTTPClient request.Client
}
// List 列出COS文件
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
// 初始化列目录参数
opt := &cossdk.BucketGetOptions{
Prefix: strings.TrimPrefix(base, "/"),
EncodingType: "",
MaxKeys: 1000,
}
// 是否为递归列出
if !recursive {
opt.Delimiter = "/"
}
// 手动补齐结尾的slash
if opt.Prefix != "" {
opt.Prefix += "/"
}
var (
marker string
objects []cossdk.Object
commons []string
)
for {
res, _, err := handler.Client.Bucket.Get(ctx, opt)
if err != nil {
return nil, err
}
objects = append(objects, res.Contents...)
commons = append(commons, res.CommonPrefixes...)
// 如果本次未列取完则继续使用marker获取结果
marker = res.NextMarker
// marker 为空时结果列取完毕,跳出
if marker == "" {
break
}
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(opt.Prefix, object)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(opt.Prefix, object.Key)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Size),
IsDir: false,
LastModify: time.Now(),
})
}
return res, nil
}
// CORS 创建跨域策略
func (handler Driver) CORS() error {
_, err := handler.Client.Bucket.PutCORS(context.Background(), &cossdk.BucketPutCORSOptions{
Rules: []cossdk.BucketCORSRule{{
AllowedMethods: []string{
"GET",
"POST",
"PUT",
"DELETE",
"HEAD",
},
AllowedOrigins: []string{"*"},
AllowedHeaders: []string{"*"},
MaxAgeSeconds: 3600,
ExposeHeaders: []string{},
}},
})
return err
}
// Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.HTTPClient.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
opt := &cossdk.ObjectPutOptions{}
_, err := handler.Client.Object.Put(ctx, file.Info().SavePath, file, opt)
return err
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
obs := []cossdk.Object{}
for _, v := range files {
obs = append(obs, cossdk.Object{Key: v})
}
opt := &cossdk.ObjectDeleteMultiOptions{
Objects: obs,
Quiet: true,
}
res, _, err := handler.Client.Object.DeleteMulti(context.Background(), opt)
if err != nil {
return files, err
}
// 整理删除结果
failed := make([]string, 0, len(files))
for _, v := range res.Errors {
failed = append(failed, v.Key)
}
if len(failed) == 0 {
return failed, nil
}
return failed, errors.New("delete failed")
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://cloud.tencent.com/document/product/436/44893
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heif", "heic"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) || file.Size > (32<<(10*2)) {
return nil, driver.ErrorThumbNotSupported
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumbParam := fmt.Sprintf("imageMogr2/thumbnail/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality)
source, err := handler.signSourceURL(
ctx,
file.SourceName,
int64(model.GetIntSetting("preview_timeout", 60)),
&urlOption{},
)
if err != nil {
return nil, err
}
thumbURL, _ := url.Parse(source)
thumbQuery := thumbURL.Query()
thumbQuery.Add(thumbParam, "")
thumbURL.RawQuery = thumbQuery.Encode()
return &response.ContentResponse{
Redirect: true,
URL: thumbURL.String(),
}, nil
}
// Source 获取外链URL
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 添加各项设置
options := urlOption{}
if speed > 0 {
if speed < 819200 {
speed = 819200
}
if speed > 838860800 {
speed = 838860800
}
options.Speed = speed
}
if isDownload {
options.ContentDescription = "attachment; filename=\"" + url.PathEscape(fileName) + "\""
}
return handler.signSourceURL(ctx, path, ttl, &options)
}
func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64, options *urlOption) (string, error) {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
// 公有空间不需要签名
if !handler.Policy.IsPrivate {
file, err := url.Parse(path)
if err != nil {
return "", err
}
// 非签名URL不支持设置响应header
options.ContentDescription = ""
optionQuery, err := query.Values(*options)
if err != nil {
return "", err
}
file.RawQuery = optionQuery.Encode()
sourceURL := cdnURL.ResolveReference(file)
return sourceURL.String(), nil
}
presignedURL, err := handler.Client.Object.GetPresignedURL(ctx, http.MethodGet, path,
handler.Policy.AccessKey, handler.Policy.SecretKey, time.Duration(ttl)*time.Second, options)
if err != nil {
return "", err
}
// 将最终生成的签名URL域名换成用户自定义的加速域名如果有
presignedURL.Host = cdnURL.Host
presignedURL.Scheme = cdnURL.Scheme
return presignedURL.String(), nil
}
// Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/cos/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI).String()
// 上传策略
savePath := file.Info().SavePath
startTime := time.Now()
endTime := startTime.Add(time.Duration(ttl) * time.Second)
keyTime := fmt.Sprintf("%d;%d", startTime.Unix(), endTime.Unix())
postPolicy := UploadPolicy{
Expiration: endTime.UTC().Format(time.RFC3339),
Conditions: []interface{}{
map[string]string{"bucket": handler.Policy.BucketName},
map[string]string{"$key": savePath},
map[string]string{"x-cos-meta-callback": apiURL},
map[string]string{"x-cos-meta-key": uploadSession.Key},
map[string]string{"q-sign-algorithm": "sha1"},
map[string]string{"q-ak": handler.Policy.AccessKey},
map[string]string{"q-sign-time": keyTime},
},
}
if handler.Policy.MaxSize > 0 {
postPolicy.Conditions = append(postPolicy.Conditions,
[]interface{}{"content-length-range", 0, handler.Policy.MaxSize})
}
res, err := handler.getUploadCredential(ctx, postPolicy, keyTime, savePath)
if err == nil {
res.SessionID = uploadSession.Key
res.Callback = apiURL
res.UploadURLs = []string{handler.Policy.Server}
}
return res, err
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}
// Meta 获取文件信息
func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
res, err := handler.Client.Object.Head(ctx, path, &cossdk.ObjectHeadOptions{})
if err != nil {
return nil, err
}
return &MetaData{
Size: uint64(res.ContentLength),
CallbackKey: res.Header.Get("x-cos-meta-key"),
CallbackURL: res.Header.Get("x-cos-meta-callback"),
}, nil
}
func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, keyTime string, savePath string) (*serializer.UploadCredential, error) {
// 编码上传策略
policyJSON, err := json.Marshal(policy)
if err != nil {
return nil, err
}
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
// 签名上传策略
hmacSign := hmac.New(sha1.New, []byte(handler.Policy.SecretKey))
_, err = io.WriteString(hmacSign, keyTime)
if err != nil {
return nil, err
}
signKey := fmt.Sprintf("%x", hmacSign.Sum(nil))
sha1Sign := sha1.New()
_, err = sha1Sign.Write(policyJSON)
if err != nil {
return nil, err
}
stringToSign := fmt.Sprintf("%x", sha1Sign.Sum(nil))
// 最终签名
hmacFinalSign := hmac.New(sha1.New, []byte(signKey))
_, err = hmacFinalSign.Write([]byte(stringToSign))
if err != nil {
return nil, err
}
signature := hmacFinalSign.Sum(nil)
return &serializer.UploadCredential{
Policy: policyEncoded,
Path: savePath,
AccessKey: handler.Policy.AccessKey,
Credential: fmt.Sprintf("%x", signature),
KeyTime: keyTime,
}, nil
}

View File

@ -0,0 +1,134 @@
package cos
import (
"archive/zip"
"bytes"
"encoding/base64"
"io"
"io/ioutil"
"net/url"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
scf "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf/v20180416"
)
const scfFunc = `# -*- coding: utf8 -*-
# SCF配置COS触发向 Cloudreve 发送回调
from qcloud_cos_v5 import CosConfig
from qcloud_cos_v5 import CosS3Client
from qcloud_cos_v5 import CosServiceError
from qcloud_cos_v5 import CosClientError
import sys
import logging
import requests
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
logger = logging.getLogger()
def main_handler(event, context):
logger.info("start main handler")
for record in event['Records']:
try:
if "x-cos-meta-callback" not in record['cos']['cosObject']['meta']:
logger.info("Cannot find callback URL, skiped.")
return 'Success'
callback = record['cos']['cosObject']['meta']['x-cos-meta-callback']
key = record['cos']['cosObject']['key']
logger.info("Callback URL is " + callback)
r = requests.get(callback)
print(r.text)
except Exception as e:
print(e)
print('Error getting object {} callback url. '.format(key))
raise e
return "Fail"
return "Success"
`
// CreateSCF 创建回调云函数
func CreateSCF(policy *model.Policy, region string) error {
// 初始化客户端
credential := common.NewCredential(
policy.AccessKey,
policy.SecretKey,
)
cpf := profile.NewClientProfile()
client, err := scf.NewClient(credential, region, cpf)
if err != nil {
return err
}
// 创建回调代码数据
buff := &bytes.Buffer{}
bs64 := base64.NewEncoder(base64.StdEncoding, buff)
zipWriter := zip.NewWriter(bs64)
header := zip.FileHeader{
Name: "callback.py",
Method: zip.Deflate,
}
writer, err := zipWriter.CreateHeader(&header)
if err != nil {
return err
}
_, err = io.Copy(writer, strings.NewReader(scfFunc))
zipWriter.Close()
// 创建云函数
req := scf.NewCreateFunctionRequest()
funcName := "cloudreve_" + hashid.HashID(policy.ID, hashid.PolicyID) + strconv.FormatInt(time.Now().Unix(), 10)
zipFileBytes, _ := ioutil.ReadAll(buff)
zipFileStr := string(zipFileBytes)
codeSource := "ZipFile"
handler := "callback.main_handler"
desc := "Cloudreve 用回调函数"
timeout := int64(60)
runtime := "Python3.6"
req.FunctionName = &funcName
req.Code = &scf.Code{
ZipFile: &zipFileStr,
}
req.Handler = &handler
req.Description = &desc
req.Timeout = &timeout
req.Runtime = &runtime
req.CodeSource = &codeSource
_, err = client.CreateFunction(req)
if err != nil {
return err
}
time.Sleep(time.Duration(5) * time.Second)
// 创建触发器
server, _ := url.Parse(policy.Server)
triggerType := "cos"
triggerDesc := `{"event":"cos:ObjectCreated:Post","filter":{"Prefix":"","Suffix":""}}`
enable := "OPEN"
trigger := scf.NewCreateTriggerRequest()
trigger.FunctionName = &funcName
trigger.TriggerName = &server.Host
trigger.Type = &triggerType
trigger.TriggerDesc = &triggerDesc
trigger.Enable = &enable
_, err = client.CreateTrigger(trigger)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,73 @@
package googledrive
import (
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"google.golang.org/api/drive/v3"
)
// Client Google Drive client
type Client struct {
Endpoints *Endpoints
Policy *model.Policy
Credential *Credential
ClientID string
ClientSecret string
Redirect string
Request request.Client
ClusterController cluster.Controller
}
// Endpoints OneDrive客户端相关设置
type Endpoints struct {
UserConsentEndpoint string // OAuth认证的基URL
TokenEndpoint string // OAuth token 基URL
EndpointURL string // 接口请求的基URL
}
const (
TokenCachePrefix = "googledrive_"
oauthEndpoint = "https://oauth2.googleapis.com/token"
userConsentBase = "https://accounts.google.com/o/oauth2/auth"
v3DriveEndpoint = "https://www.googleapis.com/drive/v3"
)
var (
// Defualt required scopes
RequiredScope = []string{
drive.DriveScope,
"openid",
"profile",
"https://www.googleapis.com/auth/userinfo.profile",
}
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
)
// NewClient 根据存储策略获取新的client
func NewClient(policy *model.Policy) (*Client, error) {
client := &Client{
Endpoints: &Endpoints{
TokenEndpoint: oauthEndpoint,
UserConsentEndpoint: userConsentBase,
EndpointURL: v3DriveEndpoint,
},
Credential: &Credential{
RefreshToken: policy.AccessKey,
},
Policy: policy,
ClientID: policy.BucketName,
ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OauthRedirect,
Request: request.NewClient(),
ClusterController: cluster.DefaultController,
}
return client, nil
}

View File

@ -0,0 +1,65 @@
package googledrive
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver Google Drive 适配器
type Driver struct {
Policy *model.Policy
HTTPClient request.Client
}
// NewDriver 从存储策略初始化新的Driver实例
func NewDriver(policy *model.Policy) (driver.Handler, error) {
return &Driver{
Policy: policy,
HTTPClient: request.NewClient(),
}, nil
}
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
//TODO implement me
panic("implement me")
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
//TODO implement me
panic("implement me")
}
func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
//TODO implement me
panic("implement me")
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
//TODO implement me
panic("implement me")
}

View File

@ -0,0 +1,154 @@
package googledrive
import (
"context"
"encoding/json"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// OAuthURL 获取OAuth认证页面URL
func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
query := url.Values{
"client_id": {client.ClientID},
"scope": {strings.Join(scope, " ")},
"response_type": {"code"},
"redirect_uri": {client.Redirect},
"access_type": {"offline"},
"prompt": {"consent"},
}
u, _ := url.Parse(client.Endpoints.UserConsentEndpoint)
u.RawQuery = query.Encode()
return u.String()
}
// ObtainToken 通过code或refresh_token兑换token
func (client *Client) ObtainToken(ctx context.Context, code, refreshToken string) (*Credential, error) {
body := url.Values{
"client_id": {client.ClientID},
"redirect_uri": {client.Redirect},
"client_secret": {client.ClientSecret},
}
if code != "" {
body.Add("grant_type", "authorization_code")
body.Add("code", code)
} else {
body.Add("grant_type", "refresh_token")
body.Add("refresh_token", refreshToken)
}
strBody := body.Encode()
res := client.Request.Request(
"POST",
client.Endpoints.TokenEndpoint,
io.NopCloser(strings.NewReader(strBody)),
request.WithHeader(http.Header{
"Content-Type": {"application/x-www-form-urlencoded"}},
),
request.WithContentLength(int64(len(strBody))),
)
if res.Err != nil {
return nil, res.Err
}
respBody, err := res.GetResponse()
if err != nil {
return nil, err
}
var (
errResp OAuthError
credential Credential
decodeErr error
)
if res.Response.StatusCode != 200 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
} else {
decodeErr = json.Unmarshal([]byte(respBody), &credential)
}
if decodeErr != nil {
return nil, decodeErr
}
if errResp.ErrorType != "" {
return nil, errResp
}
return &credential, nil
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
if isSlave {
return client.fetchCredentialFromMaster(ctx)
}
oauth.GlobalMutex.Lock(client.Policy.ID)
defer oauth.GlobalMutex.Unlock(client.Policy.ID)
// 如果已存在凭证
if client.Credential != nil && client.Credential.AccessToken != "" {
// 检查已有凭证是否过期
if client.Credential.ExpiresIn > time.Now().Unix() {
// 未过期,不要更新
return nil
}
}
// 尝试从缓存中获取凭证
if cacheCredential, ok := cache.Get(TokenCachePrefix + client.ClientID); ok {
credential := cacheCredential.(Credential)
if credential.ExpiresIn > time.Now().Unix() {
client.Credential = &credential
return nil
}
}
// 获取新的凭证
if client.Credential == nil || client.Credential.RefreshToken == "" {
// 无有效的RefreshToken
util.Log().Error("Failed to refresh credential for policy %q, please login your Google account again.", client.Policy.Name)
return ErrInvalidRefreshToken
}
credential, err := client.ObtainToken(ctx, "", client.Credential.RefreshToken)
if err != nil {
return err
}
// 更新有效期为绝对时间戳
expires := credential.ExpiresIn - 60
credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
// refresh token for Google Drive does not expire in production
credential.RefreshToken = client.Credential.RefreshToken
client.Credential = credential
// 更新缓存
cache.Set(TokenCachePrefix+client.ClientID, *credential, int(expires))
return nil
}
func (client *Client) AccessToken() string {
return client.Credential.AccessToken
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
if err != nil {
return err
}
client.Credential = &Credential{AccessToken: res}
return nil
}

View File

@ -0,0 +1,43 @@
package googledrive
import "encoding/gob"
// RespError 接口返回错误
type RespError struct {
APIError APIError `json:"error"`
}
// APIError 接口返回的错误内容
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
}
// Credential 获取token时返回的凭证
type Credential struct {
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
UserID string `json:"user_id"`
}
// OAuthError OAuth相关接口的错误响应
type OAuthError struct {
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
}
// Error 实现error接口
func (err OAuthError) Error() string {
return err.ErrorDescription
}
func init() {
gob.Register(Credential{})
}

View File

@ -0,0 +1,52 @@
package driver
import (
"context"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrorThumbNotExist = fmt.Errorf("thumb not exist")
ErrorThumbNotSupported = fmt.Errorf("thumb not supported")
)
// Handler 存储策略适配器
type Handler interface {
// 上传文件, dst为文件存储路径size 为文件大小。上下文关闭
// 时,应取消上传并清理临时文件
Put(ctx context.Context, file fsctx.FileHeader) error
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
Delete(ctx context.Context, files []string) ([]string, error)
// 获取文件内容
Get(ctx context.Context, path string) (response.RSCloser, error)
// 获取缩略图可直接在ContentResponse中返回文件数据流也可指
// 定为重定向
// 如果缩略图不存在, 且需要 Cloudreve 代理生成并上传,应返回 ErrorThumbNotExist
// 成的缩略图文件存储规则与本机策略一致。
// 如果不支持此文件的缩略图,并且不希望后续继续请求此缩略图,应返回 ErrorThumbNotSupported
Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error)
// 获取外链/下载地址,
// url - 站点本身地址,
// isDownload - 是否直接下载
Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error)
// Token 获取有效期为ttl的上传凭证和签名
Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error)
// CancelToken 取消已经创建的有状态上传凭证
CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error
// List 递归列取远程端path路径下文件、目录不包含path本身
// 返回的对象路径以path作为起始根目录.
// recursive - 是否递归列出
List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
}

View File

@ -0,0 +1,292 @@
package local
import (
"context"
"errors"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
const (
Perm = 0744
)
// Driver 本地策略适配器
type Driver struct {
Policy *model.Policy
}
// List 递归列取给定物理路径下所有文件
func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
var res []response.Object
// 取得起始路径
root := util.RelativePath(filepath.FromSlash(path))
// 开始遍历路径下的文件、目录
err := filepath.Walk(root,
func(path string, info os.FileInfo, err error) error {
// 跳过根目录
if path == root {
return nil
}
if err != nil {
util.Log().Warning("Failed to walk folder %q: %s", path, err)
return filepath.SkipDir
}
// 将遍历对象的绝对路径转换为相对路径
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
res = append(res, response.Object{
Name: info.Name(),
RelativePath: filepath.ToSlash(rel),
Source: path,
Size: uint64(info.Size()),
IsDir: info.IsDir(),
LastModify: info.ModTime(),
})
// 如果非递归,则不步入目录
if !recursive && info.IsDir() {
return filepath.SkipDir
}
return nil
})
return res, err
}
// Get 获取文件内容
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 打开文件
file, err := os.Open(util.RelativePath(path))
if err != nil {
util.Log().Debug("Failed to open file: %s", err)
return nil, err
}
return file, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
fileInfo := file.Info()
dst := util.RelativePath(filepath.FromSlash(fileInfo.SavePath))
// 如果非 Overwrite则检查是否有重名冲突
if fileInfo.Mode&fsctx.Overwrite != fsctx.Overwrite {
if util.Exists(dst) {
util.Log().Warning("File with the same name existed or unavailable: %s", dst)
return errors.New("file with the same name existed or unavailable")
}
}
// 如果目标目录不存在,创建
basePath := filepath.Dir(dst)
if !util.Exists(basePath) {
err := os.MkdirAll(basePath, Perm)
if err != nil {
util.Log().Warning("Failed to create directory: %s", err)
return err
}
}
var (
out *os.File
err error
)
openMode := os.O_CREATE | os.O_RDWR
if fileInfo.Mode&fsctx.Append == fsctx.Append {
openMode |= os.O_APPEND
} else {
openMode |= os.O_TRUNC
}
out, err = os.OpenFile(dst, openMode, Perm)
if err != nil {
util.Log().Warning("Failed to open or create file: %s", err)
return err
}
defer out.Close()
if fileInfo.Mode&fsctx.Append == fsctx.Append {
stat, err := out.Stat()
if err != nil {
util.Log().Warning("Failed to read file info: %s", err)
return err
}
if uint64(stat.Size()) < fileInfo.AppendStart {
return errors.New("size of unfinished uploaded chunks is not as expected")
} else if uint64(stat.Size()) > fileInfo.AppendStart {
out.Close()
if err := handler.Truncate(ctx, dst, fileInfo.AppendStart); err != nil {
return fmt.Errorf("failed to overwrite chunk: %w", err)
}
out, err = os.OpenFile(dst, openMode, Perm)
defer out.Close()
if err != nil {
util.Log().Warning("Failed to create or open file: %s", err)
return err
}
}
}
// 写入文件内容
_, err = io.Copy(out, file)
return err
}
func (handler Driver) Truncate(ctx context.Context, src string, size uint64) error {
util.Log().Warning("Truncate file %q to [%d].", src, size)
out, err := os.OpenFile(src, os.O_WRONLY, Perm)
if err != nil {
util.Log().Warning("Failed to open file: %s", err)
return err
}
defer out.Close()
return out.Truncate(int64(size))
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
deleteFailed := make([]string, 0, len(files))
var retErr error
for _, value := range files {
filePath := util.RelativePath(filepath.FromSlash(value))
if util.Exists(filePath) {
err := os.Remove(filePath)
if err != nil {
util.Log().Warning("Failed to delete file: %s", err)
retErr = err
deleteFailed = append(deleteFailed, value)
}
}
// 尝试删除文件的缩略图(如果有)
_ = os.Remove(util.RelativePath(value + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")))
}
return deleteFailed, retErr
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// Quick check thumb existence on master.
if conf.SystemConfig.Mode == "master" && file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusNotExist {
// Tell invoker to generate a thumb
return nil, driver.ErrorThumbNotExist
}
thumbFile, err := handler.Get(ctx, file.ThumbFile())
if err != nil {
if errors.Is(err, os.ErrNotExist) {
err = fmt.Errorf("thumb not exist: %w (%w)", err, driver.ErrorThumbNotExist)
}
return nil, err
}
return &response.ContentResponse{
Redirect: false,
Content: thumbFile,
}, nil
}
// Source 获取外链URL
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
file, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return "", errors.New("failed to read file model context")
}
var baseURL *url.URL
// 是否启用了CDN
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
baseURL = cdnURL
}
var (
signedURI *url.URL
err error
)
if isDownload {
// 创建下载会话,将文件信息写入缓存
downloadSessionID := util.RandStringRunes(16)
err = cache.Set("download_"+downloadSessionID, file, int(ttl))
if err != nil {
return "", serializer.NewError(serializer.CodeCacheOperation, "Failed to create download session", err)
}
// 签名生成文件记录
signedURI, err = auth.SignURI(
auth.General,
fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID),
ttl,
)
} else {
// 签名生成文件记录
signedURI, err = auth.SignURI(
auth.General,
fmt.Sprintf("/api/v3/file/get/%d/%s", file.ID, file.Name),
ttl,
)
}
if err != nil {
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err)
}
finalURL := signedURI.String()
if baseURL != nil {
finalURL = baseURL.ResolveReference(signedURI).String()
}
return finalURL, nil
}
// Token 获取上传策略和认证Token本地策略直接返回空值
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
if util.Exists(uploadSession.SavePath) {
return nil, errors.New("placeholder file already exist")
}
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
}, nil
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}

View File

@ -0,0 +1,595 @@
package onedrive
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
const (
// SmallFileSize 单文件上传接口最大尺寸
SmallFileSize uint64 = 4 * 1024 * 1024
// ChunkSize 服务端中转分片上传分片大小
ChunkSize uint64 = 10 * 1024 * 1024
// ListRetry 列取请求重试次数
ListRetry = 1
chunkRetrySleep = time.Second * 5
notFoundError = "itemNotFound"
)
// GetSourcePath 获取文件的绝对路径
func (info *FileInfo) GetSourcePath() string {
res, err := url.PathUnescape(info.ParentReference.Path)
if err != nil {
return ""
}
return strings.TrimPrefix(
path.Join(
strings.TrimPrefix(res, "/drive/root:"),
info.Name,
),
"/",
)
}
func (client *Client) getRequestURL(api string, opts ...Option) string {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
base, _ := url.Parse(client.Endpoints.EndpointURL)
if base == nil {
return ""
}
if options.useDriverResource {
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
} else {
base.Path = path.Join(base.Path, api)
}
return base.String()
}
// ListChildren 根据路径列取子对象
func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) {
var requestURL string
dst := strings.TrimPrefix(path, "/")
if dst == "" {
requestURL = client.getRequestURL("root/children")
} else {
requestURL = client.getRequestURL("root:/" + dst + ":/children")
}
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
if err != nil {
retried := 0
if v, ok := ctx.Value(fsctx.RetryCtx).(int); ok {
retried = v
}
if retried < ListRetry {
retried++
util.Log().Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err)
time.Sleep(time.Duration(5) * time.Second)
return client.ListChildren(context.WithValue(ctx, fsctx.RetryCtx, retried), path)
}
return nil, err
}
var (
decodeErr error
fileInfo ListResponse
)
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
if decodeErr != nil {
return nil, decodeErr
}
return fileInfo.Value, nil
}
// Meta 根据资源ID或文件路径获取文件元信息
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
var requestURL string
if id != "" {
requestURL = client.getRequestURL("items/" + id)
} else {
dst := strings.TrimPrefix(path, "/")
requestURL = client.getRequestURL("root:/" + dst)
}
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
if err != nil {
return nil, err
}
var (
decodeErr error
fileInfo FileInfo
)
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
if decodeErr != nil {
return nil, decodeErr
}
return &fileInfo, nil
}
// CreateUploadSession 创建分片上传会话
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
body := map[string]map[string]interface{}{
"item": {
"@microsoft.graph.conflictBehavior": options.conflictBehavior,
},
}
bodyBytes, _ := json.Marshal(body)
res, err := client.requestWithStr(ctx, "POST", requestURL, string(bodyBytes), 200)
if err != nil {
return "", err
}
var (
decodeErr error
uploadSession UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
if decodeErr != nil {
return "", decodeErr
}
return uploadSession.UploadURL, nil
}
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
siteUrlParsed, err := url.Parse(siteUrl)
if err != nil {
return "", err
}
hostName := siteUrlParsed.Hostname()
relativePath := strings.Trim(siteUrlParsed.Path, "/")
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
if reqErr != nil {
return "", reqErr
}
var (
decodeErr error
siteInfo Site
)
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
if decodeErr != nil {
return "", decodeErr
}
return siteInfo.ID, nil
}
// GetUploadSessionStatus 查询上传会话状态
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
if err != nil {
return nil, err
}
var (
decodeErr error
uploadSession UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadSession, nil
}
// UploadChunk 上传分片
func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) {
res, err := client.request(
ctx, "PUT", uploadURL, content,
request.WithContentLength(current.Length()),
request.WithHeader(http.Header{
"Content-Range": {current.RangeHeader()},
}),
request.WithoutHeader([]string{"Authorization", "Content-Type"}),
request.WithTimeout(0),
)
if err != nil {
return nil, fmt.Errorf("failed to upload OneDrive chunk #%d: %w", current.Index(), err)
}
if current.IsLast() {
return nil, nil
}
var (
decodeErr error
uploadRes UploadSessionResponse
)
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadRes, nil
}
// Upload 上传文件
func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
fileInfo := file.Info()
// 决定是否覆盖文件
overwrite := "fail"
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
overwrite = "replace"
}
size := int(fileInfo.Size)
dst := fileInfo.SavePath
// 小文件,使用简单上传接口上传
if size <= int(SmallFileSize) {
_, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite))
return err
}
// 大文件,进行分片
// 创建上传会话
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior(overwrite))
if err != nil {
return err
}
// Initial chunk groups
chunks := chunk.NewChunkGroup(file, client.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
Max: model.GetIntSetting("chunk_retries", 5),
Sleep: chunkRetrySleep,
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
_, err := client.UploadChunk(ctx, uploadURL, content, current)
return err
}
// upload chunks
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
return nil
}
// DeleteUploadSession 删除上传会话
func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string) error {
_, err := client.requestWithStr(ctx, "DELETE", uploadURL, "", 204)
if err != nil {
return err
}
return nil
}
// SimpleUpload 上传小文件到dst
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/" + dst + ":/content")
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
request.WithTimeout(0),
)
if err != nil {
return nil, err
}
var (
decodeErr error
uploadRes UploadResult
)
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
if decodeErr != nil {
return nil, decodeErr
}
return &uploadRes, nil
}
// BatchDelete 并行删除给出的文件,返回删除失败的文件,及第一个遇到的错误。此方法将文件分为
// 20个一组调用Delete并行删除
// TODO 测试
func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, error) {
groupNum := len(dst)/20 + 1
finalRes := make([]string, 0, len(dst))
res := make([]string, 0, 20)
var err error
for i := 0; i < groupNum; i++ {
end := 20*i + 20
if i == groupNum-1 {
end = len(dst)
}
res, err = client.Delete(ctx, dst[20*i:end])
finalRes = append(finalRes, res...)
}
return finalRes, err
}
// Delete 并行删除文件,返回删除失败的文件,及第一个遇到的错误,
// 由于API限制最多删除20个
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
body := client.makeBatchDeleteRequestsBody(dst)
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
WithDriverResource(false)), body, 200)
if err != nil {
return dst, err
}
var (
decodeErr error
deleteRes BatchResponses
)
decodeErr = json.Unmarshal([]byte(res), &deleteRes)
if decodeErr != nil {
return dst, decodeErr
}
// 取得删除失败的文件
failed := getDeleteFailed(&deleteRes)
if len(failed) != 0 {
return failed, ErrDeleteFile
}
return failed, nil
}
func getDeleteFailed(res *BatchResponses) []string {
var failed = make([]string, 0, len(res.Responses))
for _, v := range res.Responses {
if v.Status != 204 && v.Status != 404 {
failed = append(failed, v.ID)
}
}
return failed
}
// makeBatchDeleteRequestsBody 生成批量删除请求正文
func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
req := BatchRequests{
Requests: make([]BatchRequest, len(files)),
}
for i, v := range files {
v = strings.TrimPrefix(v, "/")
filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
filePath.Path = path.Join(filePath.Path, v)
req.Requests[i] = BatchRequest{
ID: v,
Method: "DELETE",
URL: filePath.EscapedPath(),
}
}
res, _ := json.Marshal(req)
return string(res)
}
// GetThumbURL 获取给定尺寸的缩略图URL
func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) {
dst = strings.TrimPrefix(dst, "/")
requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
if err != nil {
return "", err
}
var (
decodeErr error
thumbRes ThumbResponse
)
decodeErr = json.Unmarshal([]byte(res), &thumbRes)
if decodeErr != nil {
return "", decodeErr
}
if thumbRes.URL != "" {
return thumbRes.URL, nil
}
if len(thumbRes.Value) == 1 {
if res, ok := thumbRes.Value[0]["large"]; ok {
return res.(map[string]interface{})["url"].(string), nil
}
}
return "", ErrThumbSizeNotFound
}
// MonitorUpload 监控客户端分片上传进度
func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size uint64, ttl int64) {
// 回调完成通知chan
callbackChan := mq.GlobalMQ.Subscribe(callbackKey, 1)
defer mq.GlobalMQ.Unsubscribe(callbackKey, callbackChan)
timeout := model.GetIntSetting("onedrive_monitor_timeout", 600)
interval := model.GetIntSetting("onedrive_callback_check", 20)
for {
select {
case <-callbackChan:
util.Log().Debug("Client finished OneDrive callback.")
return
case <-time.After(time.Duration(ttl) * time.Second):
// 上传会话到期,仍未完成上传,创建占位符
client.DeleteUploadSession(context.Background(), uploadURL)
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
if err != nil {
util.Log().Debug("Failed to create placeholder file: %s", err)
}
return
case <-time.After(time.Duration(timeout) * time.Second):
util.Log().Debug("Checking OneDrive upload status.")
status, err := client.GetUploadSessionStatus(context.Background(), uploadURL)
if err != nil {
if resErr, ok := err.(*RespError); ok {
if resErr.APIError.Code == notFoundError {
util.Log().Debug("Upload completed, will check upload callback later.")
select {
case <-time.After(time.Duration(interval) * time.Second):
util.Log().Warning("No callback is made, file will be deleted.")
cache.Deletes([]string{callbackKey}, "callback_")
_, err = client.Delete(context.Background(), []string{path})
if err != nil {
util.Log().Warning("Failed to delete file without callback: %s", err)
}
case <-callbackChan:
util.Log().Debug("Client finished callback.")
}
return
}
}
util.Log().Debug("Failed to get upload session status: %s, continue next iteration.", err.Error())
continue
}
// 成功获取分片上传状态,检查文件大小
if len(status.NextExpectedRanges) == 0 {
continue
}
sizeRange := strings.Split(
status.NextExpectedRanges[len(status.NextExpectedRanges)-1],
"-",
)
if len(sizeRange) != 2 {
continue
}
uploadFullSize, _ := strconv.ParseUint(sizeRange[1], 10, 64)
if (sizeRange[0] == "0" && sizeRange[1] == "") || uploadFullSize+1 != size {
util.Log().Debug("Upload has not started, or uploaded file size not match, canceling upload session...")
// 取消上传会话实测OneDrive取消上传会话后客户端还是可以上传
// 所以上传一个空文件占位,阻止客户端上传
client.DeleteUploadSession(context.Background(), uploadURL)
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
if err != nil {
util.Log().Debug("无法创建占位文件,%s", err)
}
return
}
}
}
}
func sysError(err error) *RespError {
return &RespError{APIError: APIError{
Code: "system",
Message: err.Error(),
}}
}
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
// 获取凭证
err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
if err != nil {
return "", sysError(err)
}
option = append(option,
request.WithHeader(http.Header{
"Authorization": {"Bearer " + client.Credential.AccessToken},
"Content-Type": {"application/json"},
}),
request.WithContext(ctx),
request.WithTPSLimit(
fmt.Sprintf("policy_%d", client.Policy.ID),
client.Policy.OptionsSerialized.TPSLimit,
client.Policy.OptionsSerialized.TPSLimitBurst,
),
)
// 发送请求
res := client.Request.Request(
method,
url,
body,
option...,
)
if res.Err != nil {
return "", sysError(res.Err)
}
respBody, err := res.GetResponse()
if err != nil {
return "", sysError(err)
}
// 解析请求响应
var (
errResp RespError
decodeErr error
)
// 如果有错误
if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
if decodeErr != nil {
util.Log().Debug("Onedrive returns unknown response: %s", respBody)
return "", sysError(decodeErr)
}
if res.Response.StatusCode == 429 {
util.Log().Warning("OneDrive request is throttled.")
return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
}
return "", &errResp
}
return respBody, nil
}
func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
// 发送请求
bodyReader := io.NopCloser(strings.NewReader(body))
return client.request(ctx, method, url, bodyReader,
request.WithContentLength(int64(len(body))),
)
}

View File

@ -0,0 +1,78 @@
package onedrive
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
)
var (
// ErrAuthEndpoint 无法解析授权端点地址
ErrAuthEndpoint = errors.New("failed to parse endpoint url")
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
// ErrDeleteFile 无法删除文件
ErrDeleteFile = errors.New("cannot delete file")
// ErrClientCanceled 客户端取消操作
ErrClientCanceled = errors.New("client canceled")
// Desired thumb size not available
ErrThumbSizeNotFound = errors.New("thumb size not found")
)
// Client OneDrive客户端
type Client struct {
Endpoints *Endpoints
Policy *model.Policy
Credential *Credential
ClientID string
ClientSecret string
Redirect string
Request request.Client
ClusterController cluster.Controller
}
// Endpoints OneDrive客户端相关设置
type Endpoints struct {
OAuthURL string // OAuth认证的基URL
OAuthEndpoints *oauthEndpoint
EndpointURL string // 接口请求的基URL
isInChina bool // 是否为世纪互联
DriverResource string // 要使用的驱动器
}
// NewClient 根据存储策略获取新的client
func NewClient(policy *model.Policy) (*Client, error) {
client := &Client{
Endpoints: &Endpoints{
OAuthURL: policy.BaseURL,
EndpointURL: policy.Server,
DriverResource: policy.OptionsSerialized.OdDriver,
},
Credential: &Credential{
RefreshToken: policy.AccessKey,
},
Policy: policy,
ClientID: policy.BucketName,
ClientSecret: policy.SecretKey,
Redirect: policy.OptionsSerialized.OauthRedirect,
Request: request.NewClient(),
ClusterController: cluster.DefaultController,
}
if client.Endpoints.DriverResource == "" {
client.Endpoints.DriverResource = "me/drive"
}
oauthBase := client.getOAuthEndpoint()
if oauthBase == nil {
return nil, ErrAuthEndpoint
}
client.Endpoints.OAuthEndpoints = oauthBase
return client, nil
}

View File

@ -0,0 +1,238 @@
package onedrive
import (
"context"
"errors"
"fmt"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver OneDrive 适配器
type Driver struct {
Policy *model.Policy
Client *Client
HTTPClient request.Client
}
// NewDriver 从存储策略初始化新的Driver实例
func NewDriver(policy *model.Policy) (driver.Handler, error) {
client, err := NewClient(policy)
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 50 << 20 // 50MB
}
return Driver{
Policy: policy,
Client: client,
HTTPClient: request.NewClient(),
}, err
}
// List 列取项目
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
base = strings.TrimPrefix(base, "/")
// 列取子项目
objects, _ := handler.Client.ListChildren(ctx, base)
// 获取真实的列取起始根目录
rootPath := base
if realBase, ok := ctx.Value(fsctx.PathCtx).(string); ok {
rootPath = realBase
} else {
ctx = context.WithValue(ctx, fsctx.PathCtx, base)
}
// 整理结果
res := make([]response.Object, 0, len(objects))
for _, object := range objects {
source := path.Join(base, object.Name)
rel, err := filepath.Rel(rootPath, source)
if err != nil {
continue
}
res = append(res, response.Object{
Name: object.Name,
RelativePath: filepath.ToSlash(rel),
Source: source,
Size: object.Size,
IsDir: object.Folder != nil,
LastModify: time.Now(),
})
}
// 递归列取子目录
if recursive {
for _, object := range objects {
if object.Folder != nil {
sub, _ := handler.List(ctx, path.Join(base, object.Name), recursive)
res = append(res, sub...)
}
}
}
return res, nil
}
// Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(
ctx,
path,
60,
false,
0,
)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.HTTPClient.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
return handler.Client.Upload(ctx, file)
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return handler.Client.BatchDelete(ctx, files)
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
res, err := handler.Client.GetThumbURL(ctx, file.SourceName, thumbSize[0], thumbSize[1])
if err != nil {
var apiErr *RespError
if errors.As(err, &apiErr); err == ErrThumbSizeNotFound || (apiErr != nil && apiErr.APIError.Code == notFoundError) {
// OneDrive cannot generate thumbnail for this file
return nil, driver.ErrorThumbNotSupported
}
}
return &response.ContentResponse{
Redirect: true,
URL: res,
}, err
}
// Source 获取外链URL
func (handler Driver) Source(
ctx context.Context,
path string,
ttl int64,
isDownload bool,
speed int,
) (string, error) {
cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID)
}
// 尝试从缓存中查找
if cachedURL, ok := cache.Get(cacheKey); ok {
return handler.replaceSourceHost(cachedURL.(string))
}
// 缓存不存在,重新获取
res, err := handler.Client.Meta(ctx, "", path)
if err == nil {
// 写入新的缓存
cache.Set(
cacheKey,
res.DownloadURL,
model.GetIntSetting("onedrive_source_timeout", 1800),
)
return handler.replaceSourceHost(res.DownloadURL)
}
return "", err
}
func (handler Driver) replaceSourceHost(origin string) (string, error) {
if handler.Policy.OptionsSerialized.OdProxy != "" {
source, err := url.Parse(origin)
if err != nil {
return "", err
}
cdn, err := url.Parse(handler.Policy.OptionsSerialized.OdProxy)
if err != nil {
return "", err
}
// 替换反代地址
source.Scheme = cdn.Scheme
source.Host = cdn.Host
return source.String(), nil
}
return origin, nil
}
// Token 获取上传会话URL
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
fileInfo := file.Info()
uploadURL, err := handler.Client.CreateUploadSession(ctx, fileInfo.SavePath, WithConflictBehavior("fail"))
if err != nil {
return nil, err
}
// 监控回调及上传
go handler.Client.MonitorUpload(uploadURL, uploadSession.Key, fileInfo.SavePath, fileInfo.Size, ttl)
uploadSession.UploadURL = uploadURL
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadURLs: []string{uploadURL},
}, nil
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return handler.Client.DeleteUploadSession(ctx, uploadSession.UploadURL)
}

View File

@ -0,0 +1,25 @@
package onedrive
import "sync"
// CredentialLock 针对存储策略凭证的锁
type CredentialLock interface {
Lock(uint)
Unlock(uint)
}
var GlobalMutex = mutexMap{}
type mutexMap struct {
locks sync.Map
}
func (m *mutexMap) Lock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Lock()
}
func (m *mutexMap) Unlock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Unlock()
}

View File

@ -0,0 +1,192 @@
package onedrive
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Error 实现error接口
func (err OAuthError) Error() string {
return err.ErrorDescription
}
// OAuthURL 获取OAuth认证页面URL
func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
query := url.Values{
"client_id": {client.ClientID},
"scope": {strings.Join(scope, " ")},
"response_type": {"code"},
"redirect_uri": {client.Redirect},
}
client.Endpoints.OAuthEndpoints.authorize.RawQuery = query.Encode()
return client.Endpoints.OAuthEndpoints.authorize.String()
}
// getOAuthEndpoint 根据指定的AuthURL获取详细的认证接口地址
func (client *Client) getOAuthEndpoint() *oauthEndpoint {
base, err := url.Parse(client.Endpoints.OAuthURL)
if err != nil {
return nil
}
var (
token *url.URL
authorize *url.URL
)
switch base.Host {
case "login.live.com":
token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
case "login.chinacloudapi.cn":
client.Endpoints.isInChina = true
token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
default:
token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
}
return &oauthEndpoint{
token: *token,
authorize: *authorize,
}
}
// ObtainToken 通过code或refresh_token兑换token
func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
options := newDefaultOption()
for _, o := range opts {
o.apply(options)
}
body := url.Values{
"client_id": {client.ClientID},
"redirect_uri": {client.Redirect},
"client_secret": {client.ClientSecret},
}
if options.code != "" {
body.Add("grant_type", "authorization_code")
body.Add("code", options.code)
} else {
body.Add("grant_type", "refresh_token")
body.Add("refresh_token", options.refreshToken)
}
strBody := body.Encode()
res := client.Request.Request(
"POST",
client.Endpoints.OAuthEndpoints.token.String(),
ioutil.NopCloser(strings.NewReader(strBody)),
request.WithHeader(http.Header{
"Content-Type": {"application/x-www-form-urlencoded"}},
),
request.WithContentLength(int64(len(strBody))),
)
if res.Err != nil {
return nil, res.Err
}
respBody, err := res.GetResponse()
if err != nil {
return nil, err
}
var (
errResp OAuthError
credential Credential
decodeErr error
)
if res.Response.StatusCode != 200 {
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
} else {
decodeErr = json.Unmarshal([]byte(respBody), &credential)
}
if decodeErr != nil {
return nil, decodeErr
}
if errResp.ErrorType != "" {
return nil, errResp
}
return &credential, nil
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
if isSlave {
return client.fetchCredentialFromMaster(ctx)
}
oauth.GlobalMutex.Lock(client.Policy.ID)
defer oauth.GlobalMutex.Unlock(client.Policy.ID)
// 如果已存在凭证
if client.Credential != nil && client.Credential.AccessToken != "" {
// 检查已有凭证是否过期
if client.Credential.ExpiresIn > time.Now().Unix() {
// 未过期,不要更新
return nil
}
}
// 尝试从缓存中获取凭证
if cacheCredential, ok := cache.Get("onedrive_" + client.ClientID); ok {
credential := cacheCredential.(Credential)
if credential.ExpiresIn > time.Now().Unix() {
client.Credential = &credential
return nil
}
}
// 获取新的凭证
if client.Credential == nil || client.Credential.RefreshToken == "" {
// 无有效的RefreshToken
util.Log().Error("Failed to refresh credential for policy %q, please login your Microsoft account again.", client.Policy.Name)
return ErrInvalidRefreshToken
}
credential, err := client.ObtainToken(ctx, WithRefreshToken(client.Credential.RefreshToken))
if err != nil {
return err
}
// 更新有效期为绝对时间戳
expires := credential.ExpiresIn - 60
credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
client.Credential = credential
// 更新存储策略的 RefreshToken
client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken)
// 更新缓存
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
return nil
}
func (client *Client) AccessToken() string {
return client.Credential.AccessToken
}
// UpdateCredential 更新凭证,并检查有效期
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
if err != nil {
return err
}
client.Credential = &Credential{AccessToken: res}
return nil
}

View File

@ -0,0 +1,59 @@
package onedrive
import "time"
// Option 发送请求的额外设置
type Option interface {
apply(*options)
}
type options struct {
redirect string
code string
refreshToken string
conflictBehavior string
expires time.Time
useDriverResource bool
}
type optionFunc func(*options)
// WithCode 设置接口Code
func WithCode(t string) Option {
return optionFunc(func(o *options) {
o.code = t
})
}
// WithRefreshToken 设置接口RefreshToken
func WithRefreshToken(t string) Option {
return optionFunc(func(o *options) {
o.refreshToken = t
})
}
// WithConflictBehavior 设置文件重名后的处理方式
func WithConflictBehavior(t string) Option {
return optionFunc(func(o *options) {
o.conflictBehavior = t
})
}
// WithConflictBehavior 设置文件重名后的处理方式
func WithDriverResource(t bool) Option {
return optionFunc(func(o *options) {
o.useDriverResource = t
})
}
func (f optionFunc) apply(o *options) {
f(o)
}
func newDefaultOption() *options {
return &options{
conflictBehavior: "fail",
useDriverResource: true,
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
}
}

View File

@ -0,0 +1,140 @@
package onedrive
import (
"encoding/gob"
"net/url"
)
// RespError 接口返回错误
type RespError struct {
APIError APIError `json:"error"`
}
// APIError 接口返回的错误内容
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// UploadSessionResponse 分片上传会话
type UploadSessionResponse struct {
DataContext string `json:"@odata.context"`
ExpirationDateTime string `json:"expirationDateTime"`
NextExpectedRanges []string `json:"nextExpectedRanges"`
UploadURL string `json:"uploadUrl"`
}
// FileInfo 文件元信息
type FileInfo struct {
Name string `json:"name"`
Size uint64 `json:"size"`
Image imageInfo `json:"image"`
ParentReference parentReference `json:"parentReference"`
DownloadURL string `json:"@microsoft.graph.downloadUrl"`
File *file `json:"file"`
Folder *folder `json:"folder"`
}
type file struct {
MimeType string `json:"mimeType"`
}
type folder struct {
ChildCount int `json:"childCount"`
}
type imageInfo struct {
Height int `json:"height"`
Width int `json:"width"`
}
type parentReference struct {
Path string `json:"path"`
Name string `json:"name"`
ID string `json:"id"`
}
// UploadResult 上传结果
type UploadResult struct {
ID string `json:"id"`
Name string `json:"name"`
Size uint64 `json:"size"`
}
// BatchRequests 批量操作请求
type BatchRequests struct {
Requests []BatchRequest `json:"requests"`
}
// BatchRequest 批量操作单个请求
type BatchRequest struct {
ID string `json:"id"`
Method string `json:"method"`
URL string `json:"url"`
Body interface{} `json:"body,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
}
// BatchResponses 批量操作响应
type BatchResponses struct {
Responses []BatchResponse `json:"responses"`
}
// BatchResponse 批量操作单个响应
type BatchResponse struct {
ID string `json:"id"`
Status int `json:"status"`
}
// ThumbResponse 获取缩略图的响应
type ThumbResponse struct {
Value []map[string]interface{} `json:"value"`
URL string `json:"url"`
}
// ListResponse 列取子项目响应
type ListResponse struct {
Value []FileInfo `json:"value"`
Context string `json:"@odata.context"`
}
// oauthEndpoint OAuth接口地址
type oauthEndpoint struct {
token url.URL
authorize url.URL
}
// Credential 获取token时返回的凭证
type Credential struct {
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
UserID string `json:"user_id"`
}
// OAuthError OAuth相关接口的错误响应
type OAuthError struct {
ErrorType string `json:"error"`
ErrorDescription string `json:"error_description"`
CorrelationID string `json:"correlation_id"`
}
// Site SharePoint 站点信息
type Site struct {
Description string `json:"description"`
ID string `json:"id"`
Name string `json:"name"`
DisplayName string `json:"displayName"`
WebUrl string `json:"webUrl"`
}
func init() {
gob.Register(Credential{})
}
// Error 实现error接口
func (err RespError) Error() string {
return err.APIError.Message
}

View File

@ -0,0 +1,117 @@
package oss
import (
"bytes"
"crypto"
"crypto/md5"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
)
// GetPublicKey 从回调请求或缓存中获取OSS的回调签名公钥
func GetPublicKey(r *http.Request) ([]byte, error) {
var pubKey []byte
// 尝试从缓存中获取
pub, exist := cache.Get("oss_public_key")
if exist {
return pub.([]byte), nil
}
// 从请求中获取
pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get("x-oss-pub-key-url"))
if err != nil {
return pubKey, err
}
// 确保这个 public key 是由 OSS 颁发的
if !strings.HasPrefix(string(pubURL), "http://gosspublic.alicdn.com/") &&
!strings.HasPrefix(string(pubURL), "https://gosspublic.alicdn.com/") {
return pubKey, errors.New("public key url invalid")
}
// 获取公钥
client := request.NewClient()
body, err := client.Request("GET", string(pubURL), nil).
CheckHTTPResponse(200).
GetResponse()
if err != nil {
return pubKey, err
}
// 写入缓存
_ = cache.Set("oss_public_key", []byte(body), 86400*7)
return []byte(body), nil
}
func getRequestMD5(r *http.Request) ([]byte, error) {
var byteMD5 []byte
// 获取请求正文
body, err := ioutil.ReadAll(r.Body)
r.Body.Close()
if err != nil {
return byteMD5, err
}
r.Body = ioutil.NopCloser(bytes.NewReader(body))
strURLPathDecode, err := url.PathUnescape(r.URL.Path)
if err != nil {
return byteMD5, err
}
strAuth := fmt.Sprintf("%s\n%s", strURLPathDecode, string(body))
md5Ctx := md5.New()
md5Ctx.Write([]byte(strAuth))
byteMD5 = md5Ctx.Sum(nil)
return byteMD5, nil
}
// VerifyCallbackSignature 验证OSS回调请求
func VerifyCallbackSignature(r *http.Request) error {
bytePublicKey, err := GetPublicKey(r)
if err != nil {
return err
}
byteMD5, err := getRequestMD5(r)
if err != nil {
return err
}
strAuthorizationBase64 := r.Header.Get("authorization")
if strAuthorizationBase64 == "" {
return errors.New("no authorization field in Request header")
}
authorization, _ := base64.StdEncoding.DecodeString(strAuthorizationBase64)
pubBlock, _ := pem.Decode(bytePublicKey)
if pubBlock == nil {
return errors.New("pubBlock not exist")
}
pubInterface, err := x509.ParsePKIXPublicKey(pubBlock.Bytes)
if (pubInterface == nil) || (err != nil) {
return err
}
pub := pubInterface.(*rsa.PublicKey)
errorVerifyPKCS1v15 := rsa.VerifyPKCS1v15(pub, crypto.MD5, byteMD5, authorization)
if errorVerifyPKCS1v15 != nil {
return errorVerifyPKCS1v15
}
return nil
}

View File

@ -0,0 +1,501 @@
package oss
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/url"
"path"
"path/filepath"
"strings"
"time"
"github.com/HFO4/aliyun-oss-go-sdk/oss"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// UploadPolicy 阿里云OSS上传策略
type UploadPolicy struct {
Expiration string `json:"expiration"`
Conditions []interface{} `json:"conditions"`
}
// CallbackPolicy 回调策略
type CallbackPolicy struct {
CallbackURL string `json:"callbackUrl"`
CallbackBody string `json:"callbackBody"`
CallbackBodyType string `json:"callbackBodyType"`
}
// Driver 阿里云OSS策略适配器
type Driver struct {
Policy *model.Policy
client *oss.Client
bucket *oss.Bucket
HTTPClient request.Client
}
type key int
const (
chunkRetrySleep = time.Duration(5) * time.Second
// MultiPartUploadThreshold 服务端使用分片上传的阈值
MultiPartUploadThreshold uint64 = 5 * (1 << 30) // 5GB
// VersionID 文件版本标识
VersionID key = iota
)
func NewDriver(policy *model.Policy) (*Driver, error) {
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
}
driver := &Driver{
Policy: policy,
HTTPClient: request.NewClient(),
}
return driver, driver.InitOSSClient(false)
}
// CORS 创建跨域策略
func (handler *Driver) CORS() error {
return handler.client.SetBucketCORS(handler.Policy.BucketName, []oss.CORSRule{
{
AllowedOrigin: []string{"*"},
AllowedMethod: []string{
"GET",
"POST",
"PUT",
"DELETE",
"HEAD",
},
ExposeHeader: []string{},
AllowedHeader: []string{"*"},
MaxAgeSeconds: 3600,
},
})
}
// InitOSSClient 初始化OSS鉴权客户端
func (handler *Driver) InitOSSClient(forceUsePublicEndpoint bool) error {
if handler.Policy == nil {
return errors.New("empty policy")
}
// 决定是否使用内网 Endpoint
endpoint := handler.Policy.Server
if handler.Policy.OptionsSerialized.ServerSideEndpoint != "" && !forceUsePublicEndpoint {
endpoint = handler.Policy.OptionsSerialized.ServerSideEndpoint
}
// 初始化客户端
client, err := oss.New(endpoint, handler.Policy.AccessKey, handler.Policy.SecretKey)
if err != nil {
return err
}
handler.client = client
// 初始化存储桶
bucket, err := client.Bucket(handler.Policy.BucketName)
if err != nil {
return err
}
handler.bucket = bucket
return nil
}
// List 列出OSS上的文件
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
// 列取文件
base = strings.TrimPrefix(base, "/")
if base != "" {
base += "/"
}
var (
delimiter string
marker string
objects []oss.ObjectProperties
commons []string
)
if !recursive {
delimiter = "/"
}
for {
subRes, err := handler.bucket.ListObjects(oss.Marker(marker), oss.Prefix(base),
oss.MaxKeys(1000), oss.Delimiter(delimiter))
if err != nil {
return nil, err
}
objects = append(objects, subRes.Objects...)
commons = append(commons, subRes.CommonPrefixes...)
marker = subRes.NextMarker
if marker == "" {
break
}
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(base, object)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(base, object.Key)
if err != nil {
continue
}
if strings.HasSuffix(object.Key, "/") {
res = append(res, response.Object{
Name: path.Base(object.Key),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
} else {
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Size),
IsDir: false,
LastModify: object.LastModified,
})
}
}
return res, nil
}
// Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 通过VersionID禁止缓存
ctx = context.WithValue(ctx, VersionID, time.Now().UnixNano())
// 尽可能使用私有 Endpoint
ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false)
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.HTTPClient.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
fileInfo := file.Info()
// 凭证有效期
credentialTTL := model.GetIntSetting("upload_session_timeout", 3600)
// 是否允许覆盖
overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite
options := []oss.Option{
oss.Expires(time.Now().Add(time.Duration(credentialTTL) * time.Second)),
oss.ForbidOverWrite(!overwrite),
}
// 小文件直接上传
if fileInfo.Size < MultiPartUploadThreshold {
return handler.bucket.PutObject(fileInfo.SavePath, file, options...)
}
// 超过阈值时使用分片上传
imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...)
if err != nil {
return fmt.Errorf("failed to initiate multipart upload: %w", err)
}
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
Max: model.GetIntSetting("chunk_retries", 5),
Sleep: chunkRetrySleep,
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
_, err := handler.bucket.UploadPart(imur, content, current.Length(), current.Index()+1)
return err
}
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
_, err = handler.bucket.CompleteMultipartUpload(imur, oss.CompleteAll("yes"), oss.ForbidOverWrite(!overwrite))
return err
}
// Delete 删除一个或多个文件,
// 返回未删除的文件
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
// 删除文件
delRes, err := handler.bucket.DeleteObjects(files)
if err != nil {
return files, err
}
// 统计未删除的文件
failed := util.SliceDifference(files, delRes.DeletedObjects)
if len(failed) > 0 {
return failed, errors.New("failed to delete")
}
return []string{}, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://help.aliyun.com/document_detail/183902.html
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heic", "tiff", "avif"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) {
return nil, driver.ErrorThumbNotSupported
}
// 初始化客户端
if err := handler.InitOSSClient(true); err != nil {
return nil, err
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumbParam := fmt.Sprintf("image/resize,m_lfit,h_%d,w_%d/quality,q_%d", thumbSize[1], thumbSize[0], thumbEncodeQuality)
ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, thumbParam)
thumbOption := []oss.Option{oss.Process(thumbParam)}
thumbURL, err := handler.signSourceURL(
ctx,
file.SourceName,
int64(model.GetIntSetting("preview_timeout", 60)),
thumbOption,
)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: thumbURL,
}, nil
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 初始化客户端
usePublicEndpoint := true
if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok {
usePublicEndpoint = forceUsePublicEndpoint
}
if err := handler.InitOSSClient(usePublicEndpoint); err != nil {
return "", err
}
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 添加各项设置
var signOptions = make([]oss.Option, 0, 2)
if isDownload {
signOptions = append(signOptions, oss.ResponseContentDisposition("attachment; filename=\""+url.PathEscape(fileName)+"\""))
}
if speed > 0 {
// Byte 转换为 bit
speed *= 8
// OSS对速度值有范围限制
if speed < 819200 {
speed = 819200
}
if speed > 838860800 {
speed = 838860800
}
signOptions = append(signOptions, oss.TrafficLimitParam(int64(speed)))
}
return handler.signSourceURL(ctx, path, ttl, signOptions)
}
func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64, options []oss.Option) (string, error) {
signedURL, err := handler.bucket.SignURL(path, oss.HTTPGet, ttl, options...)
if err != nil {
return "", err
}
// 将最终生成的签名URL域名换成用户自定义的加速域名如果有
finalURL, err := url.Parse(signedURL)
if err != nil {
return "", err
}
// 公有空间替换掉Key及不支持的头
if !handler.Policy.IsPrivate {
query := finalURL.Query()
query.Del("OSSAccessKeyId")
query.Del("Signature")
query.Del("response-content-disposition")
query.Del("x-oss-traffic-limit")
finalURL.RawQuery = query.Encode()
}
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
finalURL.Host = cdnURL.Host
finalURL.Scheme = cdnURL.Scheme
}
return finalURL.String(), nil
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 初始化客户端
if err := handler.InitOSSClient(true); err != nil {
return nil, err
}
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/oss/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI)
// 回调策略
callbackPolicy := CallbackPolicy{
CallbackURL: apiURL.String(),
CallbackBody: `{"name":${x:fname},"source_name":${object},"size":${size},"pic_info":"${imageInfo.width},${imageInfo.height}"}`,
CallbackBodyType: "application/json",
}
callbackPolicyJSON, err := json.Marshal(callbackPolicy)
if err != nil {
return nil, fmt.Errorf("failed to encode callback policy: %w", err)
}
callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON)
// 初始化分片上传
fileInfo := file.Info()
options := []oss.Option{
oss.Expires(time.Now().Add(time.Duration(ttl) * time.Second)),
oss.ForbidOverWrite(true),
oss.ContentType(fileInfo.DetectMimeType()),
}
imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...)
if err != nil {
return nil, fmt.Errorf("failed to initialize multipart upload: %w", err)
}
uploadSession.UploadID = imur.UploadID
// 为每个分片签名上传 URL
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false)
urls := make([]string, chunks.Num())
for chunks.Next() {
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
signedURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPut, ttl,
oss.PartNumber(c.Index()+1),
oss.UploadID(imur.UploadID),
oss.ContentType("application/octet-stream"))
if err != nil {
return err
}
urls[c.Index()] = signedURL
return nil
})
if err != nil {
return nil, err
}
}
// 签名完成分片上传的URL
completeURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPost, ttl,
oss.ContentType("application/octet-stream"),
oss.UploadID(imur.UploadID),
oss.Expires(time.Now().Add(time.Duration(ttl)*time.Second)),
oss.CompleteAll("yes"),
oss.ForbidOverWrite(true),
oss.CallbackParam(callbackPolicyEncoded))
if err != nil {
return nil, err
}
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadID: imur.UploadID,
UploadURLs: urls,
CompleteURL: completeURL,
}, nil
}
// 取消上传凭证
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return handler.bucket.AbortMultipartUpload(oss.InitiateMultipartUploadResult{UploadID: uploadSession.UploadID, Key: uploadSession.SavePath}, nil)
}

View File

@ -0,0 +1,354 @@
package qiniu
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
)
// Driver 本地策略适配器
type Driver struct {
Policy *model.Policy
mac *qbox.Mac
cfg *storage.Config
bucket *storage.BucketManager
}
func NewDriver(policy *model.Policy) *Driver {
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
}
mac := qbox.NewMac(policy.AccessKey, policy.SecretKey)
cfg := &storage.Config{UseHTTPS: true}
return &Driver{
Policy: policy,
mac: mac,
cfg: cfg,
bucket: storage.NewBucketManager(mac, cfg),
}
}
// List 列出给定路径下的文件
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
base = strings.TrimPrefix(base, "/")
if base != "" {
base += "/"
}
var (
delimiter string
marker string
objects []storage.ListItem
commons []string
)
if !recursive {
delimiter = "/"
}
for {
entries, folders, nextMarker, hashNext, err := handler.bucket.ListFiles(
handler.Policy.BucketName,
base, delimiter, marker, 1000)
if err != nil {
return nil, err
}
objects = append(objects, entries...)
commons = append(commons, folders...)
if !hashNext {
break
}
marker = nextMarker
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(base, object)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(base, object.Key)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(object.Key),
Source: object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(object.Fsize),
IsDir: false,
LastModify: time.Unix(object.PutTime/10000000, 0),
})
}
return res, nil
}
// Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 给文件名加上随机参数以强制拉取
path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano())
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithHeader(
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
// 凭证有效期
credentialTTL := model.GetIntSetting("upload_session_timeout", 3600)
// 生成上传策略
fileInfo := file.Info()
scope := handler.Policy.BucketName
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
scope = fmt.Sprintf("%s:%s", handler.Policy.BucketName, fileInfo.SavePath)
}
putPolicy := storage.PutPolicy{
// 指定为覆盖策略
Scope: scope,
SaveKey: fileInfo.SavePath,
ForceSaveKey: true,
FsizeLimit: int64(fileInfo.Size),
}
// 是否开启了MIMEType限制
if handler.Policy.OptionsSerialized.MimeType != "" {
putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType
}
// 生成上传凭证
token, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, int64(credentialTTL), false)
if err != nil {
return err
}
// 创建上传表单
cfg := storage.Config{}
formUploader := storage.NewFormUploader(&cfg)
ret := storage.PutRet{}
putExtra := storage.PutExtra{
Params: map[string]string{},
}
// 开始上传
err = formUploader.Put(ctx, &ret, token.Credential, fileInfo.SavePath, file, int64(fileInfo.Size), &putExtra)
if err != nil {
return err
}
return nil
}
// Delete 删除一个或多个文件,
// 返回未删除的文件
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
// TODO 大于一千个文件需要分批发送
deleteOps := make([]string, 0, len(files))
for _, key := range files {
deleteOps = append(deleteOps, storage.URIDelete(handler.Policy.BucketName, key))
}
rets, err := handler.bucket.Batch(deleteOps)
// 处理删除结果
if err != nil {
failed := make([]string, 0, len(rets))
for k, ret := range rets {
if ret.Code != 200 && ret.Code != 612 {
failed = append(failed, files[k])
}
}
return failed, errors.New("删除失败")
}
return []string{}, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://developer.qiniu.com/dora/api/basic-processing-images-imageview2
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "tiff", "avif", "psd"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) {
return nil, driver.ErrorThumbNotSupported
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumb := fmt.Sprintf("%s?imageView2/1/w/%d/h/%d/q/%d", file.SourceName, thumbSize[0], thumbSize[1], thumbEncodeQuality)
return &response.ContentResponse{
Redirect: true,
URL: handler.signSourceURL(
ctx,
thumb,
int64(model.GetIntSetting("preview_timeout", 60)),
),
}, nil
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 加入下载相关设置
if isDownload {
path = path + "?attname=" + url.PathEscape(fileName)
}
// 取得原始文件地址
return handler.signSourceURL(ctx, path, ttl), nil
}
func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64) string {
var sourceURL string
if handler.Policy.IsPrivate {
deadline := time.Now().Add(time.Second * time.Duration(ttl)).Unix()
sourceURL = storage.MakePrivateURL(handler.mac, handler.Policy.BaseURL, path, deadline)
} else {
sourceURL = storage.MakePublicURL(handler.Policy.BaseURL, path)
}
return sourceURL
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/qiniu/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI)
// 创建上传策略
fileInfo := file.Info()
putPolicy := storage.PutPolicy{
Scope: handler.Policy.BucketName,
CallbackURL: apiURL.String(),
CallbackBody: `{"size":$(fsize),"pic_info":"$(imageInfo.width),$(imageInfo.height)"}`,
CallbackBodyType: "application/json",
SaveKey: fileInfo.SavePath,
ForceSaveKey: true,
FsizeLimit: int64(handler.Policy.MaxSize),
}
// 是否开启了MIMEType限制
if handler.Policy.OptionsSerialized.MimeType != "" {
putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType
}
credential, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, ttl, true)
if err != nil {
return nil, fmt.Errorf("failed to init parts: %w", err)
}
credential.SessionID = uploadSession.Key
credential.ChunkSize = handler.Policy.OptionsSerialized.ChunkSize
uploadSession.UploadURL = credential.UploadURLs[0]
uploadSession.Credential = credential.Credential
return credential, nil
}
// getUploadCredential 签名上传策略并创建上传会话
func (handler *Driver) getUploadCredential(ctx context.Context, policy storage.PutPolicy, file *fsctx.UploadTaskInfo, TTL int64, resume bool) (*serializer.UploadCredential, error) {
// 上传凭证
policy.Expires = uint64(TTL)
upToken := policy.UploadToken(handler.mac)
// 初始化分片上传
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
upHost, err := resumeUploader.UpHost(handler.Policy.AccessKey, handler.Policy.BucketName)
if err != nil {
return nil, err
}
ret := &storage.InitPartsRet{}
if resume {
err = resumeUploader.InitParts(ctx, upToken, upHost, handler.Policy.BucketName, file.SavePath, true, ret)
}
return &serializer.UploadCredential{
UploadURLs: []string{upHost + "/buckets/" + handler.Policy.BucketName + "/objects/" + base64.URLEncoding.EncodeToString([]byte(file.SavePath)) + "/uploads/" + ret.UploadID},
Credential: upToken,
}, err
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
return resumeUploader.Client.CallWith(ctx, nil, "DELETE", uploadSession.UploadURL, http.Header{"Authorization": {"UpToken " + uploadSession.Credential}}, nil, 0)
}

View File

@ -0,0 +1,195 @@
package remote
import (
"context"
"encoding/json"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gofrs/uuid"
"io"
"net/http"
"net/url"
"path"
"strings"
"time"
)
const (
basePath = "/api/v3/slave/"
OverwriteHeader = auth.CrHeaderPrefix + "Overwrite"
chunkRetrySleep = time.Duration(5) * time.Second
)
// Client to operate uploading to remote slave server
type Client interface {
// CreateUploadSession creates remote upload session
CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error
// GetUploadURL signs an url for uploading file
GetUploadURL(ttl int64, sessionID string) (string, string, error)
// Upload uploads file to remote server
Upload(ctx context.Context, file fsctx.FileHeader) error
// DeleteUploadSession deletes remote upload session
DeleteUploadSession(ctx context.Context, sessionID string) error
}
// NewClient creates new Client from given policy
func NewClient(policy *model.Policy) (Client, error) {
authInstance := auth.HMACAuth{[]byte(policy.SecretKey)}
serverURL, err := url.Parse(policy.Server)
if err != nil {
return nil, err
}
base, _ := url.Parse(basePath)
signTTL := model.GetIntSetting("slave_api_timeout", 60)
return &remoteClient{
policy: policy,
authInstance: authInstance,
httpClient: request.NewClient(
request.WithEndpoint(serverURL.ResolveReference(base).String()),
request.WithCredential(authInstance, int64(signTTL)),
request.WithMasterMeta(),
request.WithSlaveMeta(policy.AccessKey),
),
}, nil
}
type remoteClient struct {
policy *model.Policy
authInstance auth.Auth
httpClient request.Client
}
func (c *remoteClient) Upload(ctx context.Context, file fsctx.FileHeader) error {
ttl := model.GetIntSetting("upload_session_timeout", 86400)
fileInfo := file.Info()
session := &serializer.UploadSession{
Key: uuid.Must(uuid.NewV4()).String(),
VirtualPath: fileInfo.VirtualPath,
Name: fileInfo.FileName,
Size: fileInfo.Size,
SavePath: fileInfo.SavePath,
LastModified: fileInfo.LastModified,
Policy: *c.policy,
}
// Create upload session
overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite
if err := c.CreateUploadSession(ctx, session, int64(ttl), overwrite); err != nil {
return fmt.Errorf("failed to create upload session: %w", err)
}
// Initial chunk groups
chunks := chunk.NewChunkGroup(file, c.policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
Max: model.GetIntSetting("chunk_retries", 5),
Sleep: chunkRetrySleep,
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
return c.uploadChunk(ctx, session.Key, current.Index(), content, overwrite, current.Length())
}
// upload chunks
for chunks.Next() {
if err := chunks.Process(uploadFunc); err != nil {
if err := c.DeleteUploadSession(ctx, session.Key); err != nil {
util.Log().Warning("failed to delete upload session: %s", err)
}
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
}
}
return nil
}
func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error {
resp, err := c.httpClient.Request(
"DELETE",
"upload/"+sessionID,
nil,
request.WithContext(ctx),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (c *remoteClient) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error {
reqBodyEncoded, err := json.Marshal(map[string]interface{}{
"session": session,
"ttl": ttl,
"overwrite": overwrite,
})
if err != nil {
return err
}
bodyReader := strings.NewReader(string(reqBodyEncoded))
resp, err := c.httpClient.Request(
"PUT",
"upload",
bodyReader,
request.WithContext(ctx),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}
func (c *remoteClient) GetUploadURL(ttl int64, sessionID string) (string, string, error) {
base, err := url.Parse(c.policy.Server)
if err != nil {
return "", "", err
}
base.Path = path.Join(base.Path, basePath, "upload", sessionID)
req, err := http.NewRequest("POST", base.String(), nil)
if err != nil {
return "", "", err
}
req = auth.SignRequest(c.authInstance, req, ttl)
return req.URL.String(), req.Header["Authorization"][0], nil
}
func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error {
resp, err := c.httpClient.Request(
"POST",
fmt.Sprintf("upload/%s?chunk=%d", sessionID, index),
chunk,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
request.WithContentLength(size),
request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return err
}
if resp.Code != 0 {
return serializer.NewErrorFromResponse(resp)
}
return nil
}

View File

@ -0,0 +1,311 @@
package remote
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/url"
"path"
"path/filepath"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Driver 远程存储策略适配器
type Driver struct {
Client request.Client
Policy *model.Policy
AuthInstance auth.Auth
uploadClient Client
}
// NewDriver initializes a new Driver from policy
// TODO: refactor all method into upload client
func NewDriver(policy *model.Policy) (*Driver, error) {
client, err := NewClient(policy)
if err != nil {
return nil, err
}
return &Driver{
Policy: policy,
Client: request.NewClient(),
AuthInstance: auth.HMACAuth{[]byte(policy.SecretKey)},
uploadClient: client,
}, nil
}
// List 列取文件
func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
var res []response.Object
reqBody := serializer.ListRequest{
Path: path,
Recursive: recursive,
}
reqBodyEncoded, err := json.Marshal(reqBody)
if err != nil {
return res, err
}
// 发送列表请求
bodyReader := strings.NewReader(string(reqBodyEncoded))
signTTL := model.GetIntSetting("slave_api_timeout", 60)
resp, err := handler.Client.Request(
"POST",
handler.getAPIUrl("list"),
bodyReader,
request.WithCredential(handler.AuthInstance, int64(signTTL)),
request.WithMasterMeta(),
).CheckHTTPResponse(200).DecodeResponse()
if err != nil {
return res, err
}
// 处理列取结果
if resp.Code != 0 {
return res, errors.New(resp.Error)
}
if resStr, ok := resp.Data.(string); ok {
err = json.Unmarshal([]byte(resStr), &res)
if err != nil {
return res, err
}
}
return res, nil
}
// getAPIUrl 获取接口请求地址
func (handler *Driver) getAPIUrl(scope string, routes ...string) string {
serverURL, err := url.Parse(handler.Policy.Server)
if err != nil {
return ""
}
var controller *url.URL
switch scope {
case "delete":
controller, _ = url.Parse("/api/v3/slave/delete")
case "thumb":
controller, _ = url.Parse("/api/v3/slave/thumb")
case "list":
controller, _ = url.Parse("/api/v3/slave/list")
default:
controller = serverURL
}
for _, r := range routes {
controller.Path = path.Join(controller.Path, r)
}
return serverURL.ResolveReference(controller).String()
}
// Get 获取文件内容
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 尝试获取速度限制
speedLimit := 0
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
speedLimit = user.Group.SpeedLimit
}
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, 0, true, speedLimit)
if err != nil {
return nil, err
}
// 获取文件数据流
resp, err := handler.Client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithTimeout(time.Duration(0)),
request.WithMasterMeta(),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
return handler.uploadClient.Upload(ctx, file)
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
// 封装接口请求正文
reqBody := serializer.RemoteDeleteRequest{
Files: files,
}
reqBodyEncoded, err := json.Marshal(reqBody)
if err != nil {
return files, err
}
// 发送删除请求
bodyReader := strings.NewReader(string(reqBodyEncoded))
signTTL := model.GetIntSetting("slave_api_timeout", 60)
resp, err := handler.Client.Request(
"POST",
handler.getAPIUrl("delete"),
bodyReader,
request.WithCredential(handler.AuthInstance, int64(signTTL)),
request.WithMasterMeta(),
request.WithSlaveMeta(handler.Policy.AccessKey),
).CheckHTTPResponse(200).GetResponse()
if err != nil {
return files, err
}
// 处理删除结果
var reqResp serializer.Response
err = json.Unmarshal([]byte(resp), &reqResp)
if err != nil {
return files, err
}
if reqResp.Code != 0 {
var failedResp serializer.RemoteDeleteRequest
if failed, ok := reqResp.Data.(string); ok {
err = json.Unmarshal([]byte(failed), &failedResp)
if err == nil {
return failedResp.Files, errors.New(reqResp.Error)
}
}
return files, errors.New("unknown format of returned response")
}
return []string{}, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
supported := []string{"png", "jpg", "jpeg", "gif"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) {
return nil, driver.ErrorThumbNotSupported
}
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(file.SourceName))
thumbURL := fmt.Sprintf("%s/%s/%s", handler.getAPIUrl("thumb"), sourcePath, filepath.Ext(file.Name))
ttl := model.GetIntSetting("preview_timeout", 60)
signedThumbURL, err := auth.SignURI(handler.AuthInstance, thumbURL, int64(ttl))
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: signedThumbURL.String(),
}, nil
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := "file"
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
serverURL, err := url.Parse(handler.Policy.Server)
if err != nil {
return "", errors.New("无法解析远程服务端地址")
}
// 是否启用了CDN
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
serverURL = cdnURL
}
var (
signedURI *url.URL
controller = "/api/v3/slave/download"
)
if !isDownload {
controller = "/api/v3/slave/source"
}
// 签名下载地址
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path))
signedURI, err = auth.SignURI(
handler.AuthInstance,
fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, url.PathEscape(fileName)),
ttl,
)
if err != nil {
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign URL", err)
}
finalURL := serverURL.ResolveReference(signedURI).String()
return finalURL, nil
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote", uploadSession.Key, uploadSession.CallbackSecret))
apiURL := siteURL.ResolveReference(apiBaseURI)
// 在从机端创建上传会话
uploadSession.Callback = apiURL.String()
if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, ttl, false); err != nil {
return nil, err
}
// 获取上传地址
uploadURL, sign, err := handler.uploadClient.GetUploadURL(ttl, uploadSession.Key)
if err != nil {
return nil, fmt.Errorf("failed to sign upload url: %w", err)
}
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadURLs: []string{uploadURL},
Credential: sign,
}, nil
}
// 取消上传凭证
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Key)
}

View File

@ -0,0 +1,440 @@
package s3
import (
"context"
"errors"
"fmt"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"io"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"time"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver 适配器模板
type Driver struct {
Policy *model.Policy
sess *session.Session
svc *s3.S3
}
// UploadPolicy S3上传策略
type UploadPolicy struct {
Expiration string `json:"expiration"`
Conditions []interface{} `json:"conditions"`
}
// MetaData 文件信息
type MetaData struct {
Size uint64
Etag string
}
func NewDriver(policy *model.Policy) (*Driver, error) {
if policy.OptionsSerialized.ChunkSize == 0 {
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
}
driver := &Driver{
Policy: policy,
}
return driver, driver.InitS3Client()
}
// InitS3Client 初始化S3会话
func (handler *Driver) InitS3Client() error {
if handler.Policy == nil {
return errors.New("empty policy")
}
if handler.svc == nil {
// 初始化会话
sess, err := session.NewSession(&aws.Config{
Credentials: credentials.NewStaticCredentials(handler.Policy.AccessKey, handler.Policy.SecretKey, ""),
Endpoint: &handler.Policy.Server,
Region: &handler.Policy.OptionsSerialized.Region,
S3ForcePathStyle: &handler.Policy.OptionsSerialized.S3ForcePathStyle,
})
if err != nil {
return err
}
handler.sess = sess
handler.svc = s3.New(sess)
}
return nil
}
// List 列出给定路径下的文件
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
// 初始化列目录参数
base = strings.TrimPrefix(base, "/")
if base != "" {
base += "/"
}
opt := &s3.ListObjectsInput{
Bucket: &handler.Policy.BucketName,
Prefix: &base,
MaxKeys: aws.Int64(1000),
}
// 是否为递归列出
if !recursive {
opt.Delimiter = aws.String("/")
}
var (
objects []*s3.Object
commons []*s3.CommonPrefix
)
for {
res, err := handler.svc.ListObjectsWithContext(ctx, opt)
if err != nil {
return nil, err
}
objects = append(objects, res.Contents...)
commons = append(commons, res.CommonPrefixes...)
// 如果本次未列取完则继续使用marker获取结果
if *res.IsTruncated {
opt.Marker = res.NextMarker
} else {
break
}
}
// 处理列取结果
res := make([]response.Object, 0, len(objects)+len(commons))
// 处理目录
for _, object := range commons {
rel, err := filepath.Rel(*opt.Prefix, *object.Prefix)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(*object.Prefix),
RelativePath: filepath.ToSlash(rel),
Size: 0,
IsDir: true,
LastModify: time.Now(),
})
}
// 处理文件
for _, object := range objects {
rel, err := filepath.Rel(*opt.Prefix, *object.Key)
if err != nil {
continue
}
res = append(res, response.Object{
Name: path.Base(*object.Key),
Source: *object.Key,
RelativePath: filepath.ToSlash(rel),
Size: uint64(*object.Size),
IsDir: false,
LastModify: time.Now(),
})
}
return res, nil
}
// Get 获取文件
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithHeader(
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
// 初始化客户端
if err := handler.InitS3Client(); err != nil {
return err
}
uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) {
u.PartSize = int64(handler.Policy.OptionsSerialized.ChunkSize)
})
dst := file.Info().SavePath
_, err := uploader.Upload(&s3manager.UploadInput{
Bucket: &handler.Policy.BucketName,
Key: &dst,
Body: io.LimitReader(file, int64(file.Info().Size)),
})
if err != nil {
return err
}
return nil
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
failed := make([]string, 0, len(files))
deleted := make([]string, 0, len(files))
keys := make([]*s3.ObjectIdentifier, 0, len(files))
for _, file := range files {
filePath := file
keys = append(keys, &s3.ObjectIdentifier{Key: &filePath})
}
// 发送异步删除请求
res, err := handler.svc.DeleteObjects(
&s3.DeleteObjectsInput{
Bucket: &handler.Policy.BucketName,
Delete: &s3.Delete{
Objects: keys,
},
})
if err != nil {
return files, err
}
// 统计未删除的文件
for _, deleteRes := range res.Deleted {
deleted = append(deleted, *deleteRes.Key)
}
failed = util.SliceDifference(files, deleted)
return failed, nil
}
// Thumb 获取文件缩略图
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
return nil, driver.ErrorThumbNotSupported
}
// Source 获取外链URL
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
// 初始化客户端
if err := handler.InitS3Client(); err != nil {
return "", err
}
contentDescription := aws.String("attachment; filename=\"" + url.PathEscape(fileName) + "\"")
if !isDownload {
contentDescription = nil
}
req, _ := handler.svc.GetObjectRequest(
&s3.GetObjectInput{
Bucket: &handler.Policy.BucketName,
Key: &path,
ResponseContentDisposition: contentDescription,
})
signedURL, err := req.Presign(time.Duration(ttl) * time.Second)
if err != nil {
return "", err
}
// 将最终生成的签名URL域名换成用户自定义的加速域名如果有
finalURL, err := url.Parse(signedURL)
if err != nil {
return "", err
}
// 公有空间替换掉Key及不支持的头
if !handler.Policy.IsPrivate {
finalURL.RawQuery = ""
}
if handler.Policy.BaseURL != "" {
cdnURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
finalURL.Host = cdnURL.Host
finalURL.Scheme = cdnURL.Scheme
}
return finalURL.String(), nil
}
// Token 获取上传策略和认证Token
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 检查文件是否存在
fileInfo := file.Info()
if _, err := handler.Meta(ctx, fileInfo.SavePath); err == nil {
return nil, fmt.Errorf("file already exist")
}
// 创建分片上传
expires := time.Now().Add(time.Duration(ttl) * time.Second)
res, err := handler.svc.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
Bucket: &handler.Policy.BucketName,
Key: &fileInfo.SavePath,
Expires: &expires,
ContentType: aws.String(fileInfo.DetectMimeType()),
})
if err != nil {
return nil, fmt.Errorf("failed to create multipart upload: %w", err)
}
uploadSession.UploadID = *res.UploadId
// 为每个分片签名上传 URL
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false)
urls := make([]string, chunks.Num())
for chunks.Next() {
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{
Bucket: &handler.Policy.BucketName,
Key: &fileInfo.SavePath,
PartNumber: aws.Int64(int64(c.Index() + 1)),
UploadId: res.UploadId,
})
signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
if err != nil {
return err
}
urls[c.Index()] = signedURL
return nil
})
if err != nil {
return nil, err
}
}
// 签名完成分片上传的请求URL
signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{
Bucket: &handler.Policy.BucketName,
Key: &fileInfo.SavePath,
UploadId: res.UploadId,
})
signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
if err != nil {
return nil, err
}
// 生成上传凭证
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
UploadID: *res.UploadId,
UploadURLs: urls,
CompleteURL: signedURL,
}, nil
}
// Meta 获取文件信息
func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
res, err := handler.svc.HeadObject(
&s3.HeadObjectInput{
Bucket: &handler.Policy.BucketName,
Key: &path,
})
if err != nil {
return nil, err
}
return &MetaData{
Size: uint64(*res.ContentLength),
Etag: *res.ETag,
}, nil
}
// CORS 创建跨域策略
func (handler *Driver) CORS() error {
rule := s3.CORSRule{
AllowedMethods: aws.StringSlice([]string{
"GET",
"POST",
"PUT",
"DELETE",
"HEAD",
}),
AllowedOrigins: aws.StringSlice([]string{"*"}),
AllowedHeaders: aws.StringSlice([]string{"*"}),
ExposeHeaders: aws.StringSlice([]string{"ETag"}),
MaxAgeSeconds: aws.Int64(3600),
}
_, err := handler.svc.PutBucketCors(&s3.PutBucketCorsInput{
Bucket: &handler.Policy.BucketName,
CORSConfiguration: &s3.CORSConfiguration{
CORSRules: []*s3.CORSRule{&rule},
},
})
return err
}
// 取消上传凭证
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
_, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
UploadId: &uploadSession.UploadID,
Bucket: &handler.Policy.BucketName,
Key: &uploadSession.SavePath,
})
return err
}

View File

@ -0,0 +1,7 @@
package masterinslave
import "errors"
var (
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
)

View File

@ -0,0 +1,60 @@
package masterinslave
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver 影子存储策略,用于在从机端上传文件
type Driver struct {
master cluster.Node
handler driver.Handler
policy *model.Policy
}
// NewDriver 返回新的处理器
func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
return &Driver{
master: master,
handler: handler,
policy: policy,
}
}
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
return d.handler.Put(ctx, file)
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return d.handler.Delete(ctx, files)
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented
}
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
return nil, ErrNotImplemented
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
return nil, ErrNotImplemented
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}

View File

@ -0,0 +1,9 @@
package slaveinmaster
import "errors"
var (
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
ErrWaitResultTimeout = errors.New("timeout waiting for slave transfer result")
)

View File

@ -0,0 +1,124 @@
package slaveinmaster
import (
"bytes"
"context"
"encoding/json"
"errors"
"net/url"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果
type Driver struct {
node cluster.Node
handler driver.Handler
policy *model.Policy
client request.Client
}
// NewDriver 返回新的从机指派处理器
func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
var endpoint *url.URL
if serverURL, err := url.Parse(node.DBModel().Server); err == nil {
var controller *url.URL
controller, _ = url.Parse("/api/v3/slave/")
endpoint = serverURL.ResolveReference(controller)
}
signTTL := model.GetIntSetting("slave_api_timeout", 60)
return &Driver{
node: node,
handler: handler,
policy: policy,
client: request.NewClient(
request.WithMasterMeta(),
request.WithTimeout(time.Duration(signTTL)*time.Second),
request.WithCredential(node.SlaveAuthInstance(), int64(signTTL)),
request.WithEndpoint(endpoint.String()),
),
}
}
// Put 将ctx中指定的从机物理文件由从机上传到目标存储策略
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
fileInfo := file.Info()
req := serializer.SlaveTransferReq{
Src: fileInfo.Src,
Dst: fileInfo.SavePath,
Policy: d.policy,
}
body, err := json.Marshal(req)
if err != nil {
return err
}
// 订阅转存结果
resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0)
defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan)
res, err := d.client.Request("PUT", "task/transfer", bytes.NewReader(body)).
CheckHTTPResponse(200).
DecodeResponse()
if err != nil {
return err
}
if res.Code != 0 {
return serializer.NewErrorFromResponse(res)
}
// 等待转存结果或者超时
waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800)
select {
case <-time.After(time.Duration(waitTimeout) * time.Second):
return ErrWaitResultTimeout
case msg := <-resChan:
if msg.Event != serializer.SlaveTransferSuccess {
return errors.New(msg.Content.(serializer.SlaveTransferResult).Error)
}
}
return nil
}
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
return d.handler.Delete(ctx, files)
}
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
return nil, ErrNotImplemented
}
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
return "", ErrNotImplemented
}
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
return nil, ErrNotImplemented
}
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
return nil, ErrNotImplemented
}
// 取消上传凭证
func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}

View File

@ -0,0 +1,358 @@
package upyun
import (
"context"
"crypto/hmac"
"crypto/md5"
"crypto/sha1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"sync"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/upyun/go-sdk/upyun"
)
// UploadPolicy 又拍云上传策略
type UploadPolicy struct {
Bucket string `json:"bucket"`
SaveKey string `json:"save-key"`
Expiration int64 `json:"expiration"`
CallbackURL string `json:"notify-url"`
ContentLength uint64 `json:"content-length"`
ContentLengthRange string `json:"content-length-range,omitempty"`
AllowFileType string `json:"allow-file-type,omitempty"`
}
// Driver 又拍云策略适配器
type Driver struct {
Policy *model.Policy
}
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
base = strings.TrimPrefix(base, "/")
// 用于接受SDK返回对象的chan
objChan := make(chan *upyun.FileInfo)
objects := []*upyun.FileInfo{}
// 列取配置
listConf := &upyun.GetObjectsConfig{
Path: "/" + base,
ObjectsChan: objChan,
MaxListTries: 1,
}
// 递归列取时不限制递归次数
if recursive {
listConf.MaxListLevel = -1
}
// 启动一个goroutine收集列取的对象信
wg := &sync.WaitGroup{}
wg.Add(1)
go func(input chan *upyun.FileInfo, output *[]*upyun.FileInfo, wg *sync.WaitGroup) {
defer wg.Done()
for {
file, ok := <-input
if !ok {
return
}
*output = append(*output, file)
}
}(objChan, &objects, wg)
up := upyun.NewUpYun(&upyun.UpYunConfig{
Bucket: handler.Policy.BucketName,
Operator: handler.Policy.AccessKey,
Password: handler.Policy.SecretKey,
})
err := up.List(listConf)
if err != nil {
return nil, err
}
wg.Wait()
// 汇总处理列取结果
res := make([]response.Object, 0, len(objects))
for _, object := range objects {
res = append(res, response.Object{
Name: path.Base(object.Name),
RelativePath: object.Name,
Source: path.Join(base, object.Name),
Size: uint64(object.Size),
IsDir: object.IsDir,
LastModify: object.Time,
})
}
return res, nil
}
// Get 获取文件
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
// 获取文件源地址
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
// 获取文件数据流
client := request.NewClient()
resp, err := client.Request(
"GET",
downloadURL,
nil,
request.WithContext(ctx),
request.WithHeader(
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
),
request.WithTimeout(time.Duration(0)),
).CheckHTTPResponse(200).GetRSCloser()
if err != nil {
return nil, err
}
resp.SetFirstFakeChunk()
// 尝试自主获取文件大小
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
resp.SetContentLength(int64(file.Size))
}
return resp, nil
}
// Put 将文件流保存到指定目录
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
defer file.Close()
up := upyun.NewUpYun(&upyun.UpYunConfig{
Bucket: handler.Policy.BucketName,
Operator: handler.Policy.AccessKey,
Password: handler.Policy.SecretKey,
})
err := up.Put(&upyun.PutObjectConfig{
Path: file.Info().SavePath,
Reader: file,
})
return err
}
// Delete 删除一个或多个文件,
// 返回未删除的文件,及遇到的最后一个错误
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
up := upyun.NewUpYun(&upyun.UpYunConfig{
Bucket: handler.Policy.BucketName,
Operator: handler.Policy.AccessKey,
Password: handler.Policy.SecretKey,
})
var (
failed = make([]string, 0, len(files))
lastErr error
currentIndex = 0
indexLock sync.Mutex
failedLock sync.Mutex
wg sync.WaitGroup
routineNum = 4
)
wg.Add(routineNum)
// upyun不支持批量操作这里开四个协程并行操作
for i := 0; i < routineNum; i++ {
go func() {
for {
// 取得待删除文件
indexLock.Lock()
if currentIndex >= len(files) {
// 所有文件处理完成
wg.Done()
indexLock.Unlock()
return
}
path := files[currentIndex]
currentIndex++
indexLock.Unlock()
// 发送异步删除请求
err := up.Delete(&upyun.DeleteObjectConfig{
Path: path,
Async: true,
})
// 处理错误
if err != nil {
failedLock.Lock()
lastErr = err
failed = append(failed, path)
failedLock.Unlock()
}
}
}()
}
wg.Wait()
return failed, lastErr
}
// Thumb 获取文件缩略图
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
// quick check by extension name
// https://help.upyun.com/knowledge-base/image/
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "svg"}
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
supported = handler.Policy.OptionsSerialized.ThumbExts
}
if !util.IsInExtensionList(supported, file.Name) {
return nil, driver.ErrorThumbNotSupported
}
var (
thumbSize = [2]uint{400, 300}
ok = false
)
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
return nil, errors.New("failed to get thumbnail size")
}
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
thumbParam := fmt.Sprintf("!/fwfh/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality)
thumbURL, err := handler.Source(ctx, file.SourceName+thumbParam, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: thumbURL,
}, nil
}
// Source 获取外链URL
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
// 尝试从上下文获取文件名
fileName := ""
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
fileName = file.Name
}
sourceURL, err := url.Parse(handler.Policy.BaseURL)
if err != nil {
return "", err
}
fileKey, err := url.Parse(url.PathEscape(path))
if err != nil {
return "", err
}
sourceURL = sourceURL.ResolveReference(fileKey)
// 如果是下载文件URL
if isDownload {
query := sourceURL.Query()
query.Add("_upd", fileName)
sourceURL.RawQuery = query.Encode()
}
return handler.signURL(ctx, sourceURL, ttl)
}
func (handler Driver) signURL(ctx context.Context, path *url.URL, TTL int64) (string, error) {
if !handler.Policy.IsPrivate {
// 未开启Token防盗链时直接返回
return path.String(), nil
}
etime := time.Now().Add(time.Duration(TTL) * time.Second).Unix()
signStr := fmt.Sprintf(
"%s&%d&%s",
handler.Policy.OptionsSerialized.Token,
etime,
path.Path,
)
signMd5 := fmt.Sprintf("%x", md5.Sum([]byte(signStr)))
finalSign := signMd5[12:20] + strconv.FormatInt(etime, 10)
// 将签名添加到URL中
query := path.Query()
query.Add("_upt", finalSign)
path.RawQuery = query.Encode()
return path.String(), nil
}
// Token 获取上传策略和认证Token
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
// 生成回调地址
siteURL := model.GetSiteURL()
apiBaseURI, _ := url.Parse("/api/v3/callback/upyun/" + uploadSession.Key)
apiURL := siteURL.ResolveReference(apiBaseURI)
// 上传策略
fileInfo := file.Info()
putPolicy := UploadPolicy{
Bucket: handler.Policy.BucketName,
// TODO escape
SaveKey: fileInfo.SavePath,
Expiration: time.Now().Add(time.Duration(ttl) * time.Second).Unix(),
CallbackURL: apiURL.String(),
ContentLength: fileInfo.Size,
ContentLengthRange: fmt.Sprintf("0,%d", fileInfo.Size),
AllowFileType: strings.Join(handler.Policy.OptionsSerialized.FileType, ","),
}
// 生成上传凭证
policyJSON, err := json.Marshal(putPolicy)
if err != nil {
return nil, err
}
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
// 生成签名
elements := []string{"POST", "/" + handler.Policy.BucketName, policyEncoded}
signStr := handler.Sign(ctx, elements)
return &serializer.UploadCredential{
SessionID: uploadSession.Key,
Policy: policyEncoded,
Credential: signStr,
UploadURLs: []string{"https://v0.api.upyun.com/" + handler.Policy.BucketName},
}, nil
}
// 取消上传凭证
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
return nil
}
// Sign 计算又拍云的签名头
func (handler Driver) Sign(ctx context.Context, elements []string) string {
password := fmt.Sprintf("%x", md5.Sum([]byte(handler.Policy.SecretKey)))
mac := hmac.New(sha1.New, []byte(password))
value := strings.Join(elements, "&")
mac.Write([]byte(value))
signStr := base64.StdEncoding.EncodeToString((mac.Sum(nil)))
return fmt.Sprintf("UPYUN %s:%s", handler.Policy.AccessKey, signStr)
}

25
pkg/filesystem/errors.go Normal file
View File

@ -0,0 +1,25 @@
package filesystem
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
var (
ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil)
ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil)
ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "File type not allowed", nil)
ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil)
ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil)
ErrClientCanceled = errors.New("Client canceled operation")
ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "Root protected", nil)
ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil)
ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil)
ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "Upload session existed", nil)
ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil)
ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "Object not exist", nil)
ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil)
ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil)
ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil)
ErrOneObjectOnly = serializer.ParamErr("You can only copy one object at the same time", nil)
)

387
pkg/filesystem/file.go Normal file
View File

@ -0,0 +1,387 @@
package filesystem
import (
"context"
"fmt"
"io"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/juju/ratelimit"
)
/* ============
文件相关
============
*/
// 限速后的ReaderSeeker
type lrs struct {
response.RSCloser
r io.Reader
}
func (r lrs) Read(p []byte) (int, error) {
return r.r.Read(p)
}
// withSpeedLimit 给原有的ReadSeeker加上限速
func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
// 如果用户组有速度限制就返回限制流速的ReaderSeeker
if fs.User.Group.SpeedLimit != 0 {
speed := fs.User.Group.SpeedLimit
bucket := ratelimit.NewBucketWithRate(float64(speed), int64(speed))
lrs := lrs{rs, ratelimit.Reader(rs, bucket)}
return lrs
}
// 否则返回原始流
return rs
}
// AddFile 新增文件记录
func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fsctx.FileHeader) (*model.File, error) {
// 添加文件记录前的钩子
err := fs.Trigger(ctx, "BeforeAddFile", file)
if err != nil {
return nil, err
}
uploadInfo := file.Info()
newFile := model.File{
Name: uploadInfo.FileName,
SourceName: uploadInfo.SavePath,
UserID: fs.User.ID,
Size: uploadInfo.Size,
FolderID: parent.ID,
PolicyID: fs.Policy.ID,
MetadataSerialized: uploadInfo.Metadata,
UploadSessionID: uploadInfo.UploadSessionID,
}
err = newFile.Create()
if err != nil {
if err := fs.Trigger(ctx, "AfterValidateFailed", file); err != nil {
util.Log().Debug("AfterValidateFailed hook execution failed: %s", err)
}
return nil, ErrFileExisted.WithError(err)
}
fs.User.Storage += newFile.Size
return &newFile, nil
}
// GetPhysicalFileContent 根据文件物理路径获取文件流
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
// 重设上传策略
fs.Policy = &model.Policy{Type: "local"}
_ = fs.DispatchHandler()
// 获取文件流
rs, err := fs.Handler.Get(ctx, path)
if err != nil {
return nil, err
}
return fs.withSpeedLimit(rs), nil
}
// Preview 预览文件
//
// path - 文件虚拟路径
// isText - 是否为文本文件,文本文件会忽略重定向,直接由
// 服务端拉取中转给用户,故会对文件大小进行限制
func (fs *FileSystem) Preview(ctx context.Context, id uint, isText bool) (*response.ContentResponse, error) {
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return nil, err
}
// 如果是文本文件预览,需要检查大小限制
sizeLimit := model.GetIntSetting("maxEditSize", 2<<20)
if isText && fs.FileTarget[0].Size > uint64(sizeLimit) {
return nil, ErrFileSizeTooBig
}
// 是否直接返回文件内容
if isText || fs.Policy.IsDirectlyPreview() {
resp, err := fs.GetDownloadContent(ctx, id)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: false,
Content: resp,
}, nil
}
// 否则重定向到签名的预览URL
ttl := model.GetIntSetting("preview_timeout", 60)
previewURL, err := fs.SignURL(ctx, &fs.FileTarget[0], int64(ttl), false)
if err != nil {
return nil, err
}
return &response.ContentResponse{
Redirect: true,
URL: previewURL,
MaxAge: ttl,
}, nil
}
// GetDownloadContent 获取用于下载的文件流
func (fs *FileSystem) GetDownloadContent(ctx context.Context, id uint) (response.RSCloser, error) {
// 获取原始文件流
rs, err := fs.GetContent(ctx, id)
if err != nil {
return nil, err
}
// 返回限速处理后的文件流
return fs.withSpeedLimit(rs), nil
}
// GetContent 获取文件内容path为虚拟路径
func (fs *FileSystem) GetContent(ctx context.Context, id uint) (response.RSCloser, error) {
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return nil, err
}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
// 获取文件流
rs, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName)
if err != nil {
return nil, ErrIO.WithError(err)
}
return rs, nil
}
// deleteGroupedFile 对分组好的文件执行删除操作,
// 返回每个分组失败的文件列表
func (fs *FileSystem) deleteGroupedFile(ctx context.Context, files map[uint][]*model.File) map[uint][]string {
// 失败的文件列表
// TODO 并行删除
failed := make(map[uint][]string, len(files))
thumbs := make([]string, 0)
for policyID, toBeDeletedFiles := range files {
// 列举出需要物理删除的文件的物理路径
sourceNamesAll := make([]string, 0, len(toBeDeletedFiles))
uploadSessions := make([]*serializer.UploadSession, 0, len(toBeDeletedFiles))
for i := 0; i < len(toBeDeletedFiles); i++ {
sourceNamesAll = append(sourceNamesAll, toBeDeletedFiles[i].SourceName)
if toBeDeletedFiles[i].UploadSessionID != nil {
if session, ok := cache.Get(UploadSessionCachePrefix + *toBeDeletedFiles[i].UploadSessionID); ok {
uploadSession := session.(serializer.UploadSession)
uploadSessions = append(uploadSessions, &uploadSession)
}
}
// Check if sidecar thumb file exist
if model.IsTrueVal(toBeDeletedFiles[i].MetadataSerialized[model.ThumbSidecarMetadataKey]) {
thumbs = append(thumbs, toBeDeletedFiles[i].ThumbFile())
}
}
// 切换上传策略
fs.Policy = toBeDeletedFiles[0].GetPolicy()
err := fs.DispatchHandler()
if err != nil {
failed[policyID] = sourceNamesAll
continue
}
// 取消上传会话
for _, upSession := range uploadSessions {
if err := fs.Handler.CancelToken(ctx, upSession); err != nil {
util.Log().Warning("Failed to cancel upload session for %q: %s", upSession.Name, err)
}
cache.Deletes([]string{upSession.Key}, UploadSessionCachePrefix)
}
// 执行删除
toBeDeletedSrcs := append(sourceNamesAll, thumbs...)
failedFile, _ := fs.Handler.Delete(ctx, toBeDeletedSrcs)
// Exclude failed results related to thumb file
failed[policyID] = util.SliceDifference(failedFile, thumbs)
}
return failed
}
// GroupFileByPolicy 将目标文件按照存储策略分组
func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File) map[uint][]*model.File {
var policyGroup = make(map[uint][]*model.File)
for key := range files {
if file, ok := policyGroup[files[key].PolicyID]; ok {
// 如果已存在分组,直接追加
policyGroup[files[key].PolicyID] = append(file, &files[key])
} else {
// 分组不存在,创建
policyGroup[files[key].PolicyID] = make([]*model.File, 0)
policyGroup[files[key].PolicyID] = append(policyGroup[files[key].PolicyID], &files[key])
}
}
return policyGroup
}
// GetDownloadURL 创建文件下载链接, timeout 为数据库中存储过期时间的字段
func (fs *FileSystem) GetDownloadURL(ctx context.Context, id uint, timeout string) (string, error) {
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return "", err
}
fileTarget := &fs.FileTarget[0]
// 生成下載地址
ttl := model.GetIntSetting(timeout, 60)
source, err := fs.SignURL(
ctx,
fileTarget,
int64(ttl),
true,
)
if err != nil {
return "", err
}
return source, nil
}
// GetSource 获取可直接访问文件的外链地址
func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) {
// 查找文件记录
err := fs.resetFileIDIfNotExist(ctx, fileID)
if err != nil {
return "", ErrObjectNotExist.WithError(err)
}
// 检查存储策略是否可以获得外链
if !fs.Policy.IsOriginLinkEnable {
return "", serializer.NewError(
serializer.CodePolicyNotAllowed,
"This policy is not enabled for getting source link",
nil,
)
}
source, err := fs.SignURL(ctx, &fs.FileTarget[0], 0, false)
if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
}
return source, nil
}
// SignURL 签名文件原始 URL
func (fs *FileSystem) SignURL(ctx context.Context, file *model.File, ttl int64, isDownload bool) (string, error) {
fs.FileTarget = []model.File{*file}
ctx = context.WithValue(ctx, fsctx.FileModelCtx, *file)
err := fs.resetPolicyToFirstFile(ctx)
if err != nil {
return "", err
}
// 签名最终URL
// 生成外链地址
source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, ttl, isDownload, fs.User.Group.SpeedLimit)
if err != nil {
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
}
return source, nil
}
// ResetFileIfNotExist 重设当前目标文件为 path如果当前目标为空
func (fs *FileSystem) ResetFileIfNotExist(ctx context.Context, path string) error {
// 找到文件
if len(fs.FileTarget) == 0 {
exist, file := fs.IsFileExist(path)
if !exist {
return ErrObjectNotExist
}
fs.FileTarget = []model.File{*file}
}
// 将当前存储策略重设为文件使用的
return fs.resetPolicyToFirstFile(ctx)
}
// ResetFileIfNotExist 重设当前目标文件为 id如果当前目标为空
func (fs *FileSystem) resetFileIDIfNotExist(ctx context.Context, id uint) error {
// 找到文件
if len(fs.FileTarget) == 0 {
file, err := model.GetFilesByIDs([]uint{id}, fs.User.ID)
if err != nil || len(file) == 0 {
return ErrObjectNotExist
}
fs.FileTarget = []model.File{file[0]}
}
// 如果上下文限制了父目录,则进行检查
if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok {
if parent.ID != fs.FileTarget[0].FolderID {
return ErrObjectNotExist
}
}
// 将当前存储策略重设为文件使用的
return fs.resetPolicyToFirstFile(ctx)
}
// resetPolicyToFirstFile 将当前存储策略重设为第一个目标文件文件使用的
func (fs *FileSystem) resetPolicyToFirstFile(ctx context.Context) error {
if len(fs.FileTarget) == 0 {
return ErrObjectNotExist
}
// 从机模式不进行操作
if conf.SystemConfig.Mode == "slave" {
return nil
}
fs.Policy = fs.FileTarget[0].GetPolicy()
err := fs.DispatchHandler()
if err != nil {
return err
}
return nil
}
// Search 搜索文件
func (fs *FileSystem) Search(ctx context.Context, keywords ...interface{}) ([]serializer.Object, error) {
parents := make([]uint, 0)
// 如果限定了根目录,则只在这个根目录下搜索。
if fs.Root != nil {
allFolders, err := model.GetRecursiveChildFolder([]uint{fs.Root.ID}, fs.User.ID, true)
if err != nil {
return nil, fmt.Errorf("failed to list all folders: %w", err)
}
for _, folder := range allFolders {
parents = append(parents, folder.ID)
}
}
files, _ := model.GetFilesByKeywords(fs.User.ID, parents, keywords...)
fs.SetTargetFile(&files)
return fs.listObjects(ctx, "/", files, nil, nil), nil
}

View File

@ -0,0 +1,295 @@
package filesystem
import (
"errors"
"fmt"
"net/http"
"net/url"
"sync"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/qiniu"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gin-gonic/gin"
cossdk "github.com/tencentyun/cos-go-sdk-v5"
)
// FSPool 文件系统资源池
var FSPool = sync.Pool{
New: func() interface{} {
return &FileSystem{}
},
}
// FileSystem 管理文件的文件系统
type FileSystem struct {
// 文件系统所有者
User *model.User
// 操作文件使用的存储策略
Policy *model.Policy
// 当前正在处理的文件对象
FileTarget []model.File
// 当前正在处理的目录对象
DirTarget []model.Folder
// 相对根目录
Root *model.Folder
// 互斥锁
Lock sync.Mutex
/*
钩子函数
*/
Hooks map[string][]Hook
/*
文件系统处理适配器
*/
Handler driver.Handler
// 回收锁
recycleLock sync.Mutex
}
// getEmptyFS 从pool中获取新的FileSystem
func getEmptyFS() *FileSystem {
fs := FSPool.Get().(*FileSystem)
return fs
}
// Recycle 回收FileSystem资源
func (fs *FileSystem) Recycle() {
fs.recycleLock.Lock()
fs.reset()
FSPool.Put(fs)
}
// reset 重设文件系统,以便回收使用
func (fs *FileSystem) reset() {
fs.User = nil
fs.CleanTargets()
fs.Policy = nil
fs.Hooks = nil
fs.Handler = nil
fs.Root = nil
fs.Lock = sync.Mutex{}
fs.recycleLock = sync.Mutex{}
}
// NewFileSystem 初始化一个文件系统
func NewFileSystem(user *model.User) (*FileSystem, error) {
fs := getEmptyFS()
fs.User = user
fs.Policy = user.GetPolicyID(nil)
// 分配存储策略适配器
err := fs.DispatchHandler()
return fs, err
}
// NewAnonymousFileSystem 初始化匿名文件系统
func NewAnonymousFileSystem() (*FileSystem, error) {
fs := getEmptyFS()
fs.User = &model.User{}
// 如果是主机模式下,则为匿名文件系统分配游客用户组
if conf.SystemConfig.Mode == "master" {
anonymousGroup, err := model.GetGroupByID(3)
if err != nil {
return nil, err
}
fs.User.Group = anonymousGroup
} else {
// 从机模式下,分配本地策略处理器
fs.Handler = local.Driver{}
}
return fs, nil
}
// DispatchHandler 根据存储策略分配文件适配器
func (fs *FileSystem) DispatchHandler() error {
handler, err := getNewPolicyHandler(fs.Policy)
fs.Handler = handler
return err
}
// getNewPolicyHandler 根据存储策略类型字段获取处理器
func getNewPolicyHandler(policy *model.Policy) (driver.Handler, error) {
if policy == nil {
return nil, ErrUnknownPolicyType
}
switch policy.Type {
case "mock", "anonymous":
return nil, nil
case "local":
return local.Driver{
Policy: policy,
}, nil
case "remote":
return remote.NewDriver(policy)
case "qiniu":
return qiniu.NewDriver(policy), nil
case "oss":
return oss.NewDriver(policy)
case "upyun":
return upyun.Driver{
Policy: policy,
}, nil
case "onedrive":
return onedrive.NewDriver(policy)
case "cos":
u, _ := url.Parse(policy.Server)
b := &cossdk.BaseURL{BucketURL: u}
return cos.Driver{
Policy: policy,
Client: cossdk.NewClient(b, &http.Client{
Transport: &cossdk.AuthorizationTransport{
SecretID: policy.AccessKey,
SecretKey: policy.SecretKey,
},
}),
HTTPClient: request.NewClient(),
}, nil
case "s3":
return s3.NewDriver(policy)
case "googledrive":
return googledrive.NewDriver(policy)
default:
return nil, ErrUnknownPolicyType
}
}
// NewFileSystemFromContext 从gin.Context创建文件系统
func NewFileSystemFromContext(c *gin.Context) (*FileSystem, error) {
user, exist := c.Get("user")
if !exist {
return NewAnonymousFileSystem()
}
fs, err := NewFileSystem(user.(*model.User))
return fs, err
}
// NewFileSystemFromCallback 从gin.Context创建回调用文件系统
func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) {
fs, err := NewFileSystemFromContext(c)
if err != nil {
return nil, err
}
// 获取回调会话
callbackSessionRaw, ok := c.Get(UploadSessionCtx)
if !ok {
return nil, errors.New("upload session not exist")
}
callbackSession := callbackSessionRaw.(*serializer.UploadSession)
// 重新指向上传策略
fs.Policy = &callbackSession.Policy
err = fs.DispatchHandler()
return fs, err
}
// SwitchToSlaveHandler 将负责上传的 Handler 切换为从机节点
func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) {
fs.Handler = slaveinmaster.NewDriver(node, fs.Handler, fs.Policy)
}
// SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器
func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) {
switch fs.Policy.Type {
case "local":
fs.Policy.Type = "remote"
fs.Policy.Server = masterURL
fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID())
fs.Policy.SecretKey = master.DBModel().MasterKey
fs.DispatchHandler()
case "onedrive":
fs.Policy.MasterID = masterID
}
fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy)
}
// SetTargetFile 设置当前处理的目标文件
func (fs *FileSystem) SetTargetFile(files *[]model.File) {
if len(fs.FileTarget) == 0 {
fs.FileTarget = *files
} else {
fs.FileTarget = append(fs.FileTarget, *files...)
}
}
// SetTargetDir 设置当前处理的目标目录
func (fs *FileSystem) SetTargetDir(dirs *[]model.Folder) {
if len(fs.DirTarget) == 0 {
fs.DirTarget = *dirs
} else {
fs.DirTarget = append(fs.DirTarget, *dirs...)
}
}
// SetTargetFileByIDs 根据文件ID设置目标文件忽略用户ID
func (fs *FileSystem) SetTargetFileByIDs(ids []uint) error {
files, err := model.GetFilesByIDs(ids, 0)
if err != nil || len(files) == 0 {
return ErrFileExisted.WithError(err)
}
fs.SetTargetFile(&files)
return nil
}
// SetTargetByInterface 根据 model.File 或者 model.Folder 设置目标对象
// TODO 测试
func (fs *FileSystem) SetTargetByInterface(target interface{}) error {
if file, ok := target.(*model.File); ok {
fs.SetTargetFile(&[]model.File{*file})
return nil
}
if folder, ok := target.(*model.Folder); ok {
fs.SetTargetDir(&[]model.Folder{*folder})
return nil
}
return ErrObjectNotExist
}
// CleanTargets 清空目标
func (fs *FileSystem) CleanTargets() {
fs.FileTarget = fs.FileTarget[:0]
fs.DirTarget = fs.DirTarget[:0]
}
// SetPolicyFromPath 根据给定路径尝试设定偏好存储策略
func (fs *FileSystem) SetPolicyFromPath(filePath string) error {
_, parent := fs.getClosedParent(filePath)
// 尝试获取并重设存储策略
fs.Policy = fs.User.GetPolicyID(parent)
return fs.DispatchHandler()
}
// SetPolicyFromPreference 尝试设定偏好存储策略
func (fs *FileSystem) SetPolicyFromPreference(preference uint) error {
// 尝试获取并重设存储策略
fs.Policy = fs.User.GetPolicyByPreference(preference)
return fs.DispatchHandler()
}

View File

@ -0,0 +1,44 @@
package fsctx
type key int
const (
// GinCtx Gin的上下文
GinCtx key = iota
// PathCtx 文件或目录的虚拟路径
PathCtx
// FileModelCtx 文件数据库模型
FileModelCtx
// FolderModelCtx 目录数据库模型
FolderModelCtx
// HTTPCtx HTTP请求的上下文
HTTPCtx
// UploadPolicyCtx 上传策略一般为slave模式下使用
UploadPolicyCtx
// UserCtx 用户
UserCtx
// ThumbSizeCtx 缩略图尺寸
ThumbSizeCtx
// FileSizeCtx 文件大小
FileSizeCtx
// ShareKeyCtx 分享文件的 HashID
ShareKeyCtx
// LimitParentCtx 限制父目录
LimitParentCtx
// IgnoreDirectoryConflictCtx 忽略目录重名冲突
IgnoreDirectoryConflictCtx
// RetryCtx 失败重试次数
RetryCtx
// ForceUsePublicEndpointCtx 强制使用公网 Endpoint
ForceUsePublicEndpointCtx
// CancelFuncCtx Context 取消函數
CancelFuncCtx
// 文件在从机节点中的路径
SlaveSrcPath
// Webdav目标名称
WebdavDstName
// WebDAVCtx WebDAV
WebDAVCtx
// WebDAV反代Url
WebDAVProxyUrlCtx
)

View File

@ -0,0 +1,123 @@
package fsctx
import (
"errors"
"github.com/HFO4/aliyun-oss-go-sdk/oss"
"io"
"time"
)
type WriteMode int
const (
Overwrite WriteMode = 0x00001
// Append 只适用于本地策略
Append WriteMode = 0x00002
Nop WriteMode = 0x00004
)
type UploadTaskInfo struct {
Size uint64
MimeType string
FileName string
VirtualPath string
Mode WriteMode
Metadata map[string]string
LastModified *time.Time
SavePath string
UploadSessionID *string
AppendStart uint64
Model interface{}
Src string
}
// Get mimetype of uploaded file, if it's not defined, detect it from file name
func (u *UploadTaskInfo) DetectMimeType() string {
if u.MimeType != "" {
return u.MimeType
}
return oss.TypeByExtension(u.FileName)
}
// FileHeader 上传来的文件数据处理器
type FileHeader interface {
io.Reader
io.Closer
io.Seeker
Info() *UploadTaskInfo
SetSize(uint64)
SetModel(fileModel interface{})
Seekable() bool
}
// FileStream 用户传来的文件
type FileStream struct {
Mode WriteMode
LastModified *time.Time
Metadata map[string]string
File io.ReadCloser
Seeker io.Seeker
Size uint64
VirtualPath string
Name string
MimeType string
SavePath string
UploadSessionID *string
AppendStart uint64
Model interface{}
Src string
}
func (file *FileStream) Read(p []byte) (n int, err error) {
if file.File != nil {
return file.File.Read(p)
}
return 0, io.EOF
}
func (file *FileStream) Close() error {
if file.File != nil {
return file.File.Close()
}
return nil
}
func (file *FileStream) Seek(offset int64, whence int) (int64, error) {
if file.Seekable() {
return file.Seeker.Seek(offset, whence)
}
return 0, errors.New("no seeker")
}
func (file *FileStream) Seekable() bool {
return file.Seeker != nil
}
func (file *FileStream) Info() *UploadTaskInfo {
return &UploadTaskInfo{
Size: file.Size,
MimeType: file.MimeType,
FileName: file.Name,
VirtualPath: file.VirtualPath,
Mode: file.Mode,
Metadata: file.Metadata,
LastModified: file.LastModified,
SavePath: file.SavePath,
UploadSessionID: file.UploadSessionID,
AppendStart: file.AppendStart,
Model: file.Model,
Src: file.Src,
}
}
func (file *FileStream) SetSize(size uint64) {
file.Size = size
}
func (file *FileStream) SetModel(fileModel interface{}) {
file.Model = fileModel
}

320
pkg/filesystem/hooks.go Normal file
View File

@ -0,0 +1,320 @@
package filesystem
import (
"context"
"io/ioutil"
"net/http"
"strconv"
"strings"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
// Hook 钩子函数
type Hook func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error
// Use 注入钩子
func (fs *FileSystem) Use(name string, hook Hook) {
if fs.Hooks == nil {
fs.Hooks = make(map[string][]Hook)
}
if _, ok := fs.Hooks[name]; ok {
fs.Hooks[name] = append(fs.Hooks[name], hook)
return
}
fs.Hooks[name] = []Hook{hook}
}
// CleanHooks 清空钩子,name为空表示全部清空
func (fs *FileSystem) CleanHooks(name string) {
if name == "" {
fs.Hooks = nil
} else {
delete(fs.Hooks, name)
}
}
// Trigger 触发钩子,遇到第一个错误时
// 返回错误,后续钩子不会继续执行
func (fs *FileSystem) Trigger(ctx context.Context, name string, file fsctx.FileHeader) error {
if hooks, ok := fs.Hooks[name]; ok {
for _, hook := range hooks {
err := hook(ctx, fs, file)
if err != nil {
util.Log().Warning("Failed to execute hook%s", err)
return err
}
}
}
return nil
}
// HookValidateFile 一系列对文件检验的集合
func HookValidateFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
fileInfo := file.Info()
// 验证单文件尺寸
if !fs.ValidateFileSize(ctx, fileInfo.Size) {
return ErrFileSizeTooBig
}
// 验证文件名
if !fs.ValidateLegalName(ctx, fileInfo.FileName) {
return ErrIllegalObjectName
}
// 验证扩展名
if !fs.ValidateExtension(ctx, fileInfo.FileName) {
return ErrFileExtensionNotAllowed
}
return nil
}
// HookResetPolicy 重设存储策略为上下文已有文件
func HookResetPolicy(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
fs.Policy = originFile.GetPolicy()
return fs.DispatchHandler()
}
// HookValidateCapacity 验证用户容量
func HookValidateCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
// 验证并扣除容量
if fs.User.GetRemainingCapacity() < file.Info().Size {
return ErrInsufficientCapacity
}
return nil
}
// HookValidateCapacityDiff 根据原有文件和新文件的大小验证用户容量
func HookValidateCapacityDiff(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
originFile := ctx.Value(fsctx.FileModelCtx).(model.File)
newFileSize := newFile.Info().Size
if newFileSize > originFile.Size {
return HookValidateCapacity(ctx, fs, newFile)
}
return nil
}
// HookDeleteTempFile 删除已保存的临时文件
func HookDeleteTempFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
// 删除临时文件
_, err := fs.Handler.Delete(ctx, []string{file.Info().SavePath})
if err != nil {
util.Log().Warning("Failed to clean-up temp files: %s", err)
}
return nil
}
// HookCleanFileContent 清空文件内容
func HookCleanFileContent(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
// 清空内容
return fs.Handler.Put(ctx, &fsctx.FileStream{
File: ioutil.NopCloser(strings.NewReader("")),
SavePath: file.Info().SavePath,
Size: 0,
Mode: fsctx.Overwrite,
})
}
// HookClearFileSize 将原始文件的尺寸设为0
func HookClearFileSize(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
return originFile.UpdateSize(0)
}
// HookCancelContext 取消上下文
func HookCancelContext(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
cancelFunc, ok := ctx.Value(fsctx.CancelFuncCtx).(context.CancelFunc)
if ok {
cancelFunc()
}
return nil
}
// HookUpdateSourceName 更新文件SourceName
func HookUpdateSourceName(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
return originFile.UpdateSourceName(originFile.SourceName)
}
// GenericAfterUpdate 文件内容更新后
func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
// 更新文件尺寸
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
if !ok {
return ErrObjectNotExist
}
newFile.SetModel(&originFile)
err := originFile.UpdateSize(newFile.Info().Size)
if err != nil {
return err
}
return nil
}
// SlaveAfterUpload Slave模式下上传完成钩子
func SlaveAfterUpload(session *serializer.UploadSession) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
if session.Callback == "" {
return nil
}
// 发送回调请求
callbackBody := serializer.UploadCallback{}
return cluster.RemoteCallback(session.Callback, callbackBody)
}
}
// GenericAfterUpload 文件上传完成后,包含数据库操作
func GenericAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
// 创建或查找根目录
folder, err := fs.CreateDirectory(ctx, fileInfo.VirtualPath)
if err != nil {
return err
}
// 检查文件是否存在
if ok, file := fs.IsChildFileExist(
folder,
fileInfo.FileName,
); ok {
if file.UploadSessionID != nil {
return ErrFileUploadSessionExisted
}
return ErrFileExisted
}
// 向数据库中插入记录
file, err := fs.AddFile(ctx, folder, fileHeader)
if err != nil {
return ErrInsertFileRecord
}
fileHeader.SetModel(file)
return nil
}
// HookGenerateThumb 生成缩略图
// func HookGenerateThumb(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
// // 异步尝试生成缩略图
// fileMode := fileHeader.Info().Model.(*model.File)
// if fs.Policy.IsThumbGenerateNeeded() {
// fs.recycleLock.Lock()
// go func() {
// defer fs.recycleLock.Unlock()
// _, _ = fs.Handler.Delete(ctx, []string{fileMode.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")})
// fs.GenerateThumbnail(ctx, fileMode)
// }()
// }
// return nil
// }
// HookClearFileHeaderSize 将FileHeader大小设定为0
func HookClearFileHeaderSize(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileHeader.SetSize(0)
return nil
}
// HookTruncateFileTo 将物理文件截断至 size
func HookTruncateFileTo(size uint64) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
if handler, ok := fs.Handler.(local.Driver); ok {
return handler.Truncate(ctx, fileHeader.Info().SavePath, size)
}
return nil
}
}
// HookChunkUploadFinished 单个分片上传结束后
func HookChunkUploaded(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
// 更新文件大小
return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart + fileInfo.Size)
}
// HookChunkUploadFailed 单个分片上传失败后
func HookChunkUploadFailed(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
// 更新文件大小
return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart)
}
// HookPopPlaceholderToFile 将占位文件提升为正式文件
func HookPopPlaceholderToFile(picInfo string) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
fileInfo := fileHeader.Info()
fileModel := fileInfo.Model.(*model.File)
return fileModel.PopChunkToFile(fileInfo.LastModified, picInfo)
}
}
// HookChunkUploadFinished 分片上传结束后处理文件
func HookDeleteUploadSession(id string) Hook {
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
cache.Deletes([]string{id}, UploadSessionCachePrefix)
return nil
}
}
// NewWebdavAfterUploadHook 每次创建一个新的钩子函数 rclone 在 PUT 请求里有 OC-Checksum 字符串
// 和 X-OC-Mtime
func NewWebdavAfterUploadHook(request *http.Request) func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
var modtime time.Time
if timeVal := request.Header.Get("X-OC-Mtime"); timeVal != "" {
timeUnix, err := strconv.ParseInt(timeVal, 10, 64)
if err == nil {
modtime = time.Unix(timeUnix, 0)
}
}
checksum := request.Header.Get("OC-Checksum")
return func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
file := newFile.Info().Model.(*model.File)
if !modtime.IsZero() {
err := model.DB.Model(file).UpdateColumn("updated_at", modtime).Error
if err != nil {
return err
}
}
if checksum != "" {
return file.UpdateMetadata(map[string]string{
model.ChecksumMetadataKey: checksum,
})
}
return nil
}
}

218
pkg/filesystem/image.go Normal file
View File

@ -0,0 +1,218 @@
package filesystem
import (
"context"
"errors"
"fmt"
"os"
"sync"
"runtime"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
"github.com/cloudreve/Cloudreve/v3/pkg/thumb"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* ================
图像处理相关
================
*/
// GetThumb 获取文件的缩略图
func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentResponse, error) {
// 根据 ID 查找文件
err := fs.resetFileIDIfNotExist(ctx, id)
if err != nil {
return nil, ErrObjectNotExist
}
file := fs.FileTarget[0]
if !file.ShouldLoadThumb() {
return nil, ErrObjectNotExist
}
w, h := fs.GenerateThumbnailSize(0, 0)
ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, [2]uint{w, h})
ctx = context.WithValue(ctx, fsctx.FileModelCtx, file)
res, err := fs.Handler.Thumb(ctx, &file)
if errors.Is(err, driver.ErrorThumbNotExist) {
// Regenerate thumb if the thumb is not initialized yet
if generateErr := fs.generateThumbnail(ctx, &file); generateErr == nil {
res, err = fs.Handler.Thumb(ctx, &file)
} else {
err = generateErr
}
} else if errors.Is(err, driver.ErrorThumbNotSupported) {
// Policy handler explicitly indicates thumb not available, check if proxy is enabled
if fs.Policy.CouldProxyThumb() {
// if thumb id marked as existed, redirect to "sidecar" thumb file.
if file.MetadataSerialized != nil &&
file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusExist {
// redirect to sidecar file
res = &response.ContentResponse{
Redirect: true,
}
res.URL, err = fs.Handler.Source(ctx, file.ThumbFile(), int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
} else {
// if not exist, generate and upload the sidecar thumb.
if err = fs.generateThumbnail(ctx, &file); err == nil {
return fs.GetThumb(ctx, id)
}
}
} else {
// thumb not supported and proxy is disabled, mark as not available
_ = updateThumbStatus(&file, model.ThumbStatusNotAvailable)
}
}
if err == nil && conf.SystemConfig.Mode == "master" {
res.MaxAge = model.GetIntSetting("preview_timeout", 60)
}
return res, err
}
// thumbPool 要使用的任务池
var thumbPool *Pool
var once sync.Once
// Pool 带有最大配额的任务池
type Pool struct {
// 容量
worker chan int
}
// Init 初始化任务池
func getThumbWorker() *Pool {
once.Do(func() {
maxWorker := model.GetIntSetting("thumb_max_task_count", -1)
if maxWorker <= 0 {
maxWorker = runtime.GOMAXPROCS(0)
}
thumbPool = &Pool{
worker: make(chan int, maxWorker),
}
util.Log().Debug("Initialize thumbnails task queue with: WorkerNum = %d", maxWorker)
})
return thumbPool
}
func (pool *Pool) addWorker() {
pool.worker <- 1
util.Log().Debug("Worker added to thumbnails task queue.")
}
func (pool *Pool) releaseWorker() {
util.Log().Debug("Worker released from thumbnails task queue.")
<-pool.worker
}
// generateThumbnail generates thumb for given file, upload the thumb file back with given suffix
func (fs *FileSystem) generateThumbnail(ctx context.Context, file *model.File) error {
// 新建上下文
newCtx, cancel := context.WithCancel(context.Background())
defer cancel()
// TODO: check file size
if file.Size > uint64(model.GetIntSetting("thumb_max_src_size", 31457280)) {
_ = updateThumbStatus(file, model.ThumbStatusNotAvailable)
return errors.New("file too large")
}
getThumbWorker().addWorker()
defer getThumbWorker().releaseWorker()
// 获取文件数据
source, err := fs.Handler.Get(newCtx, file.SourceName)
if err != nil {
return fmt.Errorf("faield to fetch original file %q: %w", file.SourceName, err)
}
defer source.Close()
// Provide file source path for local policy files
src := ""
if conf.SystemConfig.Mode == "slave" || file.GetPolicy().Type == "local" {
src = file.SourceName
}
thumbRes, err := thumb.Generators.Generate(ctx, source, src, file.Name, model.GetSettingByNames(
"thumb_width",
"thumb_height",
"thumb_builtin_enabled",
"thumb_vips_enabled",
"thumb_ffmpeg_enabled",
"thumb_libreoffice_enabled",
))
if err != nil {
_ = updateThumbStatus(file, model.ThumbStatusNotAvailable)
return fmt.Errorf("failed to generate thumb for %q: %w", file.Name, err)
}
defer os.Remove(thumbRes.Path)
thumbFile, err := os.Open(thumbRes.Path)
if err != nil {
return fmt.Errorf("failed to open temp thumb %q: %w", thumbRes.Path, err)
}
defer thumbFile.Close()
fileInfo, err := thumbFile.Stat()
if err != nil {
return fmt.Errorf("failed to stat temp thumb %q: %w", thumbRes.Path, err)
}
if err = fs.Handler.Put(newCtx, &fsctx.FileStream{
Mode: fsctx.Overwrite,
File: thumbFile,
Seeker: thumbFile,
Size: uint64(fileInfo.Size()),
SavePath: file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb"),
}); err != nil {
return fmt.Errorf("failed to save thumb for %q: %w", file.Name, err)
}
if model.IsTrueVal(model.GetSettingByName("thumb_gc_after_gen")) {
util.Log().Debug("generateThumbnail runtime.GC")
runtime.GC()
}
// Mark this file as thumb available
err = updateThumbStatus(file, model.ThumbStatusExist)
// 失败时删除缩略图文件
if err != nil {
_, _ = fs.Handler.Delete(newCtx, []string{file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")})
}
return nil
}
// GenerateThumbnailSize 获取要生成的缩略图的尺寸
func (fs *FileSystem) GenerateThumbnailSize(w, h int) (uint, uint) {
return uint(model.GetIntSetting("thumb_width", 400)), uint(model.GetIntSetting("thumb_height", 300))
}
func updateThumbStatus(file *model.File, status string) error {
if file.Model.ID > 0 {
meta := map[string]string{
model.ThumbStatusMetadataKey: status,
}
if status == model.ThumbStatusExist {
meta[model.ThumbSidecarMetadataKey] = "true"
}
return file.UpdateMetadata(meta)
} else {
if file.MetadataSerialized == nil {
file.MetadataSerialized = map[string]string{}
}
file.MetadataSerialized[model.ThumbStatusMetadataKey] = status
}
return nil
}

479
pkg/filesystem/manage.go Normal file
View File

@ -0,0 +1,479 @@
package filesystem
import (
"context"
"fmt"
"path"
"strings"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* =================
文件/目录管理
=================
*/
// Rename 重命名对象
func (fs *FileSystem) Rename(ctx context.Context, dir, file []uint, new string) (err error) {
// 验证新名字
if !fs.ValidateLegalName(ctx, new) || (len(file) > 0 && !fs.ValidateExtension(ctx, new)) {
return ErrIllegalObjectName
}
// 如果源对象是文件
if len(file) > 0 {
fileObject, err := model.GetFilesByIDs([]uint{file[0]}, fs.User.ID)
if err != nil || len(fileObject) == 0 {
return ErrPathNotExist
}
err = fileObject[0].Rename(new)
if err != nil {
return ErrFileExisted
}
return nil
}
if len(dir) > 0 {
folderObject, err := model.GetFoldersByIDs([]uint{dir[0]}, fs.User.ID)
if err != nil || len(folderObject) == 0 {
return ErrPathNotExist
}
err = folderObject[0].Rename(new)
if err != nil {
return ErrFileExisted
}
return nil
}
return ErrPathNotExist
}
// Copy 复制src目录下的文件或目录到dst
// 暂时只支持单文件
func (fs *FileSystem) Copy(ctx context.Context, dirs, files []uint, src, dst string) error {
// 获取目的目录
isDstExist, dstFolder := fs.IsPathExist(dst)
isSrcExist, srcFolder := fs.IsPathExist(src)
// 不存在时返回空的结果
if !isDstExist || !isSrcExist {
return ErrPathNotExist
}
// 记录复制的文件的总容量
var newUsedStorage uint64
// 设置webdav目标名
if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok {
dstFolder.WebdavDstName = dstName
}
// 复制目录
if len(dirs) > 0 {
subFileSizes, err := srcFolder.CopyFolderTo(dirs[0], dstFolder)
if err != nil {
return ErrObjectNotExist.WithError(err)
}
newUsedStorage += subFileSizes
}
// 复制文件
if len(files) > 0 {
subFileSizes, err := srcFolder.MoveOrCopyFileTo(files, dstFolder, true)
if err != nil {
return ErrObjectNotExist.WithError(err)
}
newUsedStorage += subFileSizes
}
// 扣除容量
fs.User.IncreaseStorageWithoutCheck(newUsedStorage)
return nil
}
// Move 移动文件和目录, 将id列表dirs和files从src移动至dst
func (fs *FileSystem) Move(ctx context.Context, dirs, files []uint, src, dst string) error {
// 获取目的目录
isDstExist, dstFolder := fs.IsPathExist(dst)
isSrcExist, srcFolder := fs.IsPathExist(src)
// 不存在时返回空的结果
if !isDstExist || !isSrcExist {
return ErrPathNotExist
}
// 设置webdav目标名
if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok {
dstFolder.WebdavDstName = dstName
}
// 处理目录及子文件移动
err := srcFolder.MoveFolderTo(dirs, dstFolder)
if err != nil {
return ErrFileExisted.WithError(err)
}
// 处理文件移动
_, err = srcFolder.MoveOrCopyFileTo(files, dstFolder, false)
if err != nil {
return ErrFileExisted.WithError(err)
}
// 移动文件
return err
}
// Delete 递归删除对象, force 为 true 时强制删除文件记录,忽略物理删除是否成功;
// unlink 为 true 时只删除虚拟文件系统的文件记录,不删除物理文件。
func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint, force, unlink bool) error {
// 已删除的文件ID
var deletedFiles = make([]*model.File, 0, len(fs.FileTarget))
// 删除失败的文件的父目录ID
// 所有文件的ID
var allFiles = make([]*model.File, 0, len(fs.FileTarget))
// 列出要删除的目录
if len(dirs) > 0 {
err := fs.ListDeleteDirs(ctx, dirs)
if err != nil {
return err
}
}
// 列出要删除的文件
if len(files) > 0 {
err := fs.ListDeleteFiles(ctx, files)
if err != nil {
return err
}
}
// 去除待删除文件中包含软连接的部分
filesToBeDelete, err := model.RemoveFilesWithSoftLinks(fs.FileTarget)
if err != nil {
return ErrDBListObjects.WithError(err)
}
// 根据存储策略将文件分组
policyGroup := fs.GroupFileByPolicy(ctx, filesToBeDelete)
// 按照存储策略分组删除对象
failed := make(map[uint][]string)
if !unlink {
failed = fs.deleteGroupedFile(ctx, policyGroup)
}
// 整理删除结果
for i := 0; i < len(fs.FileTarget); i++ {
if !util.ContainsString(failed[fs.FileTarget[i].PolicyID], fs.FileTarget[i].SourceName) {
// 已成功删除的文件
deletedFiles = append(deletedFiles, &fs.FileTarget[i])
}
// 全部文件
allFiles = append(allFiles, &fs.FileTarget[i])
}
// 如果强制删除,则将全部文件视为删除成功
if force {
deletedFiles = allFiles
}
// 删除文件记录
err = model.DeleteFiles(deletedFiles, fs.User.ID)
if err != nil {
return ErrDBDeleteObjects.WithError(err)
}
// 删除文件记录对应的分享记录
// TODO 先取消分享再删除文件
deletedFileIDs := make([]uint, len(deletedFiles))
for k, file := range deletedFiles {
deletedFileIDs[k] = file.ID
}
model.DeleteShareBySourceIDs(deletedFileIDs, false)
// 如果文件全部删除成功,继续删除目录
if len(deletedFiles) == len(allFiles) {
var allFolderIDs = make([]uint, 0, len(fs.DirTarget))
for _, value := range fs.DirTarget {
allFolderIDs = append(allFolderIDs, value.ID)
}
err = model.DeleteFolderByIDs(allFolderIDs)
if err != nil {
return ErrDBDeleteObjects.WithError(err)
}
// 删除目录记录对应的分享记录
model.DeleteShareBySourceIDs(allFolderIDs, true)
}
if notDeleted := len(fs.FileTarget) - len(deletedFiles); notDeleted > 0 {
return serializer.NewError(
serializer.CodeNotFullySuccess,
fmt.Sprintf("Failed to delete %d file(s).", notDeleted),
nil,
)
}
return nil
}
// ListDeleteDirs 递归列出要删除目录,及目录下所有文件
func (fs *FileSystem) ListDeleteDirs(ctx context.Context, ids []uint) error {
// 列出所有递归子目录
folders, err := model.GetRecursiveChildFolder(ids, fs.User.ID, true)
if err != nil {
return ErrDBListObjects.WithError(err)
}
// 忽略根目录
for i := 0; i < len(folders); i++ {
if folders[i].ParentID == nil {
folders = append(folders[:i], folders[i+1:]...)
break
}
}
fs.SetTargetDir(&folders)
// 检索目录下的子文件
files, err := model.GetChildFilesOfFolders(&folders)
if err != nil {
return ErrDBListObjects.WithError(err)
}
fs.SetTargetFile(&files)
return nil
}
// ListDeleteFiles 根据给定的路径列出要删除的文件
func (fs *FileSystem) ListDeleteFiles(ctx context.Context, ids []uint) error {
files, err := model.GetFilesByIDs(ids, fs.User.ID)
if err != nil {
return ErrDBListObjects.WithError(err)
}
fs.SetTargetFile(&files)
return nil
}
// List 列出路径下的内容,
// pathProcessor为最终对象路径的处理钩子。
// 有些情况下(如在分享页面列对象)时,
// 路径需要截取掉被分享目录路径之前的部分。
func (fs *FileSystem) List(ctx context.Context, dirPath string, pathProcessor func(string) string) ([]serializer.Object, error) {
// 获取父目录
isExist, folder := fs.IsPathExist(dirPath)
if !isExist {
return nil, ErrPathNotExist
}
fs.SetTargetDir(&[]model.Folder{*folder})
var parentPath = path.Join(folder.Position, folder.Name)
var childFolders []model.Folder
var childFiles []model.File
// 获取子目录
childFolders, _ = folder.GetChildFolder()
// 获取子文件
childFiles, _ = folder.GetChildFiles()
return fs.listObjects(ctx, parentPath, childFiles, childFolders, pathProcessor), nil
}
// ListPhysical 列出存储策略中的外部目录
// TODO:测试
func (fs *FileSystem) ListPhysical(ctx context.Context, dirPath string) ([]serializer.Object, error) {
if err := fs.DispatchHandler(); fs.Policy == nil || err != nil {
return nil, ErrUnknownPolicyType
}
// 存储策略不支持列取时,返回空结果
if !fs.Policy.CanStructureBeListed() {
return nil, nil
}
// 列取路径
objects, err := fs.Handler.List(ctx, dirPath, false)
if err != nil {
return nil, err
}
var (
folders []model.Folder
)
for _, object := range objects {
if object.IsDir {
folders = append(folders, model.Folder{
Name: object.Name,
})
}
}
return fs.listObjects(ctx, dirPath, nil, folders, nil), nil
}
func (fs *FileSystem) listObjects(ctx context.Context, parent string, files []model.File, folders []model.Folder, pathProcessor func(string) string) []serializer.Object {
// 分享文件的ID
shareKey := ""
if key, ok := ctx.Value(fsctx.ShareKeyCtx).(string); ok {
shareKey = key
}
// 汇总处理结果
objects := make([]serializer.Object, 0, len(files)+len(folders))
// 所有对象的父目录
var processedPath string
for _, subFolder := range folders {
// 路径处理钩子,
// 所有对象父目录都是一样的,所以只处理一次
if processedPath == "" {
if pathProcessor != nil {
processedPath = pathProcessor(parent)
} else {
processedPath = parent
}
}
objects = append(objects, serializer.Object{
ID: hashid.HashID(subFolder.ID, hashid.FolderID),
Name: subFolder.Name,
Path: processedPath,
Size: 0,
Type: "dir",
Date: subFolder.UpdatedAt,
CreateDate: subFolder.CreatedAt,
})
}
for _, file := range files {
if processedPath == "" {
if pathProcessor != nil {
processedPath = pathProcessor(parent)
} else {
processedPath = parent
}
}
if file.UploadSessionID == nil {
newFile := serializer.Object{
ID: hashid.HashID(file.ID, hashid.FileID),
Name: file.Name,
Path: processedPath,
Thumb: file.ShouldLoadThumb(),
Size: file.Size,
Type: "file",
Date: file.UpdatedAt,
SourceEnabled: file.GetPolicy().IsOriginLinkEnable,
CreateDate: file.CreatedAt,
}
if shareKey != "" {
newFile.Key = shareKey
}
objects = append(objects, newFile)
}
}
return objects
}
// CreateDirectory 根据给定的完整创建目录,支持递归创建。如果目录已存在,则直接
// 返回已存在的目录。
func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*model.Folder, error) {
if fullPath == "." || fullPath == "" {
return nil, ErrRootProtected
}
if fullPath == "/" {
if fs.Root != nil {
return fs.Root, nil
}
return fs.User.Root()
}
// 获取要创建目录的父路径和目录名
fullPath = path.Clean(fullPath)
base := path.Dir(fullPath)
dir := path.Base(fullPath)
// 去掉结尾空格
dir = strings.TrimRight(dir, " ")
// 检查目录名是否合法
if !fs.ValidateLegalName(ctx, dir) {
return nil, ErrIllegalObjectName
}
// 父目录是否存在
isExist, parent := fs.IsPathExist(base)
if !isExist {
newParent, err := fs.CreateDirectory(ctx, base)
if err != nil {
return nil, err
}
parent = newParent
}
// 是否有同名文件
if ok, _ := fs.IsChildFileExist(parent, dir); ok {
return nil, ErrFileExisted
}
// 创建目录
newFolder := model.Folder{
Name: dir,
ParentID: &parent.ID,
OwnerID: fs.User.ID,
}
_, err := newFolder.Create()
if err != nil {
return nil, fmt.Errorf("failed to create folder: %w", err)
}
return &newFolder, nil
}
// SaveTo 将别人分享的文件转存到目标路径下
func (fs *FileSystem) SaveTo(ctx context.Context, path string) error {
// 获取父目录
isExist, folder := fs.IsPathExist(path)
if !isExist {
return ErrPathNotExist
}
var (
totalSize uint64
err error
)
if len(fs.DirTarget) > 0 {
totalSize, err = fs.DirTarget[0].CopyFolderTo(fs.DirTarget[0].ID, folder)
} else {
parent := model.Folder{
OwnerID: fs.FileTarget[0].UserID,
}
parent.ID = fs.FileTarget[0].FolderID
totalSize, err = parent.MoveOrCopyFileTo([]uint{fs.FileTarget[0].ID}, folder, true)
}
// 扣除用户容量
fs.User.IncreaseStorageWithoutCheck(totalSize)
if err != nil {
return ErrFileExisted.WithError(err)
}
return nil
}

View File

@ -0,0 +1,25 @@
package oauth
import "sync"
// CredentialLock 针对存储策略凭证的锁
type CredentialLock interface {
Lock(uint)
Unlock(uint)
}
var GlobalMutex = mutexMap{}
type mutexMap struct {
locks sync.Map
}
func (m *mutexMap) Lock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Lock()
}
func (m *mutexMap) Unlock(id uint) {
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
lock.(*sync.Mutex).Unlock()
}

View File

@ -0,0 +1,8 @@
package oauth
import "context"
type TokenProvider interface {
UpdateCredential(ctx context.Context, isSlave bool) error
AccessToken() string
}

84
pkg/filesystem/path.go Normal file
View File

@ -0,0 +1,84 @@
package filesystem
import (
"path"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* =================
路径/目录相关
=================
*/
// IsPathExist 返回给定目录是否存在
// 如果存在就返回目录
func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) {
tracedEnd, currentFolder := fs.getClosedParent(path)
if tracedEnd {
return true, currentFolder
}
return false, nil
}
func (fs *FileSystem) getClosedParent(path string) (bool, *model.Folder) {
pathList := util.SplitPath(path)
if len(pathList) == 0 {
return false, nil
}
// 递归步入目录
var currentFolder *model.Folder
// 如果已设定跟目录对象,则从给定目录向下遍历
if fs.Root != nil {
currentFolder = fs.Root
}
for _, folderName := range pathList {
var err error
// 根目录
if folderName == "/" {
if currentFolder != nil {
continue
}
currentFolder, err = fs.User.Root()
if err != nil {
return false, nil
}
} else {
nextFolder, err := currentFolder.GetChild(folderName)
if err != nil {
return false, currentFolder
}
currentFolder = nextFolder
}
}
return true, currentFolder
}
// IsFileExist 返回给定路径的文件是否存在
func (fs *FileSystem) IsFileExist(fullPath string) (bool, *model.File) {
basePath := path.Dir(fullPath)
fileName := path.Base(fullPath)
// 获得父目录
exist, parent := fs.IsPathExist(basePath)
if !exist {
return false, nil
}
file, err := parent.GetChildFile(fileName)
return err == nil, file
}
// IsChildFileExist 确定folder目录下是否有名为name的文件
func (fs *FileSystem) IsChildFileExist(folder *model.Folder, name string) (bool, *model.File) {
file, err := folder.GetChildFile(name)
return err == nil, file
}

102
pkg/filesystem/relocate.go Normal file
View File

@ -0,0 +1,102 @@
package filesystem
import (
"context"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* ================
存储策略迁移
================
*/
// Relocate 将目标文件转移到当前存储策略下
func (fs *FileSystem) Relocate(ctx context.Context, files []model.File, policy *model.Policy) error {
// 重设存储策略为要转移的目的策略
fs.Policy = policy
if err := fs.DispatchHandler(); err != nil {
return err
}
// 将目前文件根据存储策略分组
fileGroup := fs.GroupFileByPolicy(ctx, files)
// 按照存储策略分组处理每个文件
for _, fileList := range fileGroup {
// 如果存储策略一样,则跳过
if fileList[0].GetPolicy().ID == fs.Policy.ID {
util.Log().Debug("Skip relocating %d file(s), since they are already in desired policy.",
len(fileList))
continue
}
// 获取当前存储策略的处理器
currentPolicy, _ := model.GetPolicyByID(fileList[0].PolicyID)
currentHandler, err := getNewPolicyHandler(&currentPolicy)
if err != nil {
return err
}
// 记录转移完毕需要删除的文件
toBeDeleted := make([]model.File, 0, len(fileList))
// 循环处理每一个文件
// for id, r := 0, len(fileList); id < r; id++ {
for id, _ := range fileList {
// 验证文件是否符合新存储策略的规定
if err := HookValidateFile(ctx, fs, fileList[id]); err != nil {
util.Log().Debug("File %q failed to pass validators in new policy %q, skipping.",
fileList[id].Name, err)
continue
}
// 为文件生成新存储策略下的物理路径
savePath := fs.GenerateSavePath(ctx, fileList[id])
// 获取原始文件
src, err := currentHandler.Get(ctx, fileList[id].SourceName)
if err != nil {
util.Log().Debug("Failed to get file %q: %s, skipping.",
fileList[id].Name, err)
continue
}
// 转存到新的存储策略
if err := fs.Handler.Put(ctx, &fsctx.FileStream{
File: src,
SavePath: savePath,
Size: fileList[id].Size,
}); err != nil {
util.Log().Debug("Failed to migrate file %q: %s, skipping.",
fileList[id].Name, err)
continue
}
toBeDeleted = append(toBeDeleted, *fileList[id])
// 更新文件信息
fileList[id].Relocate(savePath, fs.Policy.ID)
}
// 排除带有软链接的文件
toBeDeletedClean, err := model.RemoveFilesWithSoftLinks(toBeDeleted)
if err != nil {
util.Log().Warning("Failed to check soft links: %s", err)
}
deleteSourceNames := make([]string, 0, len(toBeDeleted))
for i := 0; i < len(toBeDeletedClean); i++ {
deleteSourceNames = append(deleteSourceNames, toBeDeletedClean[i].SourceName)
}
// 删除原始策略中的文件
if _, err := currentHandler.Delete(ctx, deleteSourceNames); err != nil {
util.Log().Warning("Cannot delete files in origin policy after relocating: %s", err)
}
}
return nil
}

View File

@ -0,0 +1,32 @@
package response
import (
"io"
"time"
)
// ContentResponse 获取文件内容类方法的通用返回值。
// 有些上传策略需要重定向,
// 有些直接写文件数据到浏览器
type ContentResponse struct {
Redirect bool
Content RSCloser
URL string
MaxAge int
}
// RSCloser 存储策略适配器返回的文件流有些策略需要带有Closer
type RSCloser interface {
io.ReadSeeker
io.Closer
}
// Object 列出文件、目录时返回的对象
type Object struct {
Name string `json:"name"`
RelativePath string `json:"relative_path"`
Source string `json:"source"`
Size uint64 `json:"size"`
IsDir bool `json:"is_dir"`
LastModify time.Time `json:"last_modify"`
}

View File

View File

Binary file not shown.

243
pkg/filesystem/upload.go Normal file
View File

@ -0,0 +1,243 @@
package filesystem
import (
"context"
"os"
"path"
"time"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/gin-gonic/gin"
"github.com/gofrs/uuid"
)
/* ================
上传处理相关
================
*/
const (
UploadSessionMetaKey = "upload_session"
UploadSessionCtx = "uploadSession"
UserCtx = "user"
UploadSessionCachePrefix = "callback_"
)
// Upload 上传文件
func (fs *FileSystem) Upload(ctx context.Context, file *fsctx.FileStream) (err error) {
// 上传前的钩子
err = fs.Trigger(ctx, "BeforeUpload", file)
if err != nil {
request.BlackHole(file)
return err
}
// 生成文件名和路径,
var savePath string
if file.SavePath == "" {
// 如果是更新操作就从上下文中获取
if originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
savePath = originFile.SourceName
} else {
savePath = fs.GenerateSavePath(ctx, file)
}
file.SavePath = savePath
}
// 保存文件
if file.Mode&fsctx.Nop != fsctx.Nop {
// 处理客户端未完成上传时,关闭连接
go fs.CancelUpload(ctx, savePath, file)
err = fs.Handler.Put(ctx, file)
if err != nil {
fs.Trigger(ctx, "AfterUploadFailed", file)
return err
}
}
// 上传完成后的钩子
err = fs.Trigger(ctx, "AfterUpload", file)
if err != nil {
// 上传完成后续处理失败
followUpErr := fs.Trigger(ctx, "AfterValidateFailed", file)
// 失败后再失败...
if followUpErr != nil {
util.Log().Debug("AfterValidateFailed hook execution failed: %s", followUpErr)
}
return err
}
return nil
}
// GenerateSavePath 生成要存放文件的路径
// TODO 完善测试
func (fs *FileSystem) GenerateSavePath(ctx context.Context, file fsctx.FileHeader) string {
fileInfo := file.Info()
return path.Join(
fs.Policy.GeneratePath(
fs.User.Model.ID,
fileInfo.VirtualPath,
),
fs.Policy.GenerateFileName(
fs.User.Model.ID,
fileInfo.FileName,
),
)
}
// CancelUpload 监测客户端取消上传
func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file fsctx.FileHeader) {
var reqContext context.Context
if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok {
reqContext = ginCtx.Request.Context()
} else if reqCtx, ok := ctx.Value(fsctx.HTTPCtx).(context.Context); ok {
reqContext = reqCtx
} else {
return
}
select {
case <-reqContext.Done():
select {
case <-ctx.Done():
// 客户端正常关闭,不执行操作
default:
// 客户端取消上传,删除临时文件
util.Log().Debug("Client canceled upload.")
if fs.Hooks["AfterUploadCanceled"] == nil {
return
}
err := fs.Trigger(ctx, "AfterUploadCanceled", file)
if err != nil {
util.Log().Debug("AfterUploadCanceled hook execution failed: %s", err)
}
}
}
}
// CreateUploadSession 创建上传会话
func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileStream) (*serializer.UploadCredential, error) {
// 获取相关有效期设置
callBackSessionTTL := model.GetIntSetting("upload_session_timeout", 86400)
callbackKey := uuid.Must(uuid.NewV4()).String()
fileSize := file.Size
// 创建占位的文件,同时校验文件信息
file.Mode = fsctx.Nop
if callbackKey != "" {
file.UploadSessionID = &callbackKey
}
fs.Use("BeforeUpload", HookValidateFile)
fs.Use("BeforeUpload", HookValidateCapacity)
// 验证文件规格
if err := fs.Upload(ctx, file); err != nil {
return nil, err
}
uploadSession := &serializer.UploadSession{
Key: callbackKey,
UID: fs.User.ID,
Policy: *fs.Policy,
VirtualPath: file.VirtualPath,
Name: file.Name,
Size: fileSize,
SavePath: file.SavePath,
LastModified: file.LastModified,
CallbackSecret: util.RandStringRunes(32),
}
// 获取上传凭证
credential, err := fs.Handler.Token(ctx, int64(callBackSessionTTL), uploadSession, file)
if err != nil {
return nil, err
}
// 创建占位符
if !fs.Policy.IsUploadPlaceholderWithSize() {
fs.Use("AfterUpload", HookClearFileHeaderSize)
}
fs.Use("AfterUpload", GenericAfterUpload)
ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true)
if err := fs.Upload(ctx, file); err != nil {
return nil, err
}
// 创建回调会话
err = cache.Set(
UploadSessionCachePrefix+callbackKey,
*uploadSession,
callBackSessionTTL,
)
if err != nil {
return nil, err
}
// 补全上传凭证其他信息
credential.Expires = time.Now().Add(time.Duration(callBackSessionTTL) * time.Second).Unix()
return credential, nil
}
// UploadFromStream 从文件流上传文件
func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error {
// 给文件系统分配钩子
fs.Lock.Lock()
if resetPolicy {
err := fs.SetPolicyFromPath(file.VirtualPath)
if err != nil {
return err
}
}
if fs.Hooks == nil {
fs.Use("BeforeUpload", HookValidateFile)
fs.Use("BeforeUpload", HookValidateCapacity)
fs.Use("AfterUploadCanceled", HookDeleteTempFile)
fs.Use("AfterUpload", GenericAfterUpload)
fs.Use("AfterValidateFailed", HookDeleteTempFile)
}
fs.Lock.Unlock()
// 开始上传
return fs.Upload(ctx, file)
}
// UploadFromPath 将本机已有文件上传到用户的文件系统
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, mode fsctx.WriteMode) error {
file, err := os.Open(util.RelativePath(src))
if err != nil {
return err
}
defer file.Close()
// 获取源文件大小
fi, err := file.Stat()
if err != nil {
return err
}
size := fi.Size()
// 开始上传
return fs.UploadFromStream(ctx, &fsctx.FileStream{
File: file,
Seeker: file,
Size: uint64(size),
Name: path.Base(dst),
VirtualPath: path.Dir(dst),
Mode: mode,
}, true)
}

View File

@ -0,0 +1,66 @@
package filesystem
import (
"context"
"strings"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
)
/* ==========
验证器
==========
*/
// 文件/路径名保留字符
var reservedCharacter = []string{"\\", "?", "*", "<", "\"", ":", ">", "/", "|"}
// ValidateLegalName 验证文件名/文件夹名是否合法
func (fs *FileSystem) ValidateLegalName(ctx context.Context, name string) bool {
// 是否包含保留字符
for _, value := range reservedCharacter {
if strings.Contains(name, value) {
return false
}
}
// 是否超出长度限制
if len(name) >= 256 {
return false
}
// 是否为空限制
if len(name) == 0 {
return false
}
// 结尾不能是空格
if strings.HasSuffix(name, " ") {
return false
}
return true
}
// ValidateFileSize 验证上传的文件大小是否超出限制
func (fs *FileSystem) ValidateFileSize(ctx context.Context, size uint64) bool {
if fs.Policy.MaxSize == 0 {
return true
}
return size <= fs.Policy.MaxSize
}
// ValidateCapacity 验证并扣除用户容量
func (fs *FileSystem) ValidateCapacity(ctx context.Context, size uint64) bool {
return fs.User.IncreaseStorage(size)
}
// ValidateExtension 验证文件扩展名
func (fs *FileSystem) ValidateExtension(ctx context.Context, fileName string) bool {
// 不需要验证
if len(fs.Policy.OptionsSerialized.FileType) == 0 {
return true
}
return util.IsInExtensionList(fs.Policy.OptionsSerialized.FileType, fileName)
}

70
pkg/hashid/hash.go Normal file
View File

@ -0,0 +1,70 @@
package hashid
import (
"errors"
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
"github.com/speps/go-hashids"
)
// ID类型
const (
ShareID = iota // 分享
UserID // 用户
FileID // 文件ID
FolderID // 目录ID
TagID // 标签ID
PolicyID // 存储策略ID
SourceLinkID
)
var (
// ErrTypeNotMatch ID类型不匹配
ErrTypeNotMatch = errors.New("mismatched ID type.")
)
// HashEncode 对给定数据计算HashID
func HashEncode(v []int) (string, error) {
hd := hashids.NewData()
hd.Salt = conf.SystemConfig.HashIDSalt
h, err := hashids.NewWithData(hd)
if err != nil {
return "", err
}
id, err := h.Encode(v)
if err != nil {
return "", err
}
return id, nil
}
// HashDecode 对给定数据计算原始数据
func HashDecode(raw string) ([]int, error) {
hd := hashids.NewData()
hd.Salt = conf.SystemConfig.HashIDSalt
h, err := hashids.NewWithData(hd)
if err != nil {
return []int{}, err
}
return h.DecodeWithError(raw)
}
// HashID 计算数据库内主键对应的HashID
func HashID(id uint, t int) string {
v, _ := HashEncode([]int{int(id), t})
return v
}
// DecodeHashID 计算HashID对应的数据库ID
func DecodeHashID(id string, t int) (uint, error) {
v, _ := HashDecode(id)
if len(v) != 2 || v[1] != t {
return 0, ErrTypeNotMatch
}
return uint(v[0]), nil
}

View File

@ -0,0 +1,37 @@
package cachemock
import "github.com/stretchr/testify/mock"
type CacheClientMock struct {
mock.Mock
}
func (c CacheClientMock) Set(key string, value interface{}, ttl int) error {
return c.Called(key, value, ttl).Error(0)
}
func (c CacheClientMock) Get(key string) (interface{}, bool) {
args := c.Called(key)
return args.Get(0), args.Bool(1)
}
func (c CacheClientMock) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
args := c.Called(keys, prefix)
return args.Get(0).(map[string]interface{}), args.Get(1).([]string)
}
func (c CacheClientMock) Sets(values map[string]interface{}, prefix string) error {
return c.Called(values).Error(0)
}
func (c CacheClientMock) Delete(keys []string, prefix string) error {
return c.Called(keys, prefix).Error(0)
}
func (c CacheClientMock) Persist(path string) error {
return c.Called(path).Error(0)
}
func (c CacheClientMock) Restore(path string) error {
return c.Called(path).Error(0)
}

View File

@ -0,0 +1,43 @@
package controllermock
import (
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/mock"
)
type SlaveControllerMock struct {
mock.Mock
}
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
args := s.Called(pingReq)
return args.Get(0).(serializer.NodePingResp), args.Error(1)
}
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
args := s.Called(s2)
return args.Get(0).(common.Aria2), args.Error(1)
}
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
args := s.Called(s3, s2, message)
return args.Error(0)
}
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
args := s.Called(s3, i, s2, f)
return args.Error(0)
}
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
args := s.Called(s2)
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
}
func (s SlaveControllerMock) GetPolicyOauthToken(s2 string, u uint) (string, error) {
args := s.Called(s2, u)
return args.String(0), args.Error(1)
}

151
pkg/mocks/mocks.go Normal file
View File

@ -0,0 +1,151 @@
package mocks
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/task"
testMock "github.com/stretchr/testify/mock"
)
type NodePoolMock struct {
testMock.Mock
}
func (n NodePoolMock) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, cluster.Node) {
args := n.Called(feature, lb)
return args.Error(0), args.Get(1).(cluster.Node)
}
func (n NodePoolMock) GetNodeByID(id uint) cluster.Node {
args := n.Called(id)
if res, ok := args.Get(0).(cluster.Node); ok {
return res
}
return nil
}
func (n NodePoolMock) Add(node *model.Node) {
n.Called(node)
}
func (n NodePoolMock) Delete(id uint) {
n.Called(id)
}
type NodeMock struct {
testMock.Mock
}
func (n NodeMock) Init(node *model.Node) {
n.Called(node)
}
func (n NodeMock) IsFeatureEnabled(feature string) bool {
args := n.Called(feature)
return args.Bool(0)
}
func (n NodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
n.Called(callback)
}
func (n NodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
args := n.Called(req)
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
}
func (n NodeMock) IsActive() bool {
args := n.Called()
return args.Bool(0)
}
func (n NodeMock) GetAria2Instance() common.Aria2 {
args := n.Called()
return args.Get(0).(common.Aria2)
}
func (n NodeMock) ID() uint {
args := n.Called()
return args.Get(0).(uint)
}
func (n NodeMock) Kill() {
n.Called()
}
func (n NodeMock) IsMater() bool {
args := n.Called()
return args.Bool(0)
}
func (n NodeMock) MasterAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n NodeMock) SlaveAuthInstance() auth.Auth {
args := n.Called()
return args.Get(0).(auth.Auth)
}
func (n NodeMock) DBModel() *model.Node {
args := n.Called()
return args.Get(0).(*model.Node)
}
type Aria2Mock struct {
testMock.Mock
}
func (a Aria2Mock) Init() error {
args := a.Called()
return args.Error(0)
}
func (a Aria2Mock) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
args := a.Called(task, options)
return args.String(0), args.Error(1)
}
func (a Aria2Mock) Status(task *model.Download) (rpc.StatusInfo, error) {
args := a.Called(task)
return args.Get(0).(rpc.StatusInfo), args.Error(1)
}
func (a Aria2Mock) Cancel(task *model.Download) error {
args := a.Called(task)
return args.Error(0)
}
func (a Aria2Mock) Select(task *model.Download, files []int) error {
args := a.Called(task, files)
return args.Error(0)
}
func (a Aria2Mock) GetConfig() model.Aria2Option {
args := a.Called()
return args.Get(0).(model.Aria2Option)
}
func (a Aria2Mock) DeleteTempFile(download *model.Download) error {
args := a.Called(download)
return args.Error(0)
}
type TaskPoolMock struct {
testMock.Mock
}
func (t TaskPoolMock) Add(num int) {
t.Called(num)
}
func (t TaskPoolMock) Submit(job task.Job) {
t.Called(job)
}

View File

@ -0,0 +1,33 @@
package remoteclientmock
import (
"context"
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/stretchr/testify/mock"
)
type RemoteClientMock struct {
mock.Mock
}
func (r *RemoteClientMock) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error {
return r.Called(ctx, session, ttl, overwrite).Error(0)
}
func (r *RemoteClientMock) GetUploadURL(ttl int64, sessionID string) (string, string, error) {
args := r.Called(ttl, sessionID)
return args.String(0), args.String(1), args.Error(2)
}
func (r *RemoteClientMock) Upload(ctx context.Context, file fsctx.FileHeader) error {
args := r.Called(ctx, file)
return args.Error(0)
}
func (r *RemoteClientMock) DeleteUploadSession(ctx context.Context, sessionID string) error {
args := r.Called(ctx, sessionID)
return args.Error(0)
}

View File

@ -0,0 +1,15 @@
package requestmock
import (
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/stretchr/testify/mock"
"io"
)
type RequestMock struct {
mock.Mock
}
func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
return r.Called(method, target, body, opts).Get(0).(*request.Response)
}

View File

@ -0,0 +1,25 @@
package thumbmock
import (
"context"
"github.com/cloudreve/Cloudreve/v3/pkg/thumb"
"github.com/stretchr/testify/mock"
"io"
)
type GeneratorMock struct {
mock.Mock
}
func (g GeneratorMock) Generate(ctx context.Context, file io.Reader, src string, name string, options map[string]string) (*thumb.Result, error) {
res := g.Called(ctx, file, src, name, options)
return res.Get(0).(*thumb.Result), res.Error(1)
}
func (g GeneratorMock) Priority() int {
return 0
}
func (g GeneratorMock) EnableFlag() string {
return "thumb_vips_enabled"
}

View File

@ -0,0 +1,21 @@
package wopimock
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
"github.com/stretchr/testify/mock"
)
type WopiClientMock struct {
mock.Mock
}
func (w *WopiClientMock) NewSession(user uint, file *model.File, action wopi.ActonType) (*wopi.Session, error) {
args := w.Called(user, file, action)
return args.Get(0).(*wopi.Session), args.Error(1)
}
func (w *WopiClientMock) AvailableExts() []string {
args := w.Called()
return args.Get(0).([]string)
}

160
pkg/mq/mq.go Normal file
View File

@ -0,0 +1,160 @@
package mq
import (
"encoding/gob"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
"strconv"
"sync"
"time"
)
// Message 消息事件正文
type Message struct {
// 消息触发者
TriggeredBy string
// 事件标识
Event string
// 消息正文
Content interface{}
}
type CallbackFunc func(Message)
// MQ 消息队列
type MQ interface {
rpc.Notifier
// 发布一个消息
Publish(string, Message)
// 订阅一个消息主题
Subscribe(string, int) <-chan Message
// 订阅一个消息主题,注册触发回调函数
SubscribeCallback(string, CallbackFunc)
// 取消订阅一个消息主题
Unsubscribe(string, <-chan Message)
}
var GlobalMQ = NewMQ()
func NewMQ() MQ {
return &inMemoryMQ{
topics: make(map[string][]chan Message),
callbacks: make(map[string][]CallbackFunc),
}
}
func init() {
gob.Register(Message{})
gob.Register([]rpc.Event{})
}
type inMemoryMQ struct {
topics map[string][]chan Message
callbacks map[string][]CallbackFunc
sync.RWMutex
}
func (i *inMemoryMQ) Publish(topic string, message Message) {
i.RLock()
subscribersChan, okChan := i.topics[topic]
subscribersCallback, okCallback := i.callbacks[topic]
i.RUnlock()
if okChan {
go func(subscribersChan []chan Message) {
for i := 0; i < len(subscribersChan); i++ {
select {
case subscribersChan[i] <- message:
case <-time.After(time.Millisecond * 500):
}
}
}(subscribersChan)
}
if okCallback {
for i := 0; i < len(subscribersCallback); i++ {
go subscribersCallback[i](message)
}
}
}
func (i *inMemoryMQ) Subscribe(topic string, buffer int) <-chan Message {
ch := make(chan Message, buffer)
i.Lock()
i.topics[topic] = append(i.topics[topic], ch)
i.Unlock()
return ch
}
func (i *inMemoryMQ) SubscribeCallback(topic string, callbackFunc CallbackFunc) {
i.Lock()
i.callbacks[topic] = append(i.callbacks[topic], callbackFunc)
i.Unlock()
}
func (i *inMemoryMQ) Unsubscribe(topic string, sub <-chan Message) {
i.Lock()
defer i.Unlock()
subscribers, ok := i.topics[topic]
if !ok {
return
}
var newSubs []chan Message
for _, subscriber := range subscribers {
if subscriber == sub {
continue
}
newSubs = append(newSubs, subscriber)
}
i.topics[topic] = newSubs
}
func (i *inMemoryMQ) Aria2Notify(events []rpc.Event, status int) {
for _, event := range events {
i.Publish(event.Gid, Message{
TriggeredBy: event.Gid,
Event: strconv.FormatInt(int64(status), 10),
Content: events,
})
}
}
// OnDownloadStart 下载开始
func (i *inMemoryMQ) OnDownloadStart(events []rpc.Event) {
i.Aria2Notify(events, common.Downloading)
}
// OnDownloadPause 下载暂停
func (i *inMemoryMQ) OnDownloadPause(events []rpc.Event) {
i.Aria2Notify(events, common.Paused)
}
// OnDownloadStop 下载停止
func (i *inMemoryMQ) OnDownloadStop(events []rpc.Event) {
i.Aria2Notify(events, common.Canceled)
}
// OnDownloadComplete 下载完成
func (i *inMemoryMQ) OnDownloadComplete(events []rpc.Event) {
i.Aria2Notify(events, common.Complete)
}
// OnDownloadError 下载出错
func (i *inMemoryMQ) OnDownloadError(events []rpc.Event) {
i.Aria2Notify(events, common.Error)
}
// OnBtDownloadComplete BT下载完成
func (i *inMemoryMQ) OnBtDownloadComplete(events []rpc.Event) {
i.Aria2Notify(events, common.Complete)
}

43
pkg/payment/alipay.go Normal file
View File

@ -0,0 +1,43 @@
package payment
import (
"fmt"
"net/url"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
alipay "github.com/smartwalle/alipay/v3"
)
// Alipay 支付宝当面付支付处理
type Alipay struct {
Client *alipay.Client
}
// Create 创建订单
func (pay *Alipay) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
gateway, _ := url.Parse("/api/v3/callback/alipay")
var p = alipay.TradePreCreate{
Trade: alipay.Trade{
NotifyURL: model.GetSiteURL().ResolveReference(gateway).String(),
Subject: order.Name,
OutTradeNo: order.OrderNo,
TotalAmount: fmt.Sprintf("%.2f", float64(order.Price*order.Num)/100),
},
}
if _, err := order.Create(); err != nil {
return nil, ErrInsertOrder.WithError(err)
}
res, err := pay.Client.TradePreCreate(p)
if err != nil {
return nil, ErrIssueOrder.WithError(err)
}
return &OrderCreateRes{
Payment: true,
QRCode: res.QRCode,
ID: order.OrderNo,
}, nil
}

93
pkg/payment/custom.go Normal file
View File

@ -0,0 +1,93 @@
package payment
import (
"encoding/json"
"errors"
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/request"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/gofrs/uuid"
"github.com/qiniu/go-sdk/v7/sms/bytes"
"net/http"
"net/url"
)
// Custom payment client
type Custom struct {
client request.Client
endpoint string
authClient auth.Auth
}
const (
paymentTTL = 3600 * 24 // 24h
CallbackSessionPrefix = "custom_payment_callback_"
)
func newCustomClient(endpoint, secret string) *Custom {
authClient := auth.HMACAuth{
SecretKey: []byte(secret),
}
return &Custom{
endpoint: endpoint,
authClient: auth.General,
client: request.NewClient(
request.WithCredential(authClient, paymentTTL),
request.WithMasterMeta(),
),
}
}
// Request body from Cloudreve to create a new payment
type NewCustomOrderRequest struct {
Name string `json:"name"` // Order name
OrderNo string `json:"order_no"` // Order number
NotifyURL string `json:"notify_url"` // Payment callback url
Amount int64 `json:"amount"` // Order total amount
}
// Create a new payment
func (pay *Custom) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
callbackID := uuid.Must(uuid.NewV4())
gateway, _ := url.Parse(fmt.Sprintf("/api/v3/callback/custom/%s/%s", order.OrderNo, callbackID))
callback, err := auth.SignURI(pay.authClient, model.GetSiteURL().ResolveReference(gateway).String(), paymentTTL)
if err != nil {
return nil, fmt.Errorf("failed to sign callback url: %w", err)
}
cache.Set(CallbackSessionPrefix+callbackID.String(), order.OrderNo, paymentTTL)
body := &NewCustomOrderRequest{
Name: order.Name,
OrderNo: order.OrderNo,
NotifyURL: callback.String(),
Amount: int64(order.Price * order.Num),
}
bodyJson, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("failed to encode body: %w", err)
}
res, err := pay.client.Request("POST", pay.endpoint, bytes.NewReader(bodyJson)).
CheckHTTPResponse(http.StatusOK).DecodeResponse()
if err != nil {
return nil, fmt.Errorf("failed to request payment gateway: %w", err)
}
if res.Code != 0 {
return nil, errors.New(res.Error)
}
if _, err := order.Create(); err != nil {
return nil, ErrInsertOrder.WithError(err)
}
return &OrderCreateRes{
Payment: true,
QRCode: res.Data.(string),
ID: order.OrderNo,
}, nil
}

171
pkg/payment/order.go Normal file
View File

@ -0,0 +1,171 @@
package payment
import (
"fmt"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/iGoogle-ink/gopay/wechat/v3"
"github.com/qingwg/payjs"
"github.com/smartwalle/alipay/v3"
"math/rand"
"net/url"
"time"
)
var (
// ErrUnknownPaymentMethod 未知支付方式
ErrUnknownPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "Unknown payment method", nil)
// ErrUnsupportedPaymentMethod 未知支付方式
ErrUnsupportedPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "This order cannot be paid with this method", nil)
// ErrInsertOrder 无法插入订单记录
ErrInsertOrder = serializer.NewError(serializer.CodeDBError, "Failed to insert order record", nil)
// ErrScoreNotEnough 积分不足
ErrScoreNotEnough = serializer.NewError(serializer.CodeInsufficientCredit, "", nil)
// ErrCreateStoragePack 无法创建容量包
ErrCreateStoragePack = serializer.NewError(serializer.CodeDBError, "Failed to create storage pack record", nil)
// ErrGroupConflict 用户组冲突
ErrGroupConflict = serializer.NewError(serializer.CodeGroupConflict, "", nil)
// ErrGroupInvalid 用户组冲突
ErrGroupInvalid = serializer.NewError(serializer.CodeGroupInvalid, "", nil)
// ErrAdminFulfillGroup 管理员无法购买用户组
ErrAdminFulfillGroup = serializer.NewError(serializer.CodeFulfillAdminGroup, "", nil)
// ErrUpgradeGroup 用户组冲突
ErrUpgradeGroup = serializer.NewError(serializer.CodeDBError, "Failed to update user's group", nil)
// ErrUInitPayment 无法初始化支付实例
ErrUInitPayment = serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize payment client", nil)
// ErrIssueOrder 订单接口请求失败
ErrIssueOrder = serializer.NewError(serializer.CodeInternalSetting, "Failed to create order", nil)
// ErrOrderNotFound 订单不存在
ErrOrderNotFound = serializer.NewError(serializer.CodeNotFound, "", nil)
)
// Pay 支付处理接口
type Pay interface {
Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error)
}
// OrderCreateRes 订单创建结果
type OrderCreateRes struct {
Payment bool `json:"payment"` // 是否需要支付
ID string `json:"id,omitempty"` // 订单号
QRCode string `json:"qr_code,omitempty"` // 支付二维码指向的地址
}
// NewPaymentInstance 获取新的支付实例
func NewPaymentInstance(method string) (Pay, error) {
switch method {
case "score":
return &ScorePayment{}, nil
case "alipay":
options := model.GetSettingByNames("alipay_enabled", "appid", "appkey", "shopid")
if options["alipay_enabled"] != "1" {
return nil, ErrUnknownPaymentMethod
}
// 初始化支付宝客户端
var client, err = alipay.New(options["appid"], options["appkey"], true)
if err != nil {
return nil, ErrUInitPayment.WithError(err)
}
// 加载支付宝公钥
err = client.LoadAliPayPublicKey(options["shopid"])
if err != nil {
return nil, ErrUInitPayment.WithError(err)
}
return &Alipay{Client: client}, nil
case "payjs":
options := model.GetSettingByNames("payjs_enabled", "payjs_secret", "payjs_id")
if options["payjs_enabled"] != "1" {
return nil, ErrUnknownPaymentMethod
}
callback, _ := url.Parse("/api/v3/callback/payjs")
payjsConfig := &payjs.Config{
Key: options["payjs_secret"],
MchID: options["payjs_id"],
NotifyUrl: model.GetSiteURL().ResolveReference(callback).String(),
}
return &PayJSClient{Client: payjs.New(payjsConfig)}, nil
case "wechat":
options := model.GetSettingByNames("wechat_enabled", "wechat_appid", "wechat_mchid", "wechat_serial_no", "wechat_api_key", "wechat_pk_content")
if options["wechat_enabled"] != "1" {
return nil, ErrUnknownPaymentMethod
}
client, err := wechat.NewClientV3(options["wechat_appid"], options["wechat_mchid"], options["wechat_serial_no"], options["wechat_api_key"], options["wechat_pk_content"])
if err != nil {
return nil, ErrUInitPayment.WithError(err)
}
return &Wechat{Client: client, ApiV3Key: options["wechat_api_key"]}, nil
case "custom":
options := model.GetSettingByNames("custom_payment_enabled", "custom_payment_endpoint", "custom_payment_secret")
if !model.IsTrueVal(options["custom_payment_enabled"]) {
return nil, ErrUnknownPaymentMethod
}
return newCustomClient(options["custom_payment_endpoint"], options["custom_payment_secret"]), nil
default:
return nil, ErrUnknownPaymentMethod
}
}
// NewOrder 创建新订单
func NewOrder(pack *serializer.PackProduct, group *serializer.GroupProducts, num int, method string, user *model.User) (*OrderCreateRes, error) {
// 获取支付实例
pay, err := NewPaymentInstance(method)
if err != nil {
return nil, err
}
var (
orderType int
productID int64
title string
price int
)
if pack != nil {
orderType = model.PackOrderType
productID = pack.ID
title = pack.Name
price = pack.Price
} else if group != nil {
if err := checkGroupUpgrade(user, group); err != nil {
return nil, err
}
orderType = model.GroupOrderType
productID = group.ID
title = group.Name
price = group.Price
} else {
orderType = model.ScoreOrderType
productID = 0
title = fmt.Sprintf("%d 积分", num)
price = model.GetIntSetting("score_price", 1)
}
// 创建订单记录
order := &model.Order{
UserID: user.ID,
OrderNo: orderID(),
Type: orderType,
Method: method,
ProductID: productID,
Num: num,
Name: fmt.Sprintf("%s - %s", model.GetSettingByName("siteName"), title),
Price: price,
Status: model.OrderUnpaid,
}
return pay.Create(order, pack, group, user)
}
func orderID() string {
return fmt.Sprintf("%s%d",
time.Now().Format("20060102150405"),
100000+rand.Intn(900000),
)
}

31
pkg/payment/payjs.go Normal file
View File

@ -0,0 +1,31 @@
package payment
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/qingwg/payjs"
)
// PayJSClient PayJS支付处理
type PayJSClient struct {
Client *payjs.PayJS
}
// Create 创建订单
func (pay *PayJSClient) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
if _, err := order.Create(); err != nil {
return nil, ErrInsertOrder.WithError(err)
}
PayNative := pay.Client.GetNative()
res, err := PayNative.Create(int64(order.Price*order.Num), order.Name, order.OrderNo, "", "")
if err != nil {
return nil, ErrIssueOrder.WithError(err)
}
return &OrderCreateRes{
Payment: true,
QRCode: res.CodeUrl,
ID: order.OrderNo,
}, nil
}

137
pkg/payment/purchase.go Normal file
View File

@ -0,0 +1,137 @@
package payment
import (
"encoding/json"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"strconv"
"time"
)
// GivePack 创建容量包
func GivePack(user *model.User, packInfo *serializer.PackProduct, num int) error {
timeNow := time.Now()
expires := timeNow.Add(time.Duration(packInfo.Time*int64(num)) * time.Second)
pack := model.StoragePack{
Name: packInfo.Name,
UserID: user.ID,
ActiveTime: &timeNow,
ExpiredTime: &expires,
Size: packInfo.Size,
}
if _, err := pack.Create(); err != nil {
return ErrCreateStoragePack.WithError(err)
}
cache.Deletes([]string{strconv.FormatUint(uint64(user.ID), 10)}, "pack_size_")
return nil
}
func checkGroupUpgrade(user *model.User, groupInfo *serializer.GroupProducts) error {
if user.Group.ID == 1 {
return ErrAdminFulfillGroup
}
// 检查用户是否已有未过期用户
if user.PreviousGroupID != 0 && user.GroupID != groupInfo.GroupID {
return ErrGroupConflict
}
// 用户组不能相同
if user.GroupID == groupInfo.GroupID && user.PreviousGroupID == 0 {
return ErrGroupInvalid
}
return nil
}
// GiveGroup 升级用户组
func GiveGroup(user *model.User, groupInfo *serializer.GroupProducts, num int) error {
if err := checkGroupUpgrade(user, groupInfo); err != nil {
return err
}
timeNow := time.Now()
expires := timeNow.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second)
if user.PreviousGroupID != 0 {
expires = user.GroupExpires.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second)
}
if err := user.UpgradeGroup(groupInfo.GroupID, &expires); err != nil {
return ErrUpgradeGroup.WithError(err)
}
return nil
}
// GiveScore 积分充值
func GiveScore(user *model.User, num int) error {
user.AddScore(num)
return nil
}
// GiveProduct “发货”
func GiveProduct(user *model.User, pack *serializer.PackProduct, group *serializer.GroupProducts, num int) error {
if pack != nil {
return GivePack(user, pack, num)
} else if group != nil {
return GiveGroup(user, group, num)
} else {
return GiveScore(user, num)
}
}
// OrderPaid 订单已支付处理
func OrderPaid(orderNo string) error {
order, err := model.GetOrderByNo(orderNo)
if err != nil || order.Status == model.OrderPaid {
return ErrOrderNotFound.WithError(err)
}
// 更新订单状态为 已支付
order.UpdateStatus(model.OrderPaid)
user, err := model.GetActiveUserByID(order.UserID)
if err != nil {
return serializer.NewError(serializer.CodeUserNotFound, "", err)
}
// 查询商品
options := model.GetSettingByNames("pack_data", "group_sell_data")
var (
packs []serializer.PackProduct
groups []serializer.GroupProducts
)
if err := json.Unmarshal([]byte(options["pack_data"]), &packs); err != nil {
return err
}
if err := json.Unmarshal([]byte(options["group_sell_data"]), &groups); err != nil {
return err
}
// 查找要购买的商品
var (
pack *serializer.PackProduct
group *serializer.GroupProducts
)
if order.Type == model.GroupOrderType {
for _, v := range groups {
if v.ID == order.ProductID {
group = &v
break
}
}
} else if order.Type == model.PackOrderType {
for _, v := range packs {
if v.ID == order.ProductID {
pack = &v
break
}
}
}
// "发货"
return GiveProduct(&user, pack, group, order.Num)
}

45
pkg/payment/score.go Normal file
View File

@ -0,0 +1,45 @@
package payment
import (
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
)
// ScorePayment 积分支付处理
type ScorePayment struct {
}
// Create 创建新订单
func (pay *ScorePayment) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
if pack != nil {
order.Price = pack.Score
} else {
order.Price = group.Score
}
// 检查此订单是否可用积分支付
if order.Price == 0 {
return nil, ErrUnsupportedPaymentMethod
}
// 扣除用户积分
if !user.PayScore(order.Price * order.Num) {
return nil, ErrScoreNotEnough
}
// 商品“发货”
if err := GiveProduct(user, pack, group, order.Num); err != nil {
user.AddScore(order.Price * order.Num)
return nil, err
}
// 创建订单记录
order.Status = model.OrderPaid
if _, err := order.Create(); err != nil {
return nil, ErrInsertOrder.WithError(err)
}
return &OrderCreateRes{
Payment: false,
}, nil
}

88
pkg/payment/wechat.go Normal file
View File

@ -0,0 +1,88 @@
package payment
import (
"errors"
model "github.com/cloudreve/Cloudreve/v3/models"
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
"github.com/cloudreve/Cloudreve/v3/pkg/util"
"github.com/iGoogle-ink/gopay"
"github.com/iGoogle-ink/gopay/wechat/v3"
"net/url"
"time"
)
// Wechat 微信扫码支付接口
type Wechat struct {
Client *wechat.ClientV3
ApiV3Key string
}
// Create 创建订单
func (pay *Wechat) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
gateway, _ := url.Parse("/api/v3/callback/wechat")
bm := make(gopay.BodyMap)
bm.
Set("description", order.Name).
Set("out_trade_no", order.OrderNo).
Set("notify_url", model.GetSiteURL().ResolveReference(gateway).String()).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", int64(order.Price*order.Num)).
Set("currency", "CNY")
})
wxRsp, err := pay.Client.V3TransactionNative(bm)
if err != nil {
return nil, ErrIssueOrder.WithError(err)
}
if wxRsp.Code == wechat.Success {
if _, err := order.Create(); err != nil {
return nil, ErrInsertOrder.WithError(err)
}
return &OrderCreateRes{
Payment: true,
QRCode: wxRsp.Response.CodeUrl,
ID: order.OrderNo,
}, nil
}
return nil, ErrIssueOrder.WithError(errors.New(wxRsp.Error))
}
// GetPlatformCert 获取微信平台证书
func (pay *Wechat) GetPlatformCert() string {
if cert, ok := cache.Get("wechat_platform_cert"); ok {
return cert.(string)
}
res, err := pay.Client.GetPlatformCerts()
if err == nil {
// 使用反馈证书中启用时间较晚的
var (
currentLatest *time.Time
currentCert string
)
for _, cert := range res.Certs {
effectiveTime, err := time.Parse("2006-01-02T15:04:05-0700", cert.EffectiveTime)
if err != nil {
if currentLatest == nil {
currentLatest = &effectiveTime
currentCert = cert.PublicKey
continue
}
if currentLatest.Before(effectiveTime) {
currentLatest = &effectiveTime
currentCert = cert.PublicKey
}
}
}
cache.Set("wechat_platform_cert", currentCert, 3600*10)
return currentCert
}
util.Log().Debug("Failed to get Wechat Pay platform certificate: %s", err)
return ""
}

Some files were not shown because too many files have changed in this diff Show More