init
This commit is contained in:
67
pkg/aria2/aria2.go
Normal file
67
pkg/aria2/aria2.go
Normal file
@ -0,0 +1,67 @@
|
||||
package aria2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/monitor"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
)
|
||||
|
||||
// Instance 默认使用的Aria2处理实例
|
||||
var Instance common.Aria2 = &common.DummyAria2{}
|
||||
|
||||
// LB 获取 Aria2 节点的负载均衡器
|
||||
var LB balancer.Balancer
|
||||
|
||||
// Lock Instance的读写锁
|
||||
var Lock sync.RWMutex
|
||||
|
||||
// GetLoadBalancer 返回供Aria2使用的负载均衡器
|
||||
func GetLoadBalancer() balancer.Balancer {
|
||||
Lock.RLock()
|
||||
defer Lock.RUnlock()
|
||||
return LB
|
||||
}
|
||||
|
||||
// Init 初始化
|
||||
func Init(isReload bool, pool cluster.Pool, mqClient mq.MQ) {
|
||||
Lock.Lock()
|
||||
LB = balancer.NewBalancer("RoundRobin")
|
||||
Lock.Unlock()
|
||||
|
||||
if !isReload {
|
||||
// 从数据库中读取未完成任务,创建监控
|
||||
unfinished := model.GetDownloadsByStatus(common.Ready, common.Paused, common.Downloading, common.Seeding)
|
||||
|
||||
for i := 0; i < len(unfinished); i++ {
|
||||
// 创建任务监控
|
||||
monitor.NewMonitor(&unfinished[i], pool, mqClient)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRPCConnection 发送测试用的 RPC 请求,测试服务连通性
|
||||
func TestRPCConnection(server, secret string, timeout int) (rpc.VersionInfo, error) {
|
||||
// 解析RPC服务地址
|
||||
rpcServer, err := url.Parse(server)
|
||||
if err != nil {
|
||||
return rpc.VersionInfo{}, fmt.Errorf("cannot parse RPC server: %w", err)
|
||||
}
|
||||
|
||||
rpcServer.Path = "/jsonrpc"
|
||||
caller, err := rpc.New(context.Background(), rpcServer.String(), secret, time.Duration(timeout)*time.Second, nil)
|
||||
if err != nil {
|
||||
return rpc.VersionInfo{}, fmt.Errorf("cannot initialize rpc connection: %w", err)
|
||||
}
|
||||
|
||||
return caller.GetVersion()
|
||||
}
|
119
pkg/aria2/common/common.go
Normal file
119
pkg/aria2/common/common.go
Normal file
@ -0,0 +1,119 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Aria2 离线下载处理接口
|
||||
type Aria2 interface {
|
||||
// Init 初始化客户端连接
|
||||
Init() error
|
||||
// CreateTask 创建新的任务
|
||||
CreateTask(task *model.Download, options map[string]interface{}) (string, error)
|
||||
// 返回状态信息
|
||||
Status(task *model.Download) (rpc.StatusInfo, error)
|
||||
// 取消任务
|
||||
Cancel(task *model.Download) error
|
||||
// 选择要下载的文件
|
||||
Select(task *model.Download, files []int) error
|
||||
// 获取离线下载配置
|
||||
GetConfig() model.Aria2Option
|
||||
// 删除临时下载文件
|
||||
DeleteTempFile(*model.Download) error
|
||||
}
|
||||
|
||||
const (
|
||||
// URLTask 从URL添加的任务
|
||||
URLTask = iota
|
||||
// TorrentTask 种子任务
|
||||
TorrentTask
|
||||
)
|
||||
|
||||
const (
|
||||
// Ready 准备就绪
|
||||
Ready = iota
|
||||
// Downloading 下载中
|
||||
Downloading
|
||||
// Paused 暂停中
|
||||
Paused
|
||||
// Error 出错
|
||||
Error
|
||||
// Complete 完成
|
||||
Complete
|
||||
// Canceled 取消/停止
|
||||
Canceled
|
||||
// Unknown 未知状态
|
||||
Unknown
|
||||
// Seeding 做种中
|
||||
Seeding
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrNotEnabled 功能未开启错误
|
||||
ErrNotEnabled = serializer.NewError(serializer.CodeFeatureNotEnabled, "not enabled", nil)
|
||||
// ErrUserNotFound 未找到下载任务创建者
|
||||
ErrUserNotFound = serializer.NewError(serializer.CodeUserNotFound, "", nil)
|
||||
)
|
||||
|
||||
// DummyAria2 未开启Aria2功能时使用的默认处理器
|
||||
type DummyAria2 struct {
|
||||
}
|
||||
|
||||
func (instance *DummyAria2) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务,此处直接返回未开启错误
|
||||
func (instance *DummyAria2) CreateTask(model *model.Download, options map[string]interface{}) (string, error) {
|
||||
return "", ErrNotEnabled
|
||||
}
|
||||
|
||||
// Status 返回未开启错误
|
||||
func (instance *DummyAria2) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
return rpc.StatusInfo{}, ErrNotEnabled
|
||||
}
|
||||
|
||||
// Cancel 返回未开启错误
|
||||
func (instance *DummyAria2) Cancel(task *model.Download) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// Select 返回未开启错误
|
||||
func (instance *DummyAria2) Select(task *model.Download, files []int) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// GetConfig 返回空的
|
||||
func (instance *DummyAria2) GetConfig() model.Aria2Option {
|
||||
return model.Aria2Option{}
|
||||
}
|
||||
|
||||
// GetConfig 返回空的
|
||||
func (instance *DummyAria2) DeleteTempFile(src *model.Download) error {
|
||||
return ErrNotEnabled
|
||||
}
|
||||
|
||||
// GetStatus 将给定的状态字符串转换为状态标识数字
|
||||
func GetStatus(status rpc.StatusInfo) int {
|
||||
switch status.Status {
|
||||
case "complete":
|
||||
return Complete
|
||||
case "active":
|
||||
if status.BitTorrent.Mode != "" && status.CompletedLength == status.TotalLength {
|
||||
return Seeding
|
||||
}
|
||||
return Downloading
|
||||
case "waiting":
|
||||
return Ready
|
||||
case "paused":
|
||||
return Paused
|
||||
case "error":
|
||||
return Error
|
||||
case "removed":
|
||||
return Canceled
|
||||
default:
|
||||
return Unknown
|
||||
}
|
||||
}
|
320
pkg/aria2/monitor/monitor.go
Normal file
320
pkg/aria2/monitor/monitor.go
Normal file
@ -0,0 +1,320 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Monitor 离线下载状态监控
|
||||
type Monitor struct {
|
||||
Task *model.Download
|
||||
Interval time.Duration
|
||||
|
||||
notifier <-chan mq.Message
|
||||
node cluster.Node
|
||||
retried int
|
||||
}
|
||||
|
||||
var MAX_RETRY = 10
|
||||
|
||||
// NewMonitor 新建离线下载状态监控
|
||||
func NewMonitor(task *model.Download, pool cluster.Pool, mqClient mq.MQ) {
|
||||
monitor := &Monitor{
|
||||
Task: task,
|
||||
notifier: make(chan mq.Message),
|
||||
node: pool.GetNodeByID(task.GetNodeID()),
|
||||
}
|
||||
|
||||
if monitor.node != nil {
|
||||
monitor.Interval = time.Duration(monitor.node.GetAria2Instance().GetConfig().Interval) * time.Second
|
||||
go monitor.Loop(mqClient)
|
||||
|
||||
monitor.notifier = mqClient.Subscribe(monitor.Task.GID, 0)
|
||||
} else {
|
||||
monitor.setErrorStatus(errors.New("node not avaliable"))
|
||||
}
|
||||
}
|
||||
|
||||
// Loop 开启监控循环
|
||||
func (monitor *Monitor) Loop(mqClient mq.MQ) {
|
||||
defer mqClient.Unsubscribe(monitor.Task.GID, monitor.notifier)
|
||||
fmt.Println(cluster.Default)
|
||||
|
||||
// 首次循环立即更新
|
||||
interval := 50 * time.Millisecond
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-monitor.notifier:
|
||||
if monitor.Update() {
|
||||
return
|
||||
}
|
||||
case <-time.After(interval):
|
||||
interval = monitor.Interval
|
||||
if monitor.Update() {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update 更新状态,返回值表示是否退出监控
|
||||
func (monitor *Monitor) Update() bool {
|
||||
status, err := monitor.node.GetAria2Instance().Status(monitor.Task)
|
||||
|
||||
if err != nil {
|
||||
monitor.retried++
|
||||
util.Log().Warning("Cannot get status of download task %q: %s", monitor.Task.GID, err)
|
||||
|
||||
// 十次重试后认定为任务失败
|
||||
if monitor.retried > MAX_RETRY {
|
||||
util.Log().Warning("Cannot get status of download task %q,exceed maximum retry threshold: %s",
|
||||
monitor.Task.GID, err)
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
monitor.retried = 0
|
||||
|
||||
// 磁力链下载需要跟随
|
||||
if len(status.FollowedBy) > 0 {
|
||||
util.Log().Debug("Redirected download task from %q to %q.", monitor.Task.GID, status.FollowedBy[0])
|
||||
monitor.Task.GID = status.FollowedBy[0]
|
||||
monitor.Task.Save()
|
||||
return false
|
||||
}
|
||||
|
||||
// 更新任务信息
|
||||
if err := monitor.UpdateTaskInfo(status); err != nil {
|
||||
util.Log().Warning("Failed to update status of download task %q: %s", monitor.Task.GID, err)
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
util.Log().Debug("Remote download %q status updated to %q.", status.Gid, status.Status)
|
||||
|
||||
switch common.GetStatus(status) {
|
||||
case common.Complete, common.Seeding:
|
||||
return monitor.Complete(task.TaskPoll)
|
||||
case common.Error:
|
||||
return monitor.Error(status)
|
||||
case common.Downloading, common.Ready, common.Paused:
|
||||
return false
|
||||
case common.Canceled:
|
||||
monitor.Task.Status = common.Canceled
|
||||
monitor.Task.Save()
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
default:
|
||||
util.Log().Warning("Download task %q returns unknown status %q.", monitor.Task.GID, status.Status)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateTaskInfo 更新数据库中的任务信息
|
||||
func (monitor *Monitor) UpdateTaskInfo(status rpc.StatusInfo) error {
|
||||
originSize := monitor.Task.TotalSize
|
||||
|
||||
monitor.Task.GID = status.Gid
|
||||
monitor.Task.Status = common.GetStatus(status)
|
||||
|
||||
// 文件大小、已下载大小
|
||||
total, err := strconv.ParseUint(status.TotalLength, 10, 64)
|
||||
if err != nil {
|
||||
total = 0
|
||||
}
|
||||
downloaded, err := strconv.ParseUint(status.CompletedLength, 10, 64)
|
||||
if err != nil {
|
||||
downloaded = 0
|
||||
}
|
||||
monitor.Task.TotalSize = total
|
||||
monitor.Task.DownloadedSize = downloaded
|
||||
monitor.Task.GID = status.Gid
|
||||
monitor.Task.Parent = status.Dir
|
||||
|
||||
// 下载速度
|
||||
speed, err := strconv.Atoi(status.DownloadSpeed)
|
||||
if err != nil {
|
||||
speed = 0
|
||||
}
|
||||
|
||||
monitor.Task.Speed = speed
|
||||
attrs, _ := json.Marshal(status)
|
||||
monitor.Task.Attrs = string(attrs)
|
||||
|
||||
if err := monitor.Task.Save(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if originSize != monitor.Task.TotalSize {
|
||||
// 文件大小更新后,对文件限制等进行校验
|
||||
if err := monitor.ValidateFile(); err != nil {
|
||||
// 验证失败时取消任务
|
||||
monitor.node.GetAria2Instance().Cancel(monitor.Task)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateFile 上传过程中校验文件大小、文件名
|
||||
func (monitor *Monitor) ValidateFile() error {
|
||||
// 找到任务创建者
|
||||
user := monitor.Task.GetOwner()
|
||||
if user == nil {
|
||||
return common.ErrUserNotFound
|
||||
}
|
||||
|
||||
// 创建文件系统
|
||||
fs, err := filesystem.NewFileSystem(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fs.Recycle()
|
||||
|
||||
if err := fs.SetPolicyFromPath(monitor.Task.Dst); err != nil {
|
||||
return fmt.Errorf("failed to switch policy to target dir: %w", err)
|
||||
}
|
||||
|
||||
// 创建上下文环境
|
||||
file := &fsctx.FileStream{
|
||||
Size: monitor.Task.TotalSize,
|
||||
}
|
||||
|
||||
// 验证用户容量
|
||||
if err := filesystem.HookValidateCapacity(context.Background(), fs, file); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证每个文件
|
||||
for _, fileInfo := range monitor.Task.StatusInfo.Files {
|
||||
if fileInfo.Selected == "true" {
|
||||
// 创建上下文环境
|
||||
fileSize, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
|
||||
file := &fsctx.FileStream{
|
||||
Size: fileSize,
|
||||
Name: filepath.Base(fileInfo.Path),
|
||||
}
|
||||
if err := filesystem.HookValidateFile(context.Background(), fs, file); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Error 任务下载出错处理,返回是否中断监控
|
||||
func (monitor *Monitor) Error(status rpc.StatusInfo) bool {
|
||||
monitor.setErrorStatus(errors.New(status.ErrorMessage))
|
||||
|
||||
// 清理临时文件
|
||||
monitor.RemoveTempFolder()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RemoveTempFolder 清理下载临时目录
|
||||
func (monitor *Monitor) RemoveTempFolder() {
|
||||
monitor.node.GetAria2Instance().DeleteTempFile(monitor.Task)
|
||||
}
|
||||
|
||||
// Complete 完成下载,返回是否中断监控
|
||||
func (monitor *Monitor) Complete(pool task.Pool) bool {
|
||||
// 未开始转存,提交转存任务
|
||||
if monitor.Task.TaskID == 0 {
|
||||
return monitor.transfer(pool)
|
||||
}
|
||||
|
||||
// 做种完成
|
||||
if common.GetStatus(monitor.Task.StatusInfo) == common.Complete {
|
||||
transferTask, err := model.GetTasksByID(monitor.Task.TaskID)
|
||||
if err != nil {
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
// 转存完成,回收下载目录
|
||||
if transferTask.Type == task.TransferTaskType && transferTask.Status >= task.Error {
|
||||
job, err := task.NewRecycleTask(monitor.Task)
|
||||
if err != nil {
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
// 提交回收任务
|
||||
pool.Submit(job)
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (monitor *Monitor) transfer(pool task.Pool) bool {
|
||||
// 创建中转任务
|
||||
file := make([]string, 0, len(monitor.Task.StatusInfo.Files))
|
||||
sizes := make(map[string]uint64, len(monitor.Task.StatusInfo.Files))
|
||||
for i := 0; i < len(monitor.Task.StatusInfo.Files); i++ {
|
||||
fileInfo := monitor.Task.StatusInfo.Files[i]
|
||||
if fileInfo.Selected == "true" {
|
||||
file = append(file, fileInfo.Path)
|
||||
size, _ := strconv.ParseUint(fileInfo.Length, 10, 64)
|
||||
sizes[fileInfo.Path] = size
|
||||
}
|
||||
}
|
||||
|
||||
job, err := task.NewTransferTask(
|
||||
monitor.Task.UserID,
|
||||
file,
|
||||
monitor.Task.Dst,
|
||||
monitor.Task.Parent,
|
||||
true,
|
||||
monitor.node.ID(),
|
||||
sizes,
|
||||
)
|
||||
if err != nil {
|
||||
monitor.setErrorStatus(err)
|
||||
monitor.RemoveTempFolder()
|
||||
return true
|
||||
}
|
||||
|
||||
// 提交中转任务
|
||||
pool.Submit(job)
|
||||
|
||||
// 更新任务ID
|
||||
monitor.Task.TaskID = job.Model().ID
|
||||
monitor.Task.Save()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (monitor *Monitor) setErrorStatus(err error) {
|
||||
monitor.Task.Status = common.Error
|
||||
monitor.Task.Error = err.Error()
|
||||
monitor.Task.Save()
|
||||
}
|
257
pkg/aria2/rpc/README.md
Normal file
257
pkg/aria2/rpc/README.md
Normal file
@ -0,0 +1,257 @@
|
||||
# PACKAGE DOCUMENTATION
|
||||
|
||||
**package rpc**
|
||||
|
||||
import "github.com/matzoe/argo/rpc"
|
||||
|
||||
|
||||
|
||||
## FUNCTIONS
|
||||
|
||||
```
|
||||
func Call(address, method string, params, reply interface{}) error
|
||||
```
|
||||
|
||||
## TYPES
|
||||
|
||||
```
|
||||
type Client struct {
|
||||
// contains filtered or unexported fields
|
||||
}
|
||||
```
|
||||
|
||||
```
|
||||
func New(uri string) *Client
|
||||
```
|
||||
|
||||
```
|
||||
func (id *Client) AddMetalink(uri string, options ...interface{}) (gid string, err error)
|
||||
```
|
||||
`aria2.addMetalink(metalink[, options[, position]])` This method adds Metalink download by uploading ".metalink" file. `metalink` is of type base64 which contains Base64-encoded ".metalink" file. `options` is of type struct and its members are a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at `position` in the
|
||||
waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns array of GID of registered download. If `--rpc-save-upload-metadata` is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".metalink" in the directory specified by `--dir` option. The example of filename is 0a3893293e27ac0490424c06de4d09242215f0a6.metalink. If same file already exists, it is overwritten. If the file cannot be saved successfully or `--rpc-save-upload-metadata` is false, the downloads added by this method are not saved by `--save-session`.
|
||||
|
||||
```
|
||||
func (id *Client) AddTorrent(filename string, options ...interface{}) (gid string, err error)
|
||||
```
|
||||
`aria2.addTorrent(torrent[, uris[, options[, position]]])` This method adds BitTorrent download by uploading ".torrent" file. If you want to add BitTorrent Magnet URI, use `aria2.addUri()` method instead. torrent is of type base64 which contains Base64-encoded ".torrent" file. `uris` is of type array and its element is URI which is of type string. `uris` is used for Web-seeding. For single file torrents, URI can be a complete URI pointing to the resource or if URI ends with /, name in torrent file is added. For multi-file torrents, name and path in torrent are added to form a URI for each file. options is of type struct and its members are
|
||||
a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at `position` in the waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns GID of registered download. If `--rpc-save-upload-metadata` is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".torrent" in the
|
||||
directory specified by `--dir` option. The example of filename is 0a3893293e27ac0490424c06de4d09242215f0a6.torrent. If same file already exists, it is overwritten. If the file cannot be saved successfully or `--rpc-save-upload-metadata` is false, the downloads added by this method are not saved by -`-save-session`.
|
||||
|
||||
```
|
||||
func (id *Client) AddUri(uri string, options ...interface{}) (gid string, err error)
|
||||
```
|
||||
|
||||
`aria2.addUri(uris[, options[, position]])` This method adds new HTTP(S)/FTP/BitTorrent Magnet URI. `uris` is of type array and its element is URI which is of type string. For BitTorrent Magnet URI, `uris` must have only one element and it should be BitTorrent Magnet URI. URIs in uris must point to the same file. If you mix other URIs which point to another file, aria2 does not complain but download may
|
||||
fail. `options` is of type struct and its members are a pair of option name and value. See Options below for more details. If `position` is given as an integer starting from 0, the new download is inserted at position in the waiting queue. If `position` is not given or `position` is larger than the size of the queue, it is appended at the end of the queue. This method returns GID of registered download.
|
||||
|
||||
```
|
||||
func (id *Client) ChangeGlobalOption(options map[string]interface{}) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.changeGlobalOption(options)` This method changes global options dynamically. `options` is of type struct. The following `options` are available:
|
||||
|
||||
download-result
|
||||
log
|
||||
log-level
|
||||
max-concurrent-downloads
|
||||
max-download-result
|
||||
max-overall-download-limit
|
||||
max-overall-upload-limit
|
||||
save-cookies
|
||||
save-session
|
||||
server-stat-of
|
||||
|
||||
In addition to them, options listed in Input File subsection are available, except for following options: `checksum`, `index-out`, `out`, `pause` and `select-file`. Using `log` option, you can dynamically start logging or change log file. To stop logging, give empty string("") as a parameter value. Note that log file is always opened in append mode. This method returns OK for success.
|
||||
|
||||
```
|
||||
func (id *Client) ChangeOption(gid string, options map[string]interface{}) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.changeOption(gid, options)` This method changes options of the download denoted by `gid` dynamically. `gid` is of type string. `options` is of type struct. The following `options` are available for active downloads:
|
||||
|
||||
bt-max-peers
|
||||
bt-request-peer-speed-limit
|
||||
bt-remove-unselected-file
|
||||
force-save
|
||||
max-download-limit
|
||||
max-upload-limit
|
||||
|
||||
For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option. This method returns OK for success.
|
||||
|
||||
```
|
||||
func (id *Client) ChangePosition(gid string, pos int, how string) (p int, err error)
|
||||
```
|
||||
|
||||
`aria2.changePosition(gid, pos, how)` This method changes the position of the download denoted by `gid`. `pos` is of type integer. `how` is of type string. If `how` is `POS_SET`, it moves the download to a position relative to the beginning of the queue. If `how` is `POS_CUR`, it moves the download to a position relative to the current position. If `how` is `POS_END`, it moves the download to a position relative to the end of the queue. If the destination position is less than 0 or beyond the end
|
||||
of the queue, it moves the download to the beginning or the end of the queue respectively. The response is of type integer and it is the destination position.
|
||||
|
||||
```
|
||||
func (id *Client) ChangeUri(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error)
|
||||
```
|
||||
|
||||
`aria2.changeUri(gid, fileIndex, delUris, addUris[, position])` This method removes URIs in `delUris` from and appends URIs in `addUris` to download denoted by gid. `delUris` and `addUris` are list of string. A download can contain multiple files and URIs are attached to each file. `fileIndex` is used to select which file to remove/attach given URIs. `fileIndex` is 1-based. `position` is used to specify where URIs are inserted in the existing waiting URI list. `position` is 0-based. When
|
||||
`position` is omitted, URIs are appended to the back of the list. This method first execute removal and then addition. `position` is the `position` after URIs are removed, not the `position` when this method is called. When removing URI, if same URIs exist in download, only one of them is removed for each URI in delUris. In other words, there are three URIs http://example.org/aria2 and you want remove them all, you
|
||||
have to specify (at least) 3 http://example.org/aria2 in delUris. This method returns a list which contains 2 integers. The first integer is the number of URIs deleted. The second integer is the number of URIs added.
|
||||
|
||||
```
|
||||
func (id *Client) ForcePause(gid string) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.forcePause(pid)` This method pauses the download denoted by `gid`. This method behaves just like aria2.pause() except that this method pauses download without any action which takes time such as contacting BitTorrent tracker.
|
||||
|
||||
```
|
||||
func (id *Client) ForcePauseAll() (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.forcePauseAll()` This method is equal to calling `aria2.forcePause()` for every active/waiting download. This methods returns OK for success.
|
||||
|
||||
```
|
||||
func (id *Client) ForceRemove(gid string) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.forceRemove(gid)` This method removes the download denoted by `gid`. This method behaves just like aria2.remove() except that this method removes download without any action which takes time such as contacting BitTorrent tracker.
|
||||
|
||||
```
|
||||
func (id *Client) ForceShutdown() (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.forceShutdown()` This method shutdowns aria2. This method behaves like `aria2.shutdown()` except that any actions which takes time such as contacting BitTorrent tracker are skipped. This method returns OK.
|
||||
|
||||
```
|
||||
func (id *Client) GetFiles(gid string) (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getFiles(gid)` This method returns file list of the download denoted by `gid`. `gid` is of type string.
|
||||
|
||||
```
|
||||
func (id *Client) GetGlobalOption() (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getGlobalOption()` This method returns global options. The response is of type struct. Its key is the name of option. The value type is string. Note that this method does not return options which have no default value and have not been set by the command-line options, configuration files or RPC methods. Because global options are used as a template for the options of newly added download, the response contains
|
||||
keys returned by `aria2.getOption()` method.
|
||||
|
||||
```
|
||||
func (id *Client) GetGlobalStat() (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getGlobalStat()` This method returns global statistics such as overall download and upload speed.
|
||||
|
||||
```
|
||||
func (id *Client) GetOption(gid string) (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getOption(gid)` This method returns options of the download denoted by `gid`. The response is of type struct. Its key is the name of option. The value type is string. Note that this method does not return options which have no default value and have not been set by the command-line options, configuration files or RPC methods.
|
||||
|
||||
```
|
||||
func (id *Client) GetPeers(gid string) (m []map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getPeers(gid)` This method returns peer list of the download denoted by `gid`. `gid` is of type string. This method is for BitTorrent only.
|
||||
|
||||
```
|
||||
func (id *Client) GetServers(gid string) (m []map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getServers(gid)` This method returns currently connected HTTP(S)/FTP servers of the download denoted by `gid`. `gid` is of type string.
|
||||
|
||||
```
|
||||
func (id *Client) GetSessionInfo() (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getSessionInfo()` This method returns session information.
|
||||
|
||||
```
|
||||
func (id *Client) GetUris(gid string) (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getUris(gid)` This method returns URIs used in the download denoted by `gid`. `gid` is of type string.
|
||||
|
||||
```
|
||||
func (id *Client) GetVersion() (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.getVersion()` This method returns version of the program and the list of enabled features.
|
||||
|
||||
```
|
||||
func (id *Client) Multicall(methods []map[string]interface{}) (r []interface{}, err error)
|
||||
```
|
||||
|
||||
`system.multicall(methods)` This method encapsulates multiple method calls in a single request. `methods` is of type array and its element is struct. The struct contains two keys: `methodName` and `params`. `methodName` is the method name to call and `params` is array containing parameters to the method. This method returns array of responses. The element of array will either be a one-item array containing the return value of each method call or struct of fault element if an encapsulated method call fails.
|
||||
|
||||
```
|
||||
func (id *Client) Pause(gid string) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.pause(gid)` This method pauses the download denoted by `gid`. `gid` is of type string. The status of paused download becomes paused. If the download is active, the download is placed on the first position of waiting queue. As long as the status is paused, the download is not started. To change status to waiting, use `aria2.unpause()` method. This method returns GID of paused download.
|
||||
|
||||
```
|
||||
func (id *Client) PauseAll() (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.pauseAll()` This method is equal to calling `aria2.pause()` for every active/waiting download. This methods returns OK for success.
|
||||
|
||||
```
|
||||
func (id *Client) PurgeDowloadResult() (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.purgeDownloadResult()` This method purges completed/error/removed downloads to free memory. This method returns OK.
|
||||
|
||||
```
|
||||
func (id *Client) Remove(gid string) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.remove(gid)` This method removes the download denoted by gid. `gid` is of type string. If specified download is in progress, it is stopped at first. The status of removed download becomes removed. This method returns GID of removed download.
|
||||
|
||||
```
|
||||
func (id *Client) RemoveDownloadResult(gid string) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.removeDownloadResult(gid)` This method removes completed/error/removed download denoted by `gid` from memory. This method returns OK for success.
|
||||
|
||||
```
|
||||
func (id *Client) Shutdown() (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.shutdown()` This method shutdowns aria2. This method returns OK.
|
||||
|
||||
```
|
||||
func (id *Client) TellActive(keys ...string) (m []map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.tellActive([keys])` This method returns the list of active downloads. The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method. For `keys` parameter, please refer to `aria2.tellStatus()` method.
|
||||
|
||||
```
|
||||
func (id *Client) TellStatus(gid string, keys ...string) (m map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.tellStatus(gid[, keys])` This method returns download progress of the download denoted by `gid`. `gid` is of type string. `keys` is array of string. If it is specified, the response contains only keys in `keys` array. If `keys` is empty or not specified, the response contains all keys. This is useful when you just want specific keys and avoid unnecessary transfers. For example, `aria2.tellStatus("2089b05ecca3d829", ["gid", "status"])` returns `gid` and `status` key.
|
||||
|
||||
```
|
||||
func (id *Client) TellStopped(offset, num int, keys ...string) (m []map[string]interface{}, err error)
|
||||
```
|
||||
|
||||
`aria2.tellStopped(offset, num[, keys])` This method returns the list of stopped download. `offset` is of type integer and specifies the `offset` from the oldest download. `num` is of type integer and specifies the number of downloads to be returned. For keys parameter, please refer to `aria2.tellStatus()` method. `offset` and `num` have the same semantics as `aria2.tellWaiting()` method. The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method.
|
||||
|
||||
```
|
||||
func (id *Client) TellWaiting(offset, num int, keys ...string) (m []map[string]interface{}, err error)
|
||||
```
|
||||
`aria2.tellWaiting(offset, num[, keys])` This method returns the list of waiting download, including paused downloads. `offset` is of type integer and specifies the `offset` from the download waiting at the front. num is of type integer and specifies the number of downloads to be returned. For keys parameter, please refer to aria2.tellStatus() method. If `offset` is a positive integer, this method returns downloads
|
||||
in the range of `[offset, offset + num)`. `offset` can be a negative integer. `offset == -1` points last download in the waiting queue and `offset == -2` points the download before the last download, and so on. The downloads in the response are in reversed order. For example, imagine that three downloads "A","B" and "C" are waiting in this order.
|
||||
|
||||
aria2.tellWaiting(0, 1) returns ["A"].
|
||||
aria2.tellWaiting(1, 2) returns ["B", "C"].
|
||||
aria2.tellWaiting(-1, 2) returns ["C", "B"].
|
||||
|
||||
The response is of type array and its element is the same struct returned by `aria2.tellStatus()` method.
|
||||
|
||||
```
|
||||
func (id *Client) Unpause(gid string) (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.unpause(gid)` This method changes the status of the download denoted by `gid` from paused to waiting. This makes the download eligible to restart. `gid` is of type string. This method returns GID of unpaused download.
|
||||
|
||||
```
|
||||
func (id *Client) UnpauseAll() (g string, err error)
|
||||
```
|
||||
|
||||
`aria2.unpauseAll()` This method is equal to calling `aria2.unpause()` for every active/waiting download. This methods returns OK for success.
|
274
pkg/aria2/rpc/call.go
Normal file
274
pkg/aria2/rpc/call.go
Normal file
@ -0,0 +1,274 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type caller interface {
|
||||
// Call sends a request of rpc to aria2 daemon
|
||||
Call(method string, params, reply interface{}) (err error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
type httpCaller struct {
|
||||
uri string
|
||||
c *http.Client
|
||||
cancel context.CancelFunc
|
||||
wg *sync.WaitGroup
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newHTTPCaller(ctx context.Context, u *url.URL, timeout time.Duration, notifer Notifier) *httpCaller {
|
||||
c := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConnsPerHost: 1,
|
||||
MaxConnsPerHost: 1,
|
||||
// TLSClientConfig: tlsConfig,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: timeout,
|
||||
KeepAlive: 60 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 3 * time.Second,
|
||||
ResponseHeaderTimeout: timeout,
|
||||
},
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
h := &httpCaller{uri: u.String(), c: c, cancel: cancel, wg: &wg}
|
||||
if notifer != nil {
|
||||
h.setNotifier(ctx, *u, notifer)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *httpCaller) Close() (err error) {
|
||||
h.once.Do(func() {
|
||||
h.cancel()
|
||||
h.wg.Wait()
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (h *httpCaller) setNotifier(ctx context.Context, u url.URL, notifer Notifier) (err error) {
|
||||
u.Scheme = "ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.wg.Add(1)
|
||||
go func() {
|
||||
defer h.wg.Done()
|
||||
defer conn.Close()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||
if err := conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
log.Printf("sending websocket close message: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}()
|
||||
h.wg.Add(1)
|
||||
go func() {
|
||||
defer h.wg.Done()
|
||||
var request websocketResponse
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
if err = conn.ReadJSON(&request); err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
log.Printf("conn.ReadJSON|err:%v", err.Error())
|
||||
return
|
||||
}
|
||||
switch request.Method {
|
||||
case "aria2.onDownloadStart":
|
||||
notifer.OnDownloadStart(request.Params)
|
||||
case "aria2.onDownloadPause":
|
||||
notifer.OnDownloadPause(request.Params)
|
||||
case "aria2.onDownloadStop":
|
||||
notifer.OnDownloadStop(request.Params)
|
||||
case "aria2.onDownloadComplete":
|
||||
notifer.OnDownloadComplete(request.Params)
|
||||
case "aria2.onDownloadError":
|
||||
notifer.OnDownloadError(request.Params)
|
||||
case "aria2.onBtDownloadComplete":
|
||||
notifer.OnBtDownloadComplete(request.Params)
|
||||
default:
|
||||
log.Printf("unexpected notification: %s", request.Method)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
func (h httpCaller) Call(method string, params, reply interface{}) (err error) {
|
||||
payload, err := EncodeClientRequest(method, params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := h.c.Post(h.uri, "application/json", payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = DecodeClientResponse(r.Body, &reply)
|
||||
r.Body.Close()
|
||||
return
|
||||
}
|
||||
|
||||
type websocketCaller struct {
|
||||
conn *websocket.Conn
|
||||
sendChan chan *sendRequest
|
||||
cancel context.CancelFunc
|
||||
wg *sync.WaitGroup
|
||||
once sync.Once
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func newWebsocketCaller(ctx context.Context, uri string, timeout time.Duration, notifier Notifier) (*websocketCaller, error) {
|
||||
var header = http.Header{}
|
||||
conn, _, err := websocket.DefaultDialer.Dial(uri, header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sendChan := make(chan *sendRequest, 16)
|
||||
var wg sync.WaitGroup
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
w := &websocketCaller{conn: conn, wg: &wg, cancel: cancel, sendChan: sendChan, timeout: timeout}
|
||||
processor := NewResponseProcessor()
|
||||
wg.Add(1)
|
||||
go func() { // routine:recv
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
var resp websocketResponse
|
||||
if err := conn.ReadJSON(&resp); err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
log.Printf("conn.ReadJSON|err:%v", err.Error())
|
||||
return
|
||||
}
|
||||
if resp.Id == nil { // RPC notifications
|
||||
if notifier != nil {
|
||||
switch resp.Method {
|
||||
case "aria2.onDownloadStart":
|
||||
notifier.OnDownloadStart(resp.Params)
|
||||
case "aria2.onDownloadPause":
|
||||
notifier.OnDownloadPause(resp.Params)
|
||||
case "aria2.onDownloadStop":
|
||||
notifier.OnDownloadStop(resp.Params)
|
||||
case "aria2.onDownloadComplete":
|
||||
notifier.OnDownloadComplete(resp.Params)
|
||||
case "aria2.onDownloadError":
|
||||
notifier.OnDownloadError(resp.Params)
|
||||
case "aria2.onBtDownloadComplete":
|
||||
notifier.OnBtDownloadComplete(resp.Params)
|
||||
default:
|
||||
log.Printf("unexpected notification: %s", resp.Method)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
processor.Process(resp.clientResponse)
|
||||
}
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() { // routine:send
|
||||
defer wg.Done()
|
||||
defer cancel()
|
||||
defer w.conn.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := w.conn.WriteMessage(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
|
||||
log.Printf("sending websocket close message: %v", err)
|
||||
}
|
||||
return
|
||||
case req := <-sendChan:
|
||||
processor.Add(req.request.Id, func(resp clientResponse) error {
|
||||
err := resp.decode(req.reply)
|
||||
req.cancel()
|
||||
return err
|
||||
})
|
||||
w.conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
w.conn.WriteJSON(req.request)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *websocketCaller) Close() (err error) {
|
||||
w.once.Do(func() {
|
||||
w.cancel()
|
||||
w.wg.Wait()
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (w websocketCaller) Call(method string, params, reply interface{}) (err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), w.timeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case w.sendChan <- &sendRequest{cancel: cancel, request: &clientRequest{
|
||||
Version: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
Id: reqid(),
|
||||
}, reply: reply}:
|
||||
|
||||
default:
|
||||
return errors.New("sending channel blocking")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := ctx.Err(); err == context.DeadlineExceeded {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type sendRequest struct {
|
||||
cancel context.CancelFunc
|
||||
request *clientRequest
|
||||
reply interface{}
|
||||
}
|
||||
|
||||
var reqid = func() func() uint64 {
|
||||
var id = uint64(time.Now().UnixNano())
|
||||
return func() uint64 {
|
||||
return atomic.AddUint64(&id, 1)
|
||||
}
|
||||
}()
|
656
pkg/aria2/rpc/client.go
Normal file
656
pkg/aria2/rpc/client.go
Normal file
@ -0,0 +1,656 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Option is a container for specifying Call parameters and returning results
|
||||
type Option map[string]interface{}
|
||||
|
||||
type Client interface {
|
||||
Protocol
|
||||
Close() error
|
||||
}
|
||||
|
||||
type client struct {
|
||||
caller
|
||||
url *url.URL
|
||||
token string
|
||||
}
|
||||
|
||||
var (
|
||||
errInvalidParameter = errors.New("invalid parameter")
|
||||
errNotImplemented = errors.New("not implemented")
|
||||
errConnTimeout = errors.New("connect to aria2 daemon timeout")
|
||||
)
|
||||
|
||||
// New returns an instance of Client
|
||||
func New(ctx context.Context, uri string, token string, timeout time.Duration, notifier Notifier) (Client, error) {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var caller caller
|
||||
switch u.Scheme {
|
||||
case "http", "https":
|
||||
caller = newHTTPCaller(ctx, u, timeout, notifier)
|
||||
case "ws", "wss":
|
||||
caller, err = newWebsocketCaller(ctx, u.String(), timeout, notifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
default:
|
||||
return nil, errInvalidParameter
|
||||
}
|
||||
c := &client{caller: caller, url: u, token: token}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// `aria2.addUri([secret, ]uris[, options[, position]])`
|
||||
// This method adds a new download. uris is an array of HTTP/FTP/SFTP/BitTorrent URIs (strings) pointing to the same resource.
|
||||
// If you mix URIs pointing to different resources, then the download may fail or be corrupted without aria2 complaining.
|
||||
// When adding BitTorrent Magnet URIs, uris must have only one element and it should be BitTorrent Magnet URI.
|
||||
// options is a struct and its members are pairs of option name and value.
|
||||
// If position is given, it must be an integer starting from 0.
|
||||
// The new download will be inserted at position in the waiting queue.
|
||||
// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue.
|
||||
// This method returns the GID of the newly registered download.
|
||||
func (c *client) AddURI(uri string, options ...interface{}) (gid string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, []string{uri})
|
||||
if options != nil {
|
||||
params = append(params, options...)
|
||||
}
|
||||
err = c.Call(aria2AddURI, params, &gid)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.addTorrent([secret, ]torrent[, uris[, options[, position]]])`
|
||||
// This method adds a BitTorrent download by uploading a ".torrent" file.
|
||||
// If you want to add a BitTorrent Magnet URI, use the aria2.addUri() method instead.
|
||||
// torrent must be a base64-encoded string containing the contents of the ".torrent" file.
|
||||
// uris is an array of URIs (string). uris is used for Web-seeding.
|
||||
// For single file torrents, the URI can be a complete URI pointing to the resource; if URI ends with /, name in torrent file is added.
|
||||
// For multi-file torrents, name and path in torrent are added to form a URI for each file. options is a struct and its members are pairs of option name and value.
|
||||
// If position is given, it must be an integer starting from 0.
|
||||
// The new download will be inserted at position in the waiting queue.
|
||||
// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue.
|
||||
// This method returns the GID of the newly registered download.
|
||||
// If --rpc-save-upload-metadata is true, the uploaded data is saved as a file named as the hex string of SHA-1 hash of data plus ".torrent" in the directory specified by --dir option.
|
||||
// E.g. a file name might be 0a3893293e27ac0490424c06de4d09242215f0a6.torrent.
|
||||
// If a file with the same name already exists, it is overwritten!
|
||||
// If the file cannot be saved successfully or --rpc-save-upload-metadata is false, the downloads added by this method are not saved by --save-session.
|
||||
func (c *client) AddTorrent(filename string, options ...interface{}) (gid string, err error) {
|
||||
co, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
file := base64.StdEncoding.EncodeToString(co)
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, file)
|
||||
if options != nil {
|
||||
params = append(params, options...)
|
||||
}
|
||||
err = c.Call(aria2AddTorrent, params, &gid)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.addMetalink([secret, ]metalink[, options[, position]])`
|
||||
// This method adds a Metalink download by uploading a ".metalink" file.
|
||||
// metalink is a base64-encoded string which contains the contents of the ".metalink" file.
|
||||
// options is a struct and its members are pairs of option name and value.
|
||||
// If position is given, it must be an integer starting from 0.
|
||||
// The new download will be inserted at position in the waiting queue.
|
||||
// If position is omitted or position is larger than the current size of the queue, the new download is appended to the end of the queue.
|
||||
// This method returns an array of GIDs of newly registered downloads.
|
||||
// If --rpc-save-upload-metadata is true, the uploaded data is saved as a file named hex string of SHA-1 hash of data plus ".metalink" in the directory specified by --dir option.
|
||||
// E.g. a file name might be 0a3893293e27ac0490424c06de4d09242215f0a6.metalink.
|
||||
// If a file with the same name already exists, it is overwritten!
|
||||
// If the file cannot be saved successfully or --rpc-save-upload-metadata is false, the downloads added by this method are not saved by --save-session.
|
||||
func (c *client) AddMetalink(filename string, options ...interface{}) (gid []string, err error) {
|
||||
co, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
file := base64.StdEncoding.EncodeToString(co)
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, file)
|
||||
if options != nil {
|
||||
params = append(params, options...)
|
||||
}
|
||||
err = c.Call(aria2AddMetalink, params, &gid)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.remove([secret, ]gid)`
|
||||
// This method removes the download denoted by gid (string).
|
||||
// If the specified download is in progress, it is first stopped.
|
||||
// The status of the removed download becomes removed.
|
||||
// This method returns GID of removed download.
|
||||
func (c *client) Remove(gid string) (g string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2Remove, params, &g)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.forceRemove([secret, ]gid)`
|
||||
// This method removes the download denoted by gid.
|
||||
// This method behaves just like aria2.remove() except that this method removes the download without performing any actions which take time, such as contacting BitTorrent trackers to unregister the download first.
|
||||
func (c *client) ForceRemove(gid string) (g string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2ForceRemove, params, &g)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.pause([secret, ]gid)`
|
||||
// This method pauses the download denoted by gid (string).
|
||||
// The status of paused download becomes paused.
|
||||
// If the download was active, the download is placed in the front of waiting queue.
|
||||
// While the status is paused, the download is not started.
|
||||
// To change status to waiting, use the aria2.unpause() method.
|
||||
// This method returns GID of paused download.
|
||||
func (c *client) Pause(gid string) (g string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2Pause, params, &g)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.pauseAll([secret])`
|
||||
// This method is equal to calling aria2.pause() for every active/waiting download.
|
||||
// This methods returns OK.
|
||||
func (c *client) PauseAll() (ok string, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2PauseAll, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.forcePause([secret, ]gid)`
|
||||
// This method pauses the download denoted by gid.
|
||||
// This method behaves just like aria2.pause() except that this method pauses downloads without performing any actions which take time, such as contacting BitTorrent trackers to unregister the download first.
|
||||
func (c *client) ForcePause(gid string) (g string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2ForcePause, params, &g)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.forcePauseAll([secret])`
|
||||
// This method is equal to calling aria2.forcePause() for every active/waiting download.
|
||||
// This methods returns OK.
|
||||
func (c *client) ForcePauseAll() (ok string, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2ForcePauseAll, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.unpause([secret, ]gid)`
|
||||
// This method changes the status of the download denoted by gid (string) from paused to waiting, making the download eligible to be restarted.
|
||||
// This method returns the GID of the unpaused download.
|
||||
func (c *client) Unpause(gid string) (g string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2Unpause, params, &g)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.unpauseAll([secret])`
|
||||
// This method is equal to calling aria2.unpause() for every active/waiting download.
|
||||
// This methods returns OK.
|
||||
func (c *client) UnpauseAll() (ok string, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2UnpauseAll, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.tellStatus([secret, ]gid[, keys])`
|
||||
// This method returns the progress of the download denoted by gid (string).
|
||||
// keys is an array of strings.
|
||||
// If specified, the response contains only keys in the keys array.
|
||||
// If keys is empty or omitted, the response contains all keys.
|
||||
// This is useful when you just want specific keys and avoid unnecessary transfers.
|
||||
// For example, aria2.tellStatus("2089b05ecca3d829", ["gid", "status"]) returns the gid and status keys only.
|
||||
// The response is a struct and contains following keys. Values are strings.
|
||||
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.tellStatus
|
||||
func (c *client) TellStatus(gid string, keys ...string) (info StatusInfo, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
if keys != nil {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err = c.Call(aria2TellStatus, params, &info)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getUris([secret, ]gid)`
|
||||
// This method returns the URIs used in the download denoted by gid (string).
|
||||
// The response is an array of structs and it contains following keys. Values are string.
|
||||
// uri URI
|
||||
// status 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue.
|
||||
func (c *client) GetURIs(gid string) (infos []URIInfo, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2GetURIs, params, &infos)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getFiles([secret, ]gid)`
|
||||
// This method returns the file list of the download denoted by gid (string).
|
||||
// The response is an array of structs which contain following keys. Values are strings.
|
||||
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getFiles
|
||||
func (c *client) GetFiles(gid string) (infos []FileInfo, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2GetFiles, params, &infos)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getPeers([secret, ]gid)`
|
||||
// This method returns a list peers of the download denoted by gid (string).
|
||||
// This method is for BitTorrent only.
|
||||
// The response is an array of structs and contains the following keys. Values are strings.
|
||||
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getPeers
|
||||
func (c *client) GetPeers(gid string) (infos []PeerInfo, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2GetPeers, params, &infos)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getServers([secret, ]gid)`
|
||||
// This method returns currently connected HTTP(S)/FTP/SFTP servers of the download denoted by gid (string).
|
||||
// The response is an array of structs and contains the following keys. Values are strings.
|
||||
// https://aria2.github.io/manual/en/html/aria2c.html#aria2.getServers
|
||||
func (c *client) GetServers(gid string) (infos []ServerInfo, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2GetServers, params, &infos)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.tellActive([secret][, keys])`
|
||||
// This method returns a list of active downloads.
|
||||
// The response is an array of the same structs as returned by the aria2.tellStatus() method.
|
||||
// For the keys parameter, please refer to the aria2.tellStatus() method.
|
||||
func (c *client) TellActive(keys ...string) (infos []StatusInfo, err error) {
|
||||
params := make([]interface{}, 0, 1)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
if keys != nil {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err = c.Call(aria2TellActive, params, &infos)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.tellWaiting([secret, ]offset, num[, keys])`
|
||||
// This method returns a list of waiting downloads, including paused ones.
|
||||
// offset is an integer and specifies the offset from the download waiting at the front.
|
||||
// num is an integer and specifies the max. number of downloads to be returned.
|
||||
// For the keys parameter, please refer to the aria2.tellStatus() method.
|
||||
// If offset is a positive integer, this method returns downloads in the range of [offset, offset + num).
|
||||
// offset can be a negative integer. offset == -1 points last download in the waiting queue and offset == -2 points the download before the last download, and so on.
|
||||
// Downloads in the response are in reversed order then.
|
||||
// For example, imagine three downloads "A","B" and "C" are waiting in this order.
|
||||
// aria2.tellWaiting(0, 1) returns ["A"].
|
||||
// aria2.tellWaiting(1, 2) returns ["B", "C"].
|
||||
// aria2.tellWaiting(-1, 2) returns ["C", "B"].
|
||||
// The response is an array of the same structs as returned by aria2.tellStatus() method.
|
||||
func (c *client) TellWaiting(offset, num int, keys ...string) (infos []StatusInfo, err error) {
|
||||
params := make([]interface{}, 0, 3)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, offset)
|
||||
params = append(params, num)
|
||||
if keys != nil {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err = c.Call(aria2TellWaiting, params, &infos)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.tellStopped([secret, ]offset, num[, keys])`
|
||||
// This method returns a list of stopped downloads.
|
||||
// offset is an integer and specifies the offset from the least recently stopped download.
|
||||
// num is an integer and specifies the max. number of downloads to be returned.
|
||||
// For the keys parameter, please refer to the aria2.tellStatus() method.
|
||||
// offset and num have the same semantics as described in the aria2.tellWaiting() method.
|
||||
// The response is an array of the same structs as returned by the aria2.tellStatus() method.
|
||||
func (c *client) TellStopped(offset, num int, keys ...string) (infos []StatusInfo, err error) {
|
||||
params := make([]interface{}, 0, 3)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, offset)
|
||||
params = append(params, num)
|
||||
if keys != nil {
|
||||
params = append(params, keys)
|
||||
}
|
||||
err = c.Call(aria2TellStopped, params, &infos)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.changePosition([secret, ]gid, pos, how)`
|
||||
// This method changes the position of the download denoted by gid in the queue.
|
||||
// pos is an integer. how is a string.
|
||||
// If how is POS_SET, it moves the download to a position relative to the beginning of the queue.
|
||||
// If how is POS_CUR, it moves the download to a position relative to the current position.
|
||||
// If how is POS_END, it moves the download to a position relative to the end of the queue.
|
||||
// If the destination position is less than 0 or beyond the end of the queue, it moves the download to the beginning or the end of the queue respectively.
|
||||
// The response is an integer denoting the resulting position.
|
||||
// For example, if GID#2089b05ecca3d829 is currently in position 3, aria2.changePosition('2089b05ecca3d829', -1, 'POS_CUR') will change its position to 2. Additionally aria2.changePosition('2089b05ecca3d829', 0, 'POS_SET') will change its position to 0 (the beginning of the queue).
|
||||
func (c *client) ChangePosition(gid string, pos int, how string) (p int, err error) {
|
||||
params := make([]interface{}, 0, 3)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
params = append(params, pos)
|
||||
params = append(params, how)
|
||||
err = c.Call(aria2ChangePosition, params, &p)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.changeUri([secret, ]gid, fileIndex, delUris, addUris[, position])`
|
||||
// This method removes the URIs in delUris from and appends the URIs in addUris to download denoted by gid.
|
||||
// delUris and addUris are lists of strings.
|
||||
// A download can contain multiple files and URIs are attached to each file.
|
||||
// fileIndex is used to select which file to remove/attach given URIs. fileIndex is 1-based.
|
||||
// position is used to specify where URIs are inserted in the existing waiting URI list. position is 0-based.
|
||||
// When position is omitted, URIs are appended to the back of the list.
|
||||
// This method first executes the removal and then the addition.
|
||||
// position is the position after URIs are removed, not the position when this method is called.
|
||||
// When removing an URI, if the same URIs exist in download, only one of them is removed for each URI in delUris.
|
||||
// In other words, if there are three URIs http://example.org/aria2 and you want remove them all, you have to specify (at least) 3 http://example.org/aria2 in delUris.
|
||||
// This method returns a list which contains two integers.
|
||||
// The first integer is the number of URIs deleted.
|
||||
// The second integer is the number of URIs added.
|
||||
func (c *client) ChangeURI(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error) {
|
||||
params := make([]interface{}, 0, 5)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
params = append(params, fileindex)
|
||||
params = append(params, delUris)
|
||||
params = append(params, addUris)
|
||||
if position != nil {
|
||||
params = append(params, position[0])
|
||||
}
|
||||
err = c.Call(aria2ChangeURI, params, &p)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getOption([secret, ]gid)`
|
||||
// This method returns options of the download denoted by gid.
|
||||
// The response is a struct where keys are the names of options.
|
||||
// The values are strings.
|
||||
// Note that this method does not return options which have no default value and have not been set on the command-line, in configuration files or RPC methods.
|
||||
func (c *client) GetOption(gid string) (m Option, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2GetOption, params, &m)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.changeOption([secret, ]gid, options)`
|
||||
// This method changes options of the download denoted by gid (string) dynamically. options is a struct.
|
||||
// The following options are available for active downloads:
|
||||
// bt-max-peers
|
||||
// bt-request-peer-speed-limit
|
||||
// bt-remove-unselected-file
|
||||
// force-save
|
||||
// max-download-limit
|
||||
// max-upload-limit
|
||||
// For waiting or paused downloads, in addition to the above options, options listed in Input File subsection are available, except for following options: dry-run, metalink-base-uri, parameterized-uri, pause, piece-length and rpc-save-upload-metadata option.
|
||||
// This method returns OK for success.
|
||||
func (c *client) ChangeOption(gid string, option Option) (ok string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
if option != nil {
|
||||
params = append(params, option)
|
||||
}
|
||||
err = c.Call(aria2ChangeOption, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getGlobalOption([secret])`
|
||||
// This method returns the global options.
|
||||
// The response is a struct.
|
||||
// Its keys are the names of options.
|
||||
// Values are strings.
|
||||
// Note that this method does not return options which have no default value and have not been set on the command-line, in configuration files or RPC methods. Because global options are used as a template for the options of newly added downloads, the response contains keys returned by the aria2.getOption() method.
|
||||
func (c *client) GetGlobalOption() (m Option, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2GetGlobalOption, params, &m)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.changeGlobalOption([secret, ]options)`
|
||||
// This method changes global options dynamically.
|
||||
// options is a struct.
|
||||
// The following options are available:
|
||||
// bt-max-open-files
|
||||
// download-result
|
||||
// log
|
||||
// log-level
|
||||
// max-concurrent-downloads
|
||||
// max-download-result
|
||||
// max-overall-download-limit
|
||||
// max-overall-upload-limit
|
||||
// save-cookies
|
||||
// save-session
|
||||
// server-stat-of
|
||||
// In addition, options listed in the Input File subsection are available, except for following options: checksum, index-out, out, pause and select-file.
|
||||
// With the log option, you can dynamically start logging or change log file.
|
||||
// To stop logging, specify an empty string("") as the parameter value.
|
||||
// Note that log file is always opened in append mode.
|
||||
// This method returns OK for success.
|
||||
func (c *client) ChangeGlobalOption(options Option) (ok string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, options)
|
||||
err = c.Call(aria2ChangeGlobalOption, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getGlobalStat([secret])`
|
||||
// This method returns global statistics such as the overall download and upload speeds.
|
||||
// The response is a struct and contains the following keys. Values are strings.
|
||||
// downloadSpeed Overall download speed (byte/sec).
|
||||
// uploadSpeed Overall upload speed(byte/sec).
|
||||
// numActive The number of active downloads.
|
||||
// numWaiting The number of waiting downloads.
|
||||
// numStopped The number of stopped downloads in the current session.
|
||||
// This value is capped by the --max-download-result option.
|
||||
// numStoppedTotal The number of stopped downloads in the current session and not capped by the --max-download-result option.
|
||||
func (c *client) GetGlobalStat() (info GlobalStatInfo, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2GetGlobalStat, params, &info)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.purgeDownloadResult([secret])`
|
||||
// This method purges completed/error/removed downloads to free memory.
|
||||
// This method returns OK.
|
||||
func (c *client) PurgeDownloadResult() (ok string, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2PurgeDownloadResult, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.removeDownloadResult([secret, ]gid)`
|
||||
// This method removes a completed/error/removed download denoted by gid from memory.
|
||||
// This method returns OK for success.
|
||||
func (c *client) RemoveDownloadResult(gid string) (ok string, err error) {
|
||||
params := make([]interface{}, 0, 2)
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
params = append(params, gid)
|
||||
err = c.Call(aria2RemoveDownloadResult, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getVersion([secret])`
|
||||
// This method returns the version of aria2 and the list of enabled features.
|
||||
// The response is a struct and contains following keys.
|
||||
// version Version number of aria2 as a string.
|
||||
// enabledFeatures List of enabled features. Each feature is given as a string.
|
||||
func (c *client) GetVersion() (info VersionInfo, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2GetVersion, params, &info)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.getSessionInfo([secret])`
|
||||
// This method returns session information.
|
||||
// The response is a struct and contains following key.
|
||||
// sessionId Session ID, which is generated each time when aria2 is invoked.
|
||||
func (c *client) GetSessionInfo() (info SessionInfo, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2GetSessionInfo, params, &info)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.shutdown([secret])`
|
||||
// This method shutdowns aria2.
|
||||
// This method returns OK.
|
||||
func (c *client) Shutdown() (ok string, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2Shutdown, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.forceShutdown([secret])`
|
||||
// This method shuts down aria2().
|
||||
// This method behaves like :func:'aria2.shutdown` without performing any actions which take time, such as contacting BitTorrent trackers to unregister downloads first.
|
||||
// This method returns OK.
|
||||
func (c *client) ForceShutdown() (ok string, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2ForceShutdown, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `aria2.saveSession([secret])`
|
||||
// This method saves the current session to a file specified by the --save-session option.
|
||||
// This method returns OK if it succeeds.
|
||||
func (c *client) SaveSession() (ok string, err error) {
|
||||
params := []string{}
|
||||
if c.token != "" {
|
||||
params = append(params, "token:"+c.token)
|
||||
}
|
||||
err = c.Call(aria2SaveSession, params, &ok)
|
||||
return
|
||||
}
|
||||
|
||||
// `system.multicall(methods)`
|
||||
// This methods encapsulates multiple method calls in a single request.
|
||||
// methods is an array of structs.
|
||||
// The structs contain two keys: methodName and params.
|
||||
// methodName is the method name to call and params is array containing parameters to the method call.
|
||||
// This method returns an array of responses.
|
||||
// The elements will be either a one-item array containing the return value of the method call or a struct of fault element if an encapsulated method call fails.
|
||||
func (c *client) Multicall(methods []Method) (r []interface{}, err error) {
|
||||
if len(methods) == 0 {
|
||||
err = errInvalidParameter
|
||||
return
|
||||
}
|
||||
err = c.Call(aria2Multicall, methods, &r)
|
||||
return
|
||||
}
|
||||
|
||||
// `system.listMethods()`
|
||||
// This method returns the all available RPC methods in an array of string.
|
||||
// Unlike other methods, this method does not require secret token.
|
||||
// This is safe because this method jsut returns the available method names.
|
||||
func (c *client) ListMethods() (methods []string, err error) {
|
||||
err = c.Call(aria2ListMethods, []string{}, &methods)
|
||||
return
|
||||
}
|
39
pkg/aria2/rpc/const.go
Normal file
39
pkg/aria2/rpc/const.go
Normal file
@ -0,0 +1,39 @@
|
||||
package rpc
|
||||
|
||||
const (
|
||||
aria2AddURI = "aria2.addUri"
|
||||
aria2AddTorrent = "aria2.addTorrent"
|
||||
aria2AddMetalink = "aria2.addMetalink"
|
||||
aria2Remove = "aria2.remove"
|
||||
aria2ForceRemove = "aria2.forceRemove"
|
||||
aria2Pause = "aria2.pause"
|
||||
aria2PauseAll = "aria2.pauseAll"
|
||||
aria2ForcePause = "aria2.forcePause"
|
||||
aria2ForcePauseAll = "aria2.forcePauseAll"
|
||||
aria2Unpause = "aria2.unpause"
|
||||
aria2UnpauseAll = "aria2.unpauseAll"
|
||||
aria2TellStatus = "aria2.tellStatus"
|
||||
aria2GetURIs = "aria2.getUris"
|
||||
aria2GetFiles = "aria2.getFiles"
|
||||
aria2GetPeers = "aria2.getPeers"
|
||||
aria2GetServers = "aria2.getServers"
|
||||
aria2TellActive = "aria2.tellActive"
|
||||
aria2TellWaiting = "aria2.tellWaiting"
|
||||
aria2TellStopped = "aria2.tellStopped"
|
||||
aria2ChangePosition = "aria2.changePosition"
|
||||
aria2ChangeURI = "aria2.changeUri"
|
||||
aria2GetOption = "aria2.getOption"
|
||||
aria2ChangeOption = "aria2.changeOption"
|
||||
aria2GetGlobalOption = "aria2.getGlobalOption"
|
||||
aria2ChangeGlobalOption = "aria2.changeGlobalOption"
|
||||
aria2GetGlobalStat = "aria2.getGlobalStat"
|
||||
aria2PurgeDownloadResult = "aria2.purgeDownloadResult"
|
||||
aria2RemoveDownloadResult = "aria2.removeDownloadResult"
|
||||
aria2GetVersion = "aria2.getVersion"
|
||||
aria2GetSessionInfo = "aria2.getSessionInfo"
|
||||
aria2Shutdown = "aria2.shutdown"
|
||||
aria2ForceShutdown = "aria2.forceShutdown"
|
||||
aria2SaveSession = "aria2.saveSession"
|
||||
aria2Multicall = "system.multicall"
|
||||
aria2ListMethods = "system.listMethods"
|
||||
)
|
116
pkg/aria2/rpc/json2.go
Normal file
116
pkg/aria2/rpc/json2.go
Normal file
@ -0,0 +1,116 @@
|
||||
package rpc
|
||||
|
||||
// based on "github.com/gorilla/rpc/v2/json2"
|
||||
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Request and Response
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// clientRequest represents a JSON-RPC request sent by a client.
|
||||
type clientRequest struct {
|
||||
// JSON-RPC protocol.
|
||||
Version string `json:"jsonrpc"`
|
||||
|
||||
// A String containing the name of the method to be invoked.
|
||||
Method string `json:"method"`
|
||||
|
||||
// Object to pass as request parameter to the method.
|
||||
Params interface{} `json:"params"`
|
||||
|
||||
// The request id. This can be of any type. It is used to match the
|
||||
// response with the request that it is replying to.
|
||||
Id uint64 `json:"id"`
|
||||
}
|
||||
|
||||
// clientResponse represents a JSON-RPC response returned to a client.
|
||||
type clientResponse struct {
|
||||
Version string `json:"jsonrpc"`
|
||||
Result *json.RawMessage `json:"result"`
|
||||
Error *json.RawMessage `json:"error"`
|
||||
Id *uint64 `json:"id"`
|
||||
}
|
||||
|
||||
// EncodeClientRequest encodes parameters for a JSON-RPC client request.
|
||||
func EncodeClientRequest(method string, args interface{}) (*bytes.Buffer, error) {
|
||||
var buf bytes.Buffer
|
||||
c := &clientRequest{
|
||||
Version: "2.0",
|
||||
Method: method,
|
||||
Params: args,
|
||||
Id: reqid(),
|
||||
}
|
||||
if err := json.NewEncoder(&buf).Encode(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
func (c clientResponse) decode(reply interface{}) error {
|
||||
if c.Error != nil {
|
||||
jsonErr := &Error{}
|
||||
if err := json.Unmarshal(*c.Error, jsonErr); err != nil {
|
||||
return &Error{
|
||||
Code: E_SERVER,
|
||||
Message: string(*c.Error),
|
||||
}
|
||||
}
|
||||
return jsonErr
|
||||
}
|
||||
|
||||
if c.Result == nil {
|
||||
return ErrNullResult
|
||||
}
|
||||
|
||||
return json.Unmarshal(*c.Result, reply)
|
||||
}
|
||||
|
||||
// DecodeClientResponse decodes the response body of a client request into
|
||||
// the interface reply.
|
||||
func DecodeClientResponse(r io.Reader, reply interface{}) error {
|
||||
var c clientResponse
|
||||
if err := json.NewDecoder(r).Decode(&c); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.decode(reply)
|
||||
}
|
||||
|
||||
type ErrorCode int
|
||||
|
||||
const (
|
||||
E_PARSE ErrorCode = -32700
|
||||
E_INVALID_REQ ErrorCode = -32600
|
||||
E_NO_METHOD ErrorCode = -32601
|
||||
E_BAD_PARAMS ErrorCode = -32602
|
||||
E_INTERNAL ErrorCode = -32603
|
||||
E_SERVER ErrorCode = -32000
|
||||
)
|
||||
|
||||
var ErrNullResult = errors.New("result is null")
|
||||
|
||||
type Error struct {
|
||||
// A Number that indicates the error type that occurred.
|
||||
Code ErrorCode `json:"code"` /* required */
|
||||
|
||||
// A String providing a short description of the error.
|
||||
// The message SHOULD be limited to a concise single sentence.
|
||||
Message string `json:"message"` /* required */
|
||||
|
||||
// A Primitive or Structured value that contains additional information about the error.
|
||||
Data interface{} `json:"data"` /* optional */
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
return e.Message
|
||||
}
|
44
pkg/aria2/rpc/notification.go
Normal file
44
pkg/aria2/rpc/notification.go
Normal file
@ -0,0 +1,44 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"log"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Gid string `json:"gid"` // GID of the download
|
||||
}
|
||||
|
||||
// The RPC server might send notifications to the client.
|
||||
// Notifications is unidirectional, therefore the client which receives the notification must not respond to it.
|
||||
// The method signature of a notification is much like a normal method request but lacks the id key
|
||||
|
||||
type websocketResponse struct {
|
||||
clientResponse
|
||||
Method string `json:"method"`
|
||||
Params []Event `json:"params"`
|
||||
}
|
||||
|
||||
// Notifier handles rpc notification from aria2 server
|
||||
type Notifier interface {
|
||||
// OnDownloadStart will be sent when a download is started.
|
||||
OnDownloadStart([]Event)
|
||||
// OnDownloadPause will be sent when a download is paused.
|
||||
OnDownloadPause([]Event)
|
||||
// OnDownloadStop will be sent when a download is stopped by the user.
|
||||
OnDownloadStop([]Event)
|
||||
// OnDownloadComplete will be sent when a download is complete. For BitTorrent downloads, this notification is sent when the download is complete and seeding is over.
|
||||
OnDownloadComplete([]Event)
|
||||
// OnDownloadError will be sent when a download is stopped due to an error.
|
||||
OnDownloadError([]Event)
|
||||
// OnBtDownloadComplete will be sent when a torrent download is complete but seeding is still going on.
|
||||
OnBtDownloadComplete([]Event)
|
||||
}
|
||||
|
||||
type DummyNotifier struct{}
|
||||
|
||||
func (DummyNotifier) OnDownloadStart(events []Event) { log.Printf("%s started.", events) }
|
||||
func (DummyNotifier) OnDownloadPause(events []Event) { log.Printf("%s paused.", events) }
|
||||
func (DummyNotifier) OnDownloadStop(events []Event) { log.Printf("%s stopped.", events) }
|
||||
func (DummyNotifier) OnDownloadComplete(events []Event) { log.Printf("%s completed.", events) }
|
||||
func (DummyNotifier) OnDownloadError(events []Event) { log.Printf("%s error.", events) }
|
||||
func (DummyNotifier) OnBtDownloadComplete(events []Event) { log.Printf("bt %s completed.", events) }
|
42
pkg/aria2/rpc/proc.go
Normal file
42
pkg/aria2/rpc/proc.go
Normal file
@ -0,0 +1,42 @@
|
||||
package rpc
|
||||
|
||||
import "sync"
|
||||
|
||||
type ResponseProcFn func(resp clientResponse) error
|
||||
|
||||
type ResponseProcessor struct {
|
||||
cbs map[uint64]ResponseProcFn
|
||||
mu *sync.RWMutex
|
||||
}
|
||||
|
||||
func NewResponseProcessor() *ResponseProcessor {
|
||||
return &ResponseProcessor{
|
||||
make(map[uint64]ResponseProcFn),
|
||||
&sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ResponseProcessor) Add(id uint64, fn ResponseProcFn) {
|
||||
r.mu.Lock()
|
||||
r.cbs[id] = fn
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *ResponseProcessor) remove(id uint64) {
|
||||
r.mu.Lock()
|
||||
delete(r.cbs, id)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
// Process called by recv routine
|
||||
func (r *ResponseProcessor) Process(resp clientResponse) error {
|
||||
id := *resp.Id
|
||||
r.mu.RLock()
|
||||
fn, ok := r.cbs[id]
|
||||
r.mu.RUnlock()
|
||||
if ok && fn != nil {
|
||||
defer r.remove(id)
|
||||
return fn(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
40
pkg/aria2/rpc/proto.go
Normal file
40
pkg/aria2/rpc/proto.go
Normal file
@ -0,0 +1,40 @@
|
||||
package rpc
|
||||
|
||||
// Protocol is a set of rpc methods that aria2 daemon supports
|
||||
type Protocol interface {
|
||||
AddURI(uri string, options ...interface{}) (gid string, err error)
|
||||
AddTorrent(filename string, options ...interface{}) (gid string, err error)
|
||||
AddMetalink(filename string, options ...interface{}) (gid []string, err error)
|
||||
Remove(gid string) (g string, err error)
|
||||
ForceRemove(gid string) (g string, err error)
|
||||
Pause(gid string) (g string, err error)
|
||||
PauseAll() (ok string, err error)
|
||||
ForcePause(gid string) (g string, err error)
|
||||
ForcePauseAll() (ok string, err error)
|
||||
Unpause(gid string) (g string, err error)
|
||||
UnpauseAll() (ok string, err error)
|
||||
TellStatus(gid string, keys ...string) (info StatusInfo, err error)
|
||||
GetURIs(gid string) (infos []URIInfo, err error)
|
||||
GetFiles(gid string) (infos []FileInfo, err error)
|
||||
GetPeers(gid string) (infos []PeerInfo, err error)
|
||||
GetServers(gid string) (infos []ServerInfo, err error)
|
||||
TellActive(keys ...string) (infos []StatusInfo, err error)
|
||||
TellWaiting(offset, num int, keys ...string) (infos []StatusInfo, err error)
|
||||
TellStopped(offset, num int, keys ...string) (infos []StatusInfo, err error)
|
||||
ChangePosition(gid string, pos int, how string) (p int, err error)
|
||||
ChangeURI(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error)
|
||||
GetOption(gid string) (m Option, err error)
|
||||
ChangeOption(gid string, option Option) (ok string, err error)
|
||||
GetGlobalOption() (m Option, err error)
|
||||
ChangeGlobalOption(options Option) (ok string, err error)
|
||||
GetGlobalStat() (info GlobalStatInfo, err error)
|
||||
PurgeDownloadResult() (ok string, err error)
|
||||
RemoveDownloadResult(gid string) (ok string, err error)
|
||||
GetVersion() (info VersionInfo, err error)
|
||||
GetSessionInfo() (info SessionInfo, err error)
|
||||
Shutdown() (ok string, err error)
|
||||
ForceShutdown() (ok string, err error)
|
||||
SaveSession() (ok string, err error)
|
||||
Multicall(methods []Method) (r []interface{}, err error)
|
||||
ListMethods() (methods []string, err error)
|
||||
}
|
104
pkg/aria2/rpc/resp.go
Normal file
104
pkg/aria2/rpc/resp.go
Normal file
@ -0,0 +1,104 @@
|
||||
//go:generate easyjson -all
|
||||
|
||||
package rpc
|
||||
|
||||
// StatusInfo represents response of aria2.tellStatus
|
||||
type StatusInfo struct {
|
||||
Gid string `json:"gid"` // GID of the download.
|
||||
Status string `json:"status"` // active for currently downloading/seeding downloads. waiting for downloads in the queue; download is not started. paused for paused downloads. error for downloads that were stopped because of error. complete for stopped and completed downloads. removed for the downloads removed by user.
|
||||
TotalLength string `json:"totalLength"` // Total length of the download in bytes.
|
||||
CompletedLength string `json:"completedLength"` // Completed length of the download in bytes.
|
||||
UploadLength string `json:"uploadLength"` // Uploaded length of the download in bytes.
|
||||
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress. The highest bit corresponds to the piece at index 0. Any set bits indicate loaded pieces, while unset bits indicate not yet loaded and/or missing pieces. Any overflow bits at the end are set to zero. When the download was not started yet, this key will not be included in the response.
|
||||
DownloadSpeed string `json:"downloadSpeed"` // Download speed of this download measured in bytes/sec.
|
||||
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed of this download measured in bytes/sec.
|
||||
InfoHash string `json:"infoHash"` // InfoHash. BitTorrent only.
|
||||
NumSeeders string `json:"numSeeders"` // The number of seeders aria2 has connected to. BitTorrent only.
|
||||
Seeder string `json:"seeder"` // true if the local endpoint is a seeder. Otherwise false. BitTorrent only.
|
||||
PieceLength string `json:"pieceLength"` // Piece length in bytes.
|
||||
NumPieces string `json:"numPieces"` // The number of pieces.
|
||||
Connections string `json:"connections"` // The number of peers/servers aria2 has connected to.
|
||||
ErrorCode string `json:"errorCode"` // The code of the last error for this item, if any. The value is a string. The error codes are defined in the EXIT STATUS section. This value is only available for stopped/completed downloads.
|
||||
ErrorMessage string `json:"errorMessage"` // The (hopefully) human readable error message associated to errorCode.
|
||||
FollowedBy []string `json:"followedBy"` // List of GIDs which are generated as the result of this download. For example, when aria2 downloads a Metalink file, it generates downloads described in the Metalink (see the --follow-metalink option). This value is useful to track auto-generated downloads. If there are no such downloads, this key will not be included in the response.
|
||||
BelongsTo string `json:"belongsTo"` // GID of a parent download. Some downloads are a part of another download. For example, if a file in a Metalink has BitTorrent resources, the downloads of ".torrent" files are parts of that parent. If this download has no parent, this key will not be included in the response.
|
||||
Dir string `json:"dir"` // Directory to save files.
|
||||
Files []FileInfo `json:"files"` // Returns the list of files. The elements of this list are the same structs used in aria2.getFiles() method.
|
||||
BitTorrent BitTorrentInfo `json:"bittorrent"` // Struct which contains information retrieved from the .torrent (file). BitTorrent only. It contains following keys.
|
||||
}
|
||||
|
||||
// URIInfo represents an element of response of aria2.getUris
|
||||
type URIInfo struct {
|
||||
URI string `json:"uri"` // URI
|
||||
Status string `json:"status"` // 'used' if the URI is in use. 'waiting' if the URI is still waiting in the queue.
|
||||
}
|
||||
|
||||
// FileInfo represents an element of response of aria2.getFiles
|
||||
type FileInfo struct {
|
||||
Index string `json:"index"` // Index of the file, starting at 1, in the same order as files appear in the multi-file torrent.
|
||||
Path string `json:"path"` // File path.
|
||||
Length string `json:"length"` // File size in bytes.
|
||||
CompletedLength string `json:"completedLength"` // Completed length of this file in bytes. Please note that it is possible that sum of completedLength is less than the completedLength returned by the aria2.tellStatus() method. This is because completedLength in aria2.getFiles() only includes completed pieces. On the other hand, completedLength in aria2.tellStatus() also includes partially completed pieces.
|
||||
Selected string `json:"selected"` // true if this file is selected by --select-file option. If --select-file is not specified or this is single-file torrent or not a torrent download at all, this value is always true. Otherwise false.
|
||||
URIs []URIInfo `json:"uris"` // Returns a list of URIs for this file. The element type is the same struct used in the aria2.getUris() method.
|
||||
}
|
||||
|
||||
// PeerInfo represents an element of response of aria2.getPeers
|
||||
type PeerInfo struct {
|
||||
PeerId string `json:"peerId"` // Percent-encoded peer ID.
|
||||
IP string `json:"ip"` // IP address of the peer.
|
||||
Port string `json:"port"` // Port number of the peer.
|
||||
BitField string `json:"bitfield"` // Hexadecimal representation of the download progress of the peer. The highest bit corresponds to the piece at index 0. Set bits indicate the piece is available and unset bits indicate the piece is missing. Any spare bits at the end are set to zero.
|
||||
AmChoking string `json:"amChoking"` // true if aria2 is choking the peer. Otherwise false.
|
||||
PeerChoking string `json:"peerChoking"` // true if the peer is choking aria2. Otherwise false.
|
||||
DownloadSpeed string `json:"downloadSpeed"` // Download speed (byte/sec) that this client obtains from the peer.
|
||||
UploadSpeed string `json:"uploadSpeed"` // LocalUpload speed(byte/sec) that this client uploads to the peer.
|
||||
Seeder string `json:"seeder"` // true if this peer is a seeder. Otherwise false.
|
||||
}
|
||||
|
||||
// ServerInfo represents an element of response of aria2.getServers
|
||||
type ServerInfo struct {
|
||||
Index string `json:"index"` // Index of the file, starting at 1, in the same order as files appear in the multi-file metalink.
|
||||
Servers []struct {
|
||||
URI string `json:"uri"` // Original URI.
|
||||
CurrentURI string `json:"currentUri"` // This is the URI currently used for downloading. If redirection is involved, currentUri and uri may differ.
|
||||
DownloadSpeed string `json:"downloadSpeed"` // Download speed (byte/sec)
|
||||
} `json:"servers"` // A list of structs which contain the following keys.
|
||||
}
|
||||
|
||||
// GlobalStatInfo represents response of aria2.getGlobalStat
|
||||
type GlobalStatInfo struct {
|
||||
DownloadSpeed string `json:"downloadSpeed"` // Overall download speed (byte/sec).
|
||||
UploadSpeed string `json:"uploadSpeed"` // Overall upload speed(byte/sec).
|
||||
NumActive string `json:"numActive"` // The number of active downloads.
|
||||
NumWaiting string `json:"numWaiting"` // The number of waiting downloads.
|
||||
NumStopped string `json:"numStopped"` // The number of stopped downloads in the current session. This value is capped by the --max-download-result option.
|
||||
NumStoppedTotal string `json:"numStoppedTotal"` // The number of stopped downloads in the current session and not capped by the --max-download-result option.
|
||||
}
|
||||
|
||||
// VersionInfo represents response of aria2.getVersion
|
||||
type VersionInfo struct {
|
||||
Version string `json:"version"` // Version number of aria2 as a string.
|
||||
Features []string `json:"enabledFeatures"` // List of enabled features. Each feature is given as a string.
|
||||
}
|
||||
|
||||
// SessionInfo represents response of aria2.getSessionInfo
|
||||
type SessionInfo struct {
|
||||
Id string `json:"sessionId"` // Session ID, which is generated each time when aria2 is invoked.
|
||||
}
|
||||
|
||||
// Method is an element of parameters used in system.multicall
|
||||
type Method struct {
|
||||
Name string `json:"methodName"` // Method name to call
|
||||
Params []interface{} `json:"params"` // Array containing parameters to the method call
|
||||
}
|
||||
|
||||
type BitTorrentInfo struct {
|
||||
AnnounceList [][]string `json:"announceList"` // List of lists of announce URIs. If the torrent contains announce and no announce-list, announce is converted to the announce-list format.
|
||||
Comment string `json:"comment"` // The comment of the torrent. comment.utf-8 is used if available.
|
||||
CreationDate int64 `json:"creationDate"` // The creation time of the torrent. The value is an integer since the epoch, measured in seconds.
|
||||
Mode string `json:"mode"` // File mode of the torrent. The value is either single or multi.
|
||||
Info struct {
|
||||
Name string `json:"name"` // name in info dictionary. name.utf-8 is used if available.
|
||||
} `json:"info"` // Struct which contains data from Info dictionary. It contains following keys.
|
||||
}
|
145
pkg/auth/auth.go
Normal file
145
pkg/auth/auth.go
Normal file
@ -0,0 +1,145 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAuthFailed = serializer.NewError(serializer.CodeInvalidSign, "invalid sign", nil)
|
||||
ErrAuthHeaderMissing = serializer.NewError(serializer.CodeNoPermissionErr, "authorization header is missing", nil)
|
||||
ErrExpiresMissing = serializer.NewError(serializer.CodeNoPermissionErr, "expire timestamp is missing", nil)
|
||||
ErrExpired = serializer.NewError(serializer.CodeSignExpired, "signature expired", nil)
|
||||
)
|
||||
|
||||
const CrHeaderPrefix = "X-Cr-"
|
||||
|
||||
// General 通用的认证接口
|
||||
var General Auth
|
||||
|
||||
// Auth 鉴权认证
|
||||
type Auth interface {
|
||||
// 对给定Body进行签名,expires为0表示永不过期
|
||||
Sign(body string, expires int64) string
|
||||
// 对给定Body和Sign进行检查
|
||||
Check(body string, sign string) error
|
||||
}
|
||||
|
||||
// SignRequest 对PUT\POST等复杂HTTP请求签名,只会对URI部分、
|
||||
// 请求正文、`X-Cr-`开头的header进行签名
|
||||
func SignRequest(instance Auth, r *http.Request, expires int64) *http.Request {
|
||||
// 处理有效期
|
||||
if expires > 0 {
|
||||
expires += time.Now().Unix()
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(getSignContent(r), expires)
|
||||
|
||||
// 将签名加到请求Header中
|
||||
r.Header["Authorization"] = []string{"Bearer " + sign}
|
||||
return r
|
||||
}
|
||||
|
||||
// CheckRequest 对复杂请求进行签名验证
|
||||
func CheckRequest(instance Auth, r *http.Request) error {
|
||||
var (
|
||||
sign []string
|
||||
ok bool
|
||||
)
|
||||
if sign, ok = r.Header["Authorization"]; !ok || len(sign) == 0 {
|
||||
return ErrAuthHeaderMissing
|
||||
}
|
||||
sign[0] = strings.TrimPrefix(sign[0], "Bearer ")
|
||||
|
||||
return instance.Check(getSignContent(r), sign[0])
|
||||
}
|
||||
|
||||
// getSignContent 签名请求 path、正文、以`X-`开头的 Header. 如果请求 path 为从机上传 API,
|
||||
// 则不对正文签名。返回待签名/验证的字符串
|
||||
func getSignContent(r *http.Request) (rawSignString string) {
|
||||
// 读取所有body正文
|
||||
var body = []byte{}
|
||||
if !strings.Contains(r.URL.Path, "/api/v3/slave/upload/") {
|
||||
if r.Body != nil {
|
||||
body, _ = ioutil.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
}
|
||||
|
||||
// 决定要签名的header
|
||||
var signedHeader []string
|
||||
for k, _ := range r.Header {
|
||||
if strings.HasPrefix(k, CrHeaderPrefix) && k != CrHeaderPrefix+"Filename" {
|
||||
signedHeader = append(signedHeader, fmt.Sprintf("%s=%s", k, r.Header.Get(k)))
|
||||
}
|
||||
}
|
||||
sort.Strings(signedHeader)
|
||||
|
||||
// 读取所有待签名Header
|
||||
rawSignString = serializer.NewRequestSignString(r.URL.Path, strings.Join(signedHeader, "&"), string(body))
|
||||
|
||||
return rawSignString
|
||||
}
|
||||
|
||||
// SignURI 对URI进行签名,签名只针对Path部分,query部分不做验证
|
||||
func SignURI(instance Auth, uri string, expires int64) (*url.URL, error) {
|
||||
// 处理有效期
|
||||
if expires != 0 {
|
||||
expires += time.Now().Unix()
|
||||
}
|
||||
|
||||
base, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign := instance.Sign(base.Path, expires)
|
||||
|
||||
// 将签名加到URI中
|
||||
queries := base.Query()
|
||||
queries.Set("sign", sign)
|
||||
base.RawQuery = queries.Encode()
|
||||
|
||||
return base, nil
|
||||
}
|
||||
|
||||
// CheckURI 对URI进行鉴权
|
||||
func CheckURI(instance Auth, url *url.URL) error {
|
||||
//获取待验证的签名正文
|
||||
queries := url.Query()
|
||||
sign := queries.Get("sign")
|
||||
queries.Del("sign")
|
||||
url.RawQuery = queries.Encode()
|
||||
|
||||
return instance.Check(url.Path, sign)
|
||||
}
|
||||
|
||||
// Init 初始化通用鉴权器
|
||||
func Init() {
|
||||
var secretKey string
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
secretKey = model.GetSettingByName("secret_key")
|
||||
} else {
|
||||
secretKey = conf.SlaveConfig.Secret
|
||||
if secretKey == "" {
|
||||
util.Log().Panic("SlaveSecret is not set, please specify it in config file.")
|
||||
}
|
||||
}
|
||||
General = HMACAuth{
|
||||
SecretKey: []byte(secretKey),
|
||||
}
|
||||
}
|
54
pkg/auth/hmac.go
Normal file
54
pkg/auth/hmac.go
Normal file
@ -0,0 +1,54 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HMACAuth HMAC算法鉴权
|
||||
type HMACAuth struct {
|
||||
SecretKey []byte
|
||||
}
|
||||
|
||||
// Sign 对给定Body生成expires后失效的签名,expires为过期时间戳,
|
||||
// 填写为0表示不限制有效期
|
||||
func (auth HMACAuth) Sign(body string, expires int64) string {
|
||||
h := hmac.New(sha256.New, auth.SecretKey)
|
||||
expireTimeStamp := strconv.FormatInt(expires, 10)
|
||||
_, err := io.WriteString(h, body+":"+expireTimeStamp)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return base64.URLEncoding.EncodeToString(h.Sum(nil)) + ":" + expireTimeStamp
|
||||
}
|
||||
|
||||
// Check 对给定Body和Sign进行鉴权,包括对expires的检查
|
||||
func (auth HMACAuth) Check(body string, sign string) error {
|
||||
signSlice := strings.Split(sign, ":")
|
||||
// 如果未携带expires字段
|
||||
if signSlice[len(signSlice)-1] == "" {
|
||||
return ErrExpiresMissing
|
||||
}
|
||||
|
||||
// 验证是否过期
|
||||
expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64)
|
||||
if err != nil {
|
||||
return ErrAuthFailed.WithError(err)
|
||||
}
|
||||
// 如果签名过期
|
||||
if expires < time.Now().Unix() && expires != 0 {
|
||||
return ErrExpired
|
||||
}
|
||||
|
||||
// 验证签名
|
||||
if auth.Sign(body, expires) != sign {
|
||||
return ErrAuthFailed
|
||||
}
|
||||
return nil
|
||||
}
|
16
pkg/authn/auth.go
Normal file
16
pkg/authn/auth.go
Normal file
@ -0,0 +1,16 @@
|
||||
package authn
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/duo-labs/webauthn/webauthn"
|
||||
)
|
||||
|
||||
// NewAuthnInstance 新建Authn实例
|
||||
func NewAuthnInstance() (*webauthn.WebAuthn, error) {
|
||||
base := model.GetSiteURL()
|
||||
return webauthn.New(&webauthn.Config{
|
||||
RPDisplayName: model.GetSettingByName("siteName"), // Display Name for your site
|
||||
RPID: base.Hostname(), // Generally the FQDN for your site
|
||||
RPOrigin: base.String(), // The origin URL for WebAuthn requests
|
||||
})
|
||||
}
|
15
pkg/balancer/balancer.go
Normal file
15
pkg/balancer/balancer.go
Normal file
@ -0,0 +1,15 @@
|
||||
package balancer
|
||||
|
||||
type Balancer interface {
|
||||
NextPeer(nodes interface{}) (error, interface{})
|
||||
}
|
||||
|
||||
// NewBalancer 根据策略标识返回新的负载均衡器
|
||||
func NewBalancer(strategy string) Balancer {
|
||||
switch strategy {
|
||||
case "RoundRobin":
|
||||
return &RoundRobin{}
|
||||
default:
|
||||
return &RoundRobin{}
|
||||
}
|
||||
}
|
8
pkg/balancer/errors.go
Normal file
8
pkg/balancer/errors.go
Normal file
@ -0,0 +1,8 @@
|
||||
package balancer
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrInputNotSlice = errors.New("Input value is not silice")
|
||||
ErrNoAvaliableNode = errors.New("No nodes avaliable")
|
||||
)
|
30
pkg/balancer/roundrobin.go
Normal file
30
pkg/balancer/roundrobin.go
Normal file
@ -0,0 +1,30 @@
|
||||
package balancer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type RoundRobin struct {
|
||||
current uint64
|
||||
}
|
||||
|
||||
// NextPeer 返回轮盘的下一节点
|
||||
func (r *RoundRobin) NextPeer(nodes interface{}) (error, interface{}) {
|
||||
v := reflect.ValueOf(nodes)
|
||||
if v.Kind() != reflect.Slice {
|
||||
return ErrInputNotSlice, nil
|
||||
}
|
||||
|
||||
if v.Len() == 0 {
|
||||
return ErrNoAvaliableNode, nil
|
||||
}
|
||||
|
||||
next := r.NextIndex(v.Len())
|
||||
return nil, v.Index(next).Interface()
|
||||
}
|
||||
|
||||
// NextIndex 返回下一个节点下标
|
||||
func (r *RoundRobin) NextIndex(total int) int {
|
||||
return int(atomic.AddUint64(&r.current, uint64(1)) % uint64(total))
|
||||
}
|
104
pkg/cache/driver.go
vendored
Normal file
104
pkg/cache/driver.go
vendored
Normal file
@ -0,0 +1,104 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gob.Register(map[string]itemWithTTL{})
|
||||
}
|
||||
|
||||
// Store 缓存存储器
|
||||
var Store Driver = NewMemoStore()
|
||||
|
||||
// Init 初始化缓存
|
||||
func Init() {
|
||||
if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode {
|
||||
Store = NewRedisStore(
|
||||
10,
|
||||
conf.RedisConfig.Network,
|
||||
conf.RedisConfig.Server,
|
||||
conf.RedisConfig.User,
|
||||
conf.RedisConfig.Password,
|
||||
conf.RedisConfig.DB,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Restore restores cache from given disk file
|
||||
func Restore(persistFile string) {
|
||||
if err := Store.Restore(persistFile); err != nil {
|
||||
util.Log().Warning("Failed to restore cache from disk: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func InitSlaveOverwrites() {
|
||||
err := Store.Sets(conf.OptionOverwrite, "setting_")
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to overwrite database setting: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Driver 键值缓存存储容器
|
||||
type Driver interface {
|
||||
// 设置值,ttl为过期时间,单位为秒
|
||||
Set(key string, value interface{}, ttl int) error
|
||||
|
||||
// 取值,并返回是否成功
|
||||
Get(key string) (interface{}, bool)
|
||||
|
||||
// 批量取值,返回成功取值的map即不存在的值
|
||||
Gets(keys []string, prefix string) (map[string]interface{}, []string)
|
||||
|
||||
// 批量设置值,所有的key都会加上prefix前缀
|
||||
Sets(values map[string]interface{}, prefix string) error
|
||||
|
||||
// 删除值
|
||||
Delete(keys []string, prefix string) error
|
||||
|
||||
// Save in-memory cache to disk
|
||||
Persist(path string) error
|
||||
|
||||
// Restore cache from disk
|
||||
Restore(path string) error
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func Set(key string, value interface{}, ttl int) error {
|
||||
return Store.Set(key, value, ttl)
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func Get(key string) (interface{}, bool) {
|
||||
return Store.Get(key)
|
||||
}
|
||||
|
||||
// Deletes 删除值
|
||||
func Deletes(keys []string, prefix string) error {
|
||||
return Store.Delete(keys, prefix)
|
||||
}
|
||||
|
||||
// GetSettings 根据名称批量获取设置项缓存
|
||||
func GetSettings(keys []string, prefix string) (map[string]string, []string) {
|
||||
raw, miss := Store.Gets(keys, prefix)
|
||||
|
||||
res := make(map[string]string, len(raw))
|
||||
for k, v := range raw {
|
||||
res[k] = v.(string)
|
||||
}
|
||||
|
||||
return res, miss
|
||||
}
|
||||
|
||||
// SetSettings 批量设置站点设置缓存
|
||||
func SetSettings(values map[string]string, prefix string) error {
|
||||
var toBeSet = make(map[string]interface{}, len(values))
|
||||
for key, value := range values {
|
||||
toBeSet[key] = interface{}(value)
|
||||
}
|
||||
return Store.Sets(toBeSet, prefix)
|
||||
}
|
181
pkg/cache/memo.go
vendored
Normal file
181
pkg/cache/memo.go
vendored
Normal file
@ -0,0 +1,181 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// MemoStore 内存存储驱动
|
||||
type MemoStore struct {
|
||||
Store *sync.Map
|
||||
}
|
||||
|
||||
// item 存储的对象
|
||||
type itemWithTTL struct {
|
||||
Expires int64
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
const DefaultCacheFile = "cache_persist.bin"
|
||||
|
||||
func newItem(value interface{}, expires int) itemWithTTL {
|
||||
expires64 := int64(expires)
|
||||
if expires > 0 {
|
||||
expires64 = time.Now().Unix() + expires64
|
||||
}
|
||||
return itemWithTTL{
|
||||
Value: value,
|
||||
Expires: expires64,
|
||||
}
|
||||
}
|
||||
|
||||
// getValue 从itemWithTTL中取值
|
||||
func getValue(item interface{}, ok bool) (interface{}, bool) {
|
||||
if !ok {
|
||||
return nil, ok
|
||||
}
|
||||
|
||||
var itemObj itemWithTTL
|
||||
if itemObj, ok = item.(itemWithTTL); !ok {
|
||||
return item, true
|
||||
}
|
||||
|
||||
if itemObj.Expires > 0 && itemObj.Expires < time.Now().Unix() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return itemObj.Value, ok
|
||||
|
||||
}
|
||||
|
||||
// GarbageCollect 回收已过期的缓存
|
||||
func (store *MemoStore) GarbageCollect() {
|
||||
store.Store.Range(func(key, value interface{}) bool {
|
||||
if item, ok := value.(itemWithTTL); ok {
|
||||
if item.Expires > 0 && item.Expires < time.Now().Unix() {
|
||||
util.Log().Debug("Cache %q is garbage collected.", key.(string))
|
||||
store.Store.Delete(key)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// NewMemoStore 新建内存存储
|
||||
func NewMemoStore() *MemoStore {
|
||||
return &MemoStore{
|
||||
Store: &sync.Map{},
|
||||
}
|
||||
}
|
||||
|
||||
// Set 存储值
|
||||
func (store *MemoStore) Set(key string, value interface{}, ttl int) error {
|
||||
store.Store.Store(key, newItem(value, ttl))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get 取值
|
||||
func (store *MemoStore) Get(key string) (interface{}, bool) {
|
||||
return getValue(store.Store.Load(key))
|
||||
}
|
||||
|
||||
// Gets 批量取值
|
||||
func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
|
||||
var res = make(map[string]interface{})
|
||||
var notFound = make([]string, 0, len(keys))
|
||||
|
||||
for _, key := range keys {
|
||||
if value, ok := getValue(store.Store.Load(prefix + key)); ok {
|
||||
res[key] = value
|
||||
} else {
|
||||
notFound = append(notFound, key)
|
||||
}
|
||||
}
|
||||
|
||||
return res, notFound
|
||||
}
|
||||
|
||||
// Sets 批量设置值
|
||||
func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error {
|
||||
for key, value := range values {
|
||||
store.Store.Store(prefix+key, newItem(value, 0))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 批量删除值
|
||||
func (store *MemoStore) Delete(keys []string, prefix string) error {
|
||||
for _, key := range keys {
|
||||
store.Store.Delete(prefix + key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Persist write memory store into cache
|
||||
func (store *MemoStore) Persist(path string) error {
|
||||
persisted := make(map[string]itemWithTTL)
|
||||
store.Store.Range(func(key, value interface{}) bool {
|
||||
v, ok := store.Store.Load(key)
|
||||
if _, ok := getValue(v, ok); ok {
|
||||
persisted[key.(string)] = v.(itemWithTTL)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
res, err := serializer(persisted)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize cache: %s", err)
|
||||
}
|
||||
|
||||
// err = os.WriteFile(path, res, 0644)
|
||||
file, err := util.CreatNestedFile(path)
|
||||
if err == nil {
|
||||
_, err = file.Write(res)
|
||||
file.Chmod(0644)
|
||||
file.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Restore memory cache from disk file
|
||||
func (store *MemoStore) Restore(path string) error {
|
||||
if !util.Exists(path) {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read cache file: %s", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
f.Close()
|
||||
os.Remove(path)
|
||||
}()
|
||||
|
||||
persisted := &item{}
|
||||
dec := gob.NewDecoder(f)
|
||||
if err := dec.Decode(&persisted); err != nil {
|
||||
return fmt.Errorf("unknown cache file format: %s", err)
|
||||
}
|
||||
|
||||
items := persisted.Value.(map[string]itemWithTTL)
|
||||
loaded := 0
|
||||
for k, v := range items {
|
||||
if _, ok := getValue(v, true); ok {
|
||||
loaded++
|
||||
store.Store.Store(k, v)
|
||||
} else {
|
||||
util.Log().Debug("Persisted cache %q is expired.", k)
|
||||
}
|
||||
}
|
||||
|
||||
util.Log().Info("Restored %d items from %q into memory cache.", loaded, path)
|
||||
return nil
|
||||
}
|
227
pkg/cache/redis.go
vendored
Normal file
227
pkg/cache/redis.go
vendored
Normal file
@ -0,0 +1,227 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
)
|
||||
|
||||
// RedisStore redis存储驱动
|
||||
type RedisStore struct {
|
||||
pool *redis.Pool
|
||||
}
|
||||
|
||||
type item struct {
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func serializer(value interface{}) ([]byte, error) {
|
||||
var buffer bytes.Buffer
|
||||
enc := gob.NewEncoder(&buffer)
|
||||
storeValue := item{
|
||||
Value: value,
|
||||
}
|
||||
err := enc.Encode(storeValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func deserializer(value []byte) (interface{}, error) {
|
||||
var res item
|
||||
buffer := bytes.NewReader(value)
|
||||
dec := gob.NewDecoder(buffer)
|
||||
err := dec.Decode(&res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.Value, nil
|
||||
}
|
||||
|
||||
// NewRedisStore 创建新的redis存储
|
||||
func NewRedisStore(size int, network, address, user, password, database string) *RedisStore {
|
||||
return &RedisStore{
|
||||
pool: &redis.Pool{
|
||||
MaxIdle: size,
|
||||
IdleTimeout: 240 * time.Second,
|
||||
TestOnBorrow: func(c redis.Conn, t time.Time) error {
|
||||
_, err := c.Do("PING")
|
||||
return err
|
||||
},
|
||||
Dial: func() (redis.Conn, error) {
|
||||
db, err := strconv.Atoi(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := redis.Dial(
|
||||
network,
|
||||
address,
|
||||
redis.DialDatabase(db),
|
||||
redis.DialUsername(user),
|
||||
redis.DialPassword(password),
|
||||
)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to create Redis connection: %s", err)
|
||||
}
|
||||
return c, nil
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Set 存储值
|
||||
func (store *RedisStore) Set(key string, value interface{}, ttl int) error {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
|
||||
serialized, err := serializer(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rc.Err() != nil {
|
||||
return rc.Err()
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
_, err = rc.Do("SETEX", key, ttl, serialized)
|
||||
} else {
|
||||
_, err = rc.Do("SET", key, serialized)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Get 取值
|
||||
func (store *RedisStore) Get(key string) (interface{}, bool) {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
v, err := redis.Bytes(rc.Do("GET", key))
|
||||
if err != nil || v == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
finalValue, err := deserializer(v)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return finalValue, true
|
||||
|
||||
}
|
||||
|
||||
// Gets 批量取值
|
||||
func (store *RedisStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
return nil, keys
|
||||
}
|
||||
|
||||
var queryKeys = make([]string, len(keys))
|
||||
for key, value := range keys {
|
||||
queryKeys[key] = prefix + value
|
||||
}
|
||||
|
||||
v, err := redis.ByteSlices(rc.Do("MGET", redis.Args{}.AddFlat(queryKeys)...))
|
||||
if err != nil {
|
||||
return nil, keys
|
||||
}
|
||||
|
||||
var res = make(map[string]interface{})
|
||||
var missed = make([]string, 0, len(keys))
|
||||
|
||||
for key, value := range v {
|
||||
decoded, err := deserializer(value)
|
||||
if err != nil || decoded == nil {
|
||||
missed = append(missed, keys[key])
|
||||
} else {
|
||||
res[keys[key]] = decoded
|
||||
}
|
||||
}
|
||||
// 解码所得值
|
||||
return res, missed
|
||||
}
|
||||
|
||||
// Sets 批量设置值
|
||||
func (store *RedisStore) Sets(values map[string]interface{}, prefix string) error {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
return rc.Err()
|
||||
}
|
||||
var setValues = make(map[string]interface{})
|
||||
|
||||
// 编码待设置值
|
||||
for key, value := range values {
|
||||
serialized, err := serializer(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
setValues[prefix+key] = serialized
|
||||
}
|
||||
|
||||
_, err := rc.Do("MSET", redis.Args{}.AddFlat(setValues)...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Delete 批量删除给定的键
|
||||
func (store *RedisStore) Delete(keys []string, prefix string) error {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
return rc.Err()
|
||||
}
|
||||
|
||||
// 处理前缀
|
||||
for i := 0; i < len(keys); i++ {
|
||||
keys[i] = prefix + keys[i]
|
||||
}
|
||||
|
||||
_, err := rc.Do("DEL", redis.Args{}.AddFlat(keys)...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteAll 批量所有键
|
||||
func (store *RedisStore) DeleteAll() error {
|
||||
rc := store.pool.Get()
|
||||
defer rc.Close()
|
||||
if rc.Err() != nil {
|
||||
return rc.Err()
|
||||
}
|
||||
|
||||
_, err := rc.Do("FLUSHDB")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Persist Dummy implementation
|
||||
func (store *RedisStore) Persist(path string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore dummy implementation
|
||||
func (store *RedisStore) Restore(path string) error {
|
||||
return nil
|
||||
}
|
210
pkg/cluster/controller.go
Normal file
210
pkg/cluster/controller.go
Normal file
@ -0,0 +1,210 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
var DefaultController Controller
|
||||
|
||||
// Controller controls communications between master and slave
|
||||
type Controller interface {
|
||||
// Handle heartbeat sent from master
|
||||
HandleHeartBeat(*serializer.NodePingReq) (serializer.NodePingResp, error)
|
||||
|
||||
// Get Aria2 Instance by master node ID
|
||||
GetAria2Instance(string) (common.Aria2, error)
|
||||
|
||||
// Send event change message to master node
|
||||
SendNotification(string, string, mq.Message) error
|
||||
|
||||
// Submit async task into task pool
|
||||
SubmitTask(string, interface{}, string, func(interface{})) error
|
||||
|
||||
// Get master node info
|
||||
GetMasterInfo(string) (*MasterInfo, error)
|
||||
|
||||
// Get master Oauth based policy credential
|
||||
GetPolicyOauthToken(string, uint) (string, error)
|
||||
}
|
||||
|
||||
type slaveController struct {
|
||||
masters map[string]MasterInfo
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// info of master node
|
||||
type MasterInfo struct {
|
||||
ID string
|
||||
TTL int
|
||||
URL *url.URL
|
||||
// used to invoke aria2 rpc calls
|
||||
Instance Node
|
||||
Client request.Client
|
||||
|
||||
jobTracker map[string]bool
|
||||
}
|
||||
|
||||
func InitController() {
|
||||
DefaultController = &slaveController{
|
||||
masters: make(map[string]MasterInfo),
|
||||
}
|
||||
gob.Register(rpc.StatusInfo{})
|
||||
}
|
||||
|
||||
func (c *slaveController) HandleHeartBeat(req *serializer.NodePingReq) (serializer.NodePingResp, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
req.Node.AfterFind()
|
||||
|
||||
// close old node if exist
|
||||
origin, ok := c.masters[req.SiteID]
|
||||
|
||||
if (ok && req.IsUpdate) || !ok {
|
||||
if ok {
|
||||
origin.Instance.Kill()
|
||||
}
|
||||
|
||||
masterUrl, err := url.Parse(req.SiteURL)
|
||||
if err != nil {
|
||||
return serializer.NodePingResp{}, err
|
||||
}
|
||||
|
||||
c.masters[req.SiteID] = MasterInfo{
|
||||
ID: req.SiteID,
|
||||
URL: masterUrl,
|
||||
TTL: req.CredentialTTL,
|
||||
Client: request.NewClient(
|
||||
request.WithEndpoint(masterUrl.String()),
|
||||
request.WithSlaveMeta(fmt.Sprintf("%d", req.Node.ID)),
|
||||
request.WithCredential(auth.HMACAuth{
|
||||
SecretKey: []byte(req.Node.MasterKey),
|
||||
}, int64(req.CredentialTTL)),
|
||||
),
|
||||
jobTracker: make(map[string]bool),
|
||||
Instance: NewNodeFromDBModel(&model.Node{
|
||||
Model: gorm.Model{ID: req.Node.ID},
|
||||
MasterKey: req.Node.MasterKey,
|
||||
Type: model.MasterNodeType,
|
||||
Aria2Enabled: req.Node.Aria2Enabled,
|
||||
Aria2OptionsSerialized: req.Node.Aria2OptionsSerialized,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
return serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
func (c *slaveController) GetAria2Instance(id string) (common.Aria2, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return node.Instance.GetAria2Instance(), nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
func (c *slaveController) SendNotification(id, subject string, msg mq.Message) error {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
body := bytes.Buffer{}
|
||||
enc := gob.NewEncoder(&body)
|
||||
if err := enc.Encode(&msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"PUT",
|
||||
fmt.Sprintf("/api/v3/slave/notification/%s", subject),
|
||||
&body,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
func (c *slaveController) SubmitTask(id string, job interface{}, hash string, submitter func(interface{})) error {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
if _, ok := node.jobTracker[hash]; ok {
|
||||
// 任务已存在,直接返回
|
||||
return nil
|
||||
}
|
||||
|
||||
node.jobTracker[hash] = true
|
||||
submitter(job)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetMasterInfo 获取主机节点信息
|
||||
func (c *slaveController) GetMasterInfo(id string) (*MasterInfo, error) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
return &node, nil
|
||||
}
|
||||
|
||||
return nil, ErrMasterNotFound
|
||||
}
|
||||
|
||||
// GetPolicyOauthToken 获取主机存储策略 Oauth 凭证
|
||||
func (c *slaveController) GetPolicyOauthToken(id string, policyID uint) (string, error) {
|
||||
c.lock.RLock()
|
||||
|
||||
if node, ok := c.masters[id]; ok {
|
||||
c.lock.RUnlock()
|
||||
|
||||
res, err := node.Client.Request(
|
||||
"GET",
|
||||
fmt.Sprintf("/api/v3/slave/credential/%d", policyID),
|
||||
nil,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), nil
|
||||
}
|
||||
|
||||
c.lock.RUnlock()
|
||||
return "", ErrMasterNotFound
|
||||
}
|
12
pkg/cluster/errors.go
Normal file
12
pkg/cluster/errors.go
Normal file
@ -0,0 +1,12 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrFeatureNotExist = errors.New("No nodes in nodepool match the feature specificed")
|
||||
ErrIlegalPath = errors.New("path out of boundary of setting temp folder")
|
||||
ErrMasterNotFound = serializer.NewError(serializer.CodeMasterNotFound, "Unknown master node id", nil)
|
||||
)
|
272
pkg/cluster/master.go
Normal file
272
pkg/cluster/master.go
Normal file
@ -0,0 +1,272 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
deleteTempFileDuration = 60 * time.Second
|
||||
statusRetryDuration = 10 * time.Second
|
||||
)
|
||||
|
||||
type MasterNode struct {
|
||||
Model *model.Node
|
||||
aria2RPC rpcService
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// RPCService 通过RPC服务的Aria2任务管理器
|
||||
type rpcService struct {
|
||||
Caller rpc.Client
|
||||
Initialized bool
|
||||
|
||||
retryDuration time.Duration
|
||||
deletePaddingDuration time.Duration
|
||||
parent *MasterNode
|
||||
options *clientOptions
|
||||
}
|
||||
|
||||
type clientOptions struct {
|
||||
Options map[string]interface{} // 创建下载时额外添加的设置
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *MasterNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
node.aria2RPC.parent = node
|
||||
node.aria2RPC.retryDuration = statusRetryDuration
|
||||
node.aria2RPC.deletePaddingDuration = deleteTempFileDuration
|
||||
node.lock.Unlock()
|
||||
|
||||
node.lock.RLock()
|
||||
if node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return
|
||||
}
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
func (node *MasterNode) ID() uint {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model.ID
|
||||
}
|
||||
|
||||
func (node *MasterNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
return &serializer.NodePingResp{}, nil
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *MasterNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (node *MasterNode) MasterAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (node *MasterNode) SlaveAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *MasterNode) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
}
|
||||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *MasterNode) IsActive() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Kill 结束aria2请求
|
||||
func (node *MasterNode) Kill() {
|
||||
if node.aria2RPC.Caller != nil {
|
||||
node.aria2RPC.Caller.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取主机Aria2实例
|
||||
func (node *MasterNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
node.lock.RUnlock()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
if !node.aria2RPC.Initialized {
|
||||
node.lock.RUnlock()
|
||||
node.aria2RPC.Init()
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
defer node.lock.RUnlock()
|
||||
return &node.aria2RPC
|
||||
}
|
||||
|
||||
func (node *MasterNode) IsMater() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (node *MasterNode) DBModel() *model.Node {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model
|
||||
}
|
||||
|
||||
func (r *rpcService) Init() error {
|
||||
r.parent.lock.Lock()
|
||||
defer r.parent.lock.Unlock()
|
||||
r.Initialized = false
|
||||
|
||||
// 客户端已存在,则关闭先前连接
|
||||
if r.Caller != nil {
|
||||
r.Caller.Close()
|
||||
}
|
||||
|
||||
// 解析RPC服务地址
|
||||
server, err := url.Parse(r.parent.Model.Aria2OptionsSerialized.Server)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to parse Aria2 RPC server URL: %s", err)
|
||||
return err
|
||||
}
|
||||
server.Path = "/jsonrpc"
|
||||
|
||||
// 加载自定义下载配置
|
||||
var globalOptions map[string]interface{}
|
||||
if r.parent.Model.Aria2OptionsSerialized.Options != "" {
|
||||
err = json.Unmarshal([]byte(r.parent.Model.Aria2OptionsSerialized.Options), &globalOptions)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to parse aria2 options: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
r.options = &clientOptions{
|
||||
Options: globalOptions,
|
||||
}
|
||||
timeout := r.parent.Model.Aria2OptionsSerialized.Timeout
|
||||
caller, err := rpc.New(context.Background(), server.String(), r.parent.Model.Aria2OptionsSerialized.Token, time.Duration(timeout)*time.Second, mq.GlobalMQ)
|
||||
|
||||
r.Caller = caller
|
||||
r.Initialized = err == nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) CreateTask(task *model.Download, groupOptions map[string]interface{}) (string, error) {
|
||||
r.parent.lock.RLock()
|
||||
// 生成存储路径
|
||||
guid, _ := uuid.NewV4()
|
||||
path := filepath.Join(
|
||||
r.parent.Model.Aria2OptionsSerialized.TempPath,
|
||||
"aria2",
|
||||
guid.String(),
|
||||
)
|
||||
r.parent.lock.RUnlock()
|
||||
|
||||
// 创建下载任务
|
||||
options := map[string]interface{}{
|
||||
"dir": path,
|
||||
}
|
||||
for k, v := range r.options.Options {
|
||||
options[k] = v
|
||||
}
|
||||
for k, v := range groupOptions {
|
||||
options[k] = v
|
||||
}
|
||||
|
||||
gid, err := r.Caller.AddURI(task.Source, options)
|
||||
if err != nil || gid == "" {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return gid, nil
|
||||
}
|
||||
|
||||
func (r *rpcService) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
res, err := r.Caller.TellStatus(task.GID)
|
||||
if err != nil {
|
||||
// 失败后重试
|
||||
util.Log().Debug("Failed to get download task status, please retry later: %s", err)
|
||||
time.Sleep(r.retryDuration)
|
||||
res, err = r.Caller.TellStatus(task.GID)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (r *rpcService) Cancel(task *model.Download) error {
|
||||
// 取消下载任务
|
||||
_, err := r.Caller.Remove(task.GID)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to cancel task %q: %s", task.GID, err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) Select(task *model.Download, files []int) error {
|
||||
var selected = make([]string, len(files))
|
||||
for i := 0; i < len(files); i++ {
|
||||
selected[i] = strconv.Itoa(files[i])
|
||||
}
|
||||
_, err := r.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")})
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *rpcService) GetConfig() model.Aria2Option {
|
||||
r.parent.lock.RLock()
|
||||
defer r.parent.lock.RUnlock()
|
||||
|
||||
return r.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func (s *rpcService) DeleteTempFile(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
// 避免被aria2占用,异步执行删除
|
||||
go func(d time.Duration, src string) {
|
||||
time.Sleep(d)
|
||||
err := os.RemoveAll(src)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to delete temp download folder: %q: %s", src, err)
|
||||
}
|
||||
}(s.deletePaddingDuration, task.Parent)
|
||||
|
||||
return nil
|
||||
}
|
60
pkg/cluster/node.go
Normal file
60
pkg/cluster/node.go
Normal file
@ -0,0 +1,60 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
type Node interface {
|
||||
// Init a node from database model
|
||||
Init(node *model.Node)
|
||||
|
||||
// Check if given feature is enabled
|
||||
IsFeatureEnabled(feature string) bool
|
||||
|
||||
// Subscribe node status change to a callback function
|
||||
SubscribeStatusChange(callback func(isActive bool, id uint))
|
||||
|
||||
// Ping the node
|
||||
Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error)
|
||||
|
||||
// Returns if the node is active
|
||||
IsActive() bool
|
||||
|
||||
// Get instances for aria2 calls
|
||||
GetAria2Instance() common.Aria2
|
||||
|
||||
// Returns unique id of this node
|
||||
ID() uint
|
||||
|
||||
// Kill node and recycle resources
|
||||
Kill()
|
||||
|
||||
// Returns if current node is master node
|
||||
IsMater() bool
|
||||
|
||||
// Get auth instance used to check RPC call from slave to master
|
||||
MasterAuthInstance() auth.Auth
|
||||
|
||||
// Get auth instance used to check RPC call from master to slave
|
||||
SlaveAuthInstance() auth.Auth
|
||||
|
||||
// Get node DB model
|
||||
DBModel() *model.Node
|
||||
}
|
||||
|
||||
// Create new node from DB model
|
||||
func NewNodeFromDBModel(node *model.Node) Node {
|
||||
switch node.Type {
|
||||
case model.SlaveNodeType:
|
||||
slave := &SlaveNode{}
|
||||
slave.Init(node)
|
||||
return slave
|
||||
default:
|
||||
master := &MasterNode{}
|
||||
master.Init(node)
|
||||
return master
|
||||
}
|
||||
}
|
213
pkg/cluster/pool.go
Normal file
213
pkg/cluster/pool.go
Normal file
@ -0,0 +1,213 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/samber/lo"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var Default *NodePool
|
||||
|
||||
// 需要分类的节点组
|
||||
var featureGroup = []string{"aria2"}
|
||||
|
||||
// Pool 节点池
|
||||
type Pool interface {
|
||||
// Returns active node selected by given feature and load balancer
|
||||
BalanceNodeByFeature(feature string, lb balancer.Balancer, available []uint, prefer uint) (error, Node)
|
||||
|
||||
// Returns node by ID
|
||||
GetNodeByID(id uint) Node
|
||||
|
||||
// Add given node into pool. If node existed, refresh node.
|
||||
Add(node *model.Node)
|
||||
|
||||
// Delete and kill node from pool by given node id
|
||||
Delete(id uint)
|
||||
}
|
||||
|
||||
// NodePool 通用节点池
|
||||
type NodePool struct {
|
||||
active map[uint]Node
|
||||
inactive map[uint]Node
|
||||
|
||||
featureMap map[string][]Node
|
||||
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// Init 初始化从机节点池
|
||||
func Init() {
|
||||
Default = &NodePool{}
|
||||
Default.Init()
|
||||
if err := Default.initFromDB(); err != nil {
|
||||
util.Log().Warning("Failed to initialize node pool: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (pool *NodePool) Init() {
|
||||
pool.lock.Lock()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
pool.featureMap = make(map[string][]Node)
|
||||
pool.active = make(map[uint]Node)
|
||||
pool.inactive = make(map[uint]Node)
|
||||
}
|
||||
|
||||
func (pool *NodePool) buildIndexMap() {
|
||||
pool.lock.Lock()
|
||||
for _, feature := range featureGroup {
|
||||
pool.featureMap[feature] = make([]Node, 0)
|
||||
}
|
||||
|
||||
for _, v := range pool.active {
|
||||
for _, feature := range featureGroup {
|
||||
if v.IsFeatureEnabled(feature) {
|
||||
pool.featureMap[feature] = append(pool.featureMap[feature], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
}
|
||||
|
||||
func (pool *NodePool) GetNodeByID(id uint) Node {
|
||||
pool.lock.RLock()
|
||||
defer pool.lock.RUnlock()
|
||||
|
||||
if node, ok := pool.active[id]; ok {
|
||||
return node
|
||||
}
|
||||
|
||||
return pool.inactive[id]
|
||||
}
|
||||
|
||||
func (pool *NodePool) nodeStatusChange(isActive bool, id uint) {
|
||||
util.Log().Debug("Slave node [ID=%d] status changed to [Active=%t].", id, isActive)
|
||||
var node Node
|
||||
pool.lock.Lock()
|
||||
if n, ok := pool.inactive[id]; ok {
|
||||
node = n
|
||||
delete(pool.inactive, id)
|
||||
} else {
|
||||
node = pool.active[id]
|
||||
delete(pool.active, id)
|
||||
}
|
||||
|
||||
if isActive {
|
||||
pool.active[id] = node
|
||||
} else {
|
||||
pool.inactive[id] = node
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
pool.buildIndexMap()
|
||||
}
|
||||
|
||||
func (pool *NodePool) initFromDB() error {
|
||||
nodes, err := model.GetNodesByStatus(model.NodeActive)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pool.lock.Lock()
|
||||
for i := 0; i < len(nodes); i++ {
|
||||
pool.add(&nodes[i])
|
||||
}
|
||||
pool.lock.Unlock()
|
||||
|
||||
pool.buildIndexMap()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pool *NodePool) add(node *model.Node) {
|
||||
newNode := NewNodeFromDBModel(node)
|
||||
if newNode.IsActive() {
|
||||
pool.active[node.ID] = newNode
|
||||
} else {
|
||||
pool.inactive[node.ID] = newNode
|
||||
}
|
||||
|
||||
// 订阅节点状态变更
|
||||
newNode.SubscribeStatusChange(func(isActive bool, id uint) {
|
||||
pool.nodeStatusChange(isActive, id)
|
||||
})
|
||||
}
|
||||
|
||||
func (pool *NodePool) Add(node *model.Node) {
|
||||
pool.lock.Lock()
|
||||
defer pool.buildIndexMap()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
var (
|
||||
old Node
|
||||
ok bool
|
||||
)
|
||||
if old, ok = pool.active[node.ID]; !ok {
|
||||
old, ok = pool.inactive[node.ID]
|
||||
}
|
||||
if old != nil {
|
||||
go old.Init(node)
|
||||
return
|
||||
}
|
||||
|
||||
pool.add(node)
|
||||
}
|
||||
|
||||
func (pool *NodePool) Delete(id uint) {
|
||||
pool.lock.Lock()
|
||||
defer pool.buildIndexMap()
|
||||
defer pool.lock.Unlock()
|
||||
|
||||
if node, ok := pool.active[id]; ok {
|
||||
node.Kill()
|
||||
delete(pool.active, id)
|
||||
return
|
||||
}
|
||||
|
||||
if node, ok := pool.inactive[id]; ok {
|
||||
node.Kill()
|
||||
delete(pool.inactive, id)
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// BalanceNodeByFeature 根据 feature 和 LoadBalancer 取出节点
|
||||
func (pool *NodePool) BalanceNodeByFeature(feature string, lb balancer.Balancer,
|
||||
available []uint, prefer uint) (error, Node) {
|
||||
pool.lock.RLock()
|
||||
defer pool.lock.RUnlock()
|
||||
if nodes, ok := pool.featureMap[feature]; ok {
|
||||
// Find nodes that are allowed to be used in user group
|
||||
availableNodes := nodes
|
||||
if len(available) > 0 {
|
||||
idHash := make(map[uint]struct{}, len(available))
|
||||
for _, id := range available {
|
||||
idHash[id] = struct{}{}
|
||||
}
|
||||
|
||||
availableNodes = lo.Filter[Node](nodes, func(node Node, index int) bool {
|
||||
_, exist := idHash[node.ID()]
|
||||
return exist
|
||||
})
|
||||
}
|
||||
|
||||
// Return preferred node if exists
|
||||
if preferredNode, found := lo.Find[Node](availableNodes, func(node Node) bool {
|
||||
return node.ID() == prefer
|
||||
}); found {
|
||||
return nil, preferredNode
|
||||
}
|
||||
|
||||
err, res := lb.NextPeer(availableNodes)
|
||||
if err == nil {
|
||||
return nil, res.(Node)
|
||||
}
|
||||
|
||||
return err, nil
|
||||
}
|
||||
|
||||
return ErrFeatureNotExist, nil
|
||||
}
|
451
pkg/cluster/slave.go
Normal file
451
pkg/cluster/slave.go
Normal file
@ -0,0 +1,451 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SlaveNode struct {
|
||||
Model *model.Node
|
||||
Active bool
|
||||
|
||||
caller slaveCaller
|
||||
callback func(bool, uint)
|
||||
close chan bool
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
type slaveCaller struct {
|
||||
parent *SlaveNode
|
||||
Client request.Client
|
||||
}
|
||||
|
||||
// Init 初始化节点
|
||||
func (node *SlaveNode) Init(nodeModel *model.Node) {
|
||||
node.lock.Lock()
|
||||
node.Model = nodeModel
|
||||
|
||||
// Init http request client
|
||||
var endpoint *url.URL
|
||||
if serverURL, err := url.Parse(node.Model.Server); err == nil {
|
||||
var controller *url.URL
|
||||
controller, _ = url.Parse("/api/v3/slave/")
|
||||
endpoint = serverURL.ResolveReference(controller)
|
||||
}
|
||||
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
node.caller.Client = request.NewClient(
|
||||
request.WithMasterMeta(),
|
||||
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||
request.WithCredential(auth.HMACAuth{SecretKey: []byte(nodeModel.SlaveKey)}, int64(signTTL)),
|
||||
request.WithEndpoint(endpoint.String()),
|
||||
)
|
||||
|
||||
node.caller.parent = node
|
||||
if node.close != nil {
|
||||
node.lock.Unlock()
|
||||
node.close <- true
|
||||
go node.StartPingLoop()
|
||||
} else {
|
||||
node.Active = true
|
||||
node.lock.Unlock()
|
||||
go node.StartPingLoop()
|
||||
}
|
||||
}
|
||||
|
||||
// IsFeatureEnabled 查询节点的某项功能是否启用
|
||||
func (node *SlaveNode) IsFeatureEnabled(feature string) bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
switch feature {
|
||||
case "aria2":
|
||||
return node.Model.Aria2Enabled
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// SubscribeStatusChange 订阅节点状态更改
|
||||
func (node *SlaveNode) SubscribeStatusChange(callback func(bool, uint)) {
|
||||
node.lock.Lock()
|
||||
node.callback = callback
|
||||
node.lock.Unlock()
|
||||
}
|
||||
|
||||
// Ping 从机节点,返回从机负载
|
||||
func (node *SlaveNode) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
reqBodyEncoded, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
|
||||
resp, err := node.caller.Client.Request(
|
||||
"POST",
|
||||
"heartbeat",
|
||||
bodyReader,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return nil, serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
var res serializer.NodePingResp
|
||||
|
||||
if resStr, ok := resp.Data.(string); ok {
|
||||
err = json.Unmarshal([]byte(resStr), &res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &res, nil
|
||||
}
|
||||
|
||||
// IsActive 返回节点是否在线
|
||||
func (node *SlaveNode) IsActive() bool {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Active
|
||||
}
|
||||
|
||||
// Kill 结束节点内相关循环
|
||||
func (node *SlaveNode) Kill() {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if node.close != nil {
|
||||
close(node.close)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAria2Instance 获取从机Aria2实例
|
||||
func (node *SlaveNode) GetAria2Instance() common.Aria2 {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
if !node.Model.Aria2Enabled {
|
||||
return &common.DummyAria2{}
|
||||
}
|
||||
|
||||
return &node.caller
|
||||
}
|
||||
|
||||
func (node *SlaveNode) ID() uint {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model.ID
|
||||
}
|
||||
|
||||
func (node *SlaveNode) StartPingLoop() {
|
||||
node.lock.Lock()
|
||||
node.close = make(chan bool)
|
||||
node.lock.Unlock()
|
||||
|
||||
tickDuration := time.Duration(model.GetIntSetting("slave_ping_interval", 300)) * time.Second
|
||||
recoverDuration := time.Duration(model.GetIntSetting("slave_recover_interval", 600)) * time.Second
|
||||
pingTicker := time.Duration(0)
|
||||
|
||||
util.Log().Debug("Slave node %q heartbeat loop started.", node.Model.Name)
|
||||
retry := 0
|
||||
recoverMode := false
|
||||
isFirstLoop := true
|
||||
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-time.After(pingTicker):
|
||||
if pingTicker == 0 {
|
||||
pingTicker = tickDuration
|
||||
}
|
||||
|
||||
util.Log().Debug("Slave node %q send ping.", node.Model.Name)
|
||||
res, err := node.Ping(node.getHeartbeatContent(isFirstLoop))
|
||||
isFirstLoop = false
|
||||
|
||||
if err != nil {
|
||||
util.Log().Debug("Error while ping slave node %q: %s", node.Model.Name, err)
|
||||
retry++
|
||||
if retry >= model.GetIntSetting("slave_node_retry", 3) {
|
||||
util.Log().Debug("Retry threshold for pinging slave node %q exceeded, mark it as offline.", node.Model.Name)
|
||||
node.changeStatus(false)
|
||||
|
||||
if !recoverMode {
|
||||
// 启动恢复监控循环
|
||||
util.Log().Debug("Slave node %q entered recovery mode.", node.Model.Name)
|
||||
pingTicker = recoverDuration
|
||||
recoverMode = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if recoverMode {
|
||||
util.Log().Debug("Slave node %q recovered.", node.Model.Name)
|
||||
pingTicker = tickDuration
|
||||
recoverMode = false
|
||||
isFirstLoop = true
|
||||
}
|
||||
|
||||
util.Log().Debug("Status of slave node %q: %s", node.Model.Name, res)
|
||||
node.changeStatus(true)
|
||||
retry = 0
|
||||
}
|
||||
|
||||
case <-node.close:
|
||||
util.Log().Debug("Slave node %q received shutdown signal.", node.Model.Name)
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) IsMater() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (node *SlaveNode) MasterAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.MasterKey)}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) SlaveAuthInstance() auth.Auth {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return auth.HMACAuth{SecretKey: []byte(node.Model.SlaveKey)}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) DBModel() *model.Node {
|
||||
node.lock.RLock()
|
||||
defer node.lock.RUnlock()
|
||||
|
||||
return node.Model
|
||||
}
|
||||
|
||||
// getHeartbeatContent gets serializer.NodePingReq used to send heartbeat to slave
|
||||
func (node *SlaveNode) getHeartbeatContent(isUpdate bool) *serializer.NodePingReq {
|
||||
return &serializer.NodePingReq{
|
||||
SiteURL: model.GetSiteURL().String(),
|
||||
IsUpdate: isUpdate,
|
||||
SiteID: model.GetSettingByName("siteID"),
|
||||
Node: node.Model,
|
||||
CredentialTTL: model.GetIntSetting("slave_api_timeout", 60),
|
||||
}
|
||||
}
|
||||
|
||||
func (node *SlaveNode) changeStatus(isActive bool) {
|
||||
node.lock.RLock()
|
||||
id := node.Model.ID
|
||||
if isActive != node.Active {
|
||||
node.lock.RUnlock()
|
||||
node.lock.Lock()
|
||||
node.Active = isActive
|
||||
node.lock.Unlock()
|
||||
node.callback(isActive, id)
|
||||
} else {
|
||||
node.lock.RUnlock()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Init() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendAria2Call send remote aria2 call to slave node
|
||||
func (s *slaveCaller) SendAria2Call(body *serializer.SlaveAria2Call, scope string) (*serializer.Response, error) {
|
||||
reqReader, err := getAria2RequestBody(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.Client.Request(
|
||||
"POST",
|
||||
"aria2/"+scope,
|
||||
reqReader,
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
}
|
||||
|
||||
func (s *slaveCaller) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
GroupOptions: options,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "task")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return "", serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return res.Data.(string), err
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "status")
|
||||
if err != nil {
|
||||
return rpc.StatusInfo{}, err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return rpc.StatusInfo{}, serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
var status rpc.StatusInfo
|
||||
res.GobDecode(&status)
|
||||
|
||||
return status, err
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Cancel(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "cancel")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) Select(task *model.Download, files []int) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
Files: files,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "select")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *slaveCaller) GetConfig() model.Aria2Option {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
return s.parent.Model.Aria2OptionsSerialized
|
||||
}
|
||||
|
||||
func (s *slaveCaller) DeleteTempFile(task *model.Download) error {
|
||||
s.parent.lock.RLock()
|
||||
defer s.parent.lock.RUnlock()
|
||||
|
||||
req := &serializer.SlaveAria2Call{
|
||||
Task: task,
|
||||
}
|
||||
|
||||
res, err := s.SendAria2Call(req, "delete")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAria2RequestBody(body *serializer.SlaveAria2Call) (io.Reader, error) {
|
||||
reqBodyEncoded, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return strings.NewReader(string(reqBodyEncoded)), nil
|
||||
}
|
||||
|
||||
// RemoteCallback 发送远程存储策略上传回调请求
|
||||
func RemoteCallback(url string, body serializer.UploadCallback) error {
|
||||
callbackBody, err := json.Marshal(struct {
|
||||
Data serializer.UploadCallback `json:"data"`
|
||||
}{
|
||||
Data: body,
|
||||
})
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Failed to encode callback content", err)
|
||||
}
|
||||
|
||||
resp := request.GeneralClient.Request(
|
||||
"POST",
|
||||
url,
|
||||
bytes.NewReader(callbackBody),
|
||||
request.WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second),
|
||||
request.WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)),
|
||||
)
|
||||
|
||||
if resp.Err != nil {
|
||||
return serializer.NewError(serializer.CodeCallbackError, "Slave cannot send callback request", resp.Err)
|
||||
}
|
||||
|
||||
// 解析回调服务端响应
|
||||
response, err := resp.DecodeResponse()
|
||||
if err != nil {
|
||||
msg := fmt.Sprintf("Slave cannot parse callback response from master (StatusCode=%d).", resp.Response.StatusCode)
|
||||
return serializer.NewError(serializer.CodeCallbackError, msg, err)
|
||||
}
|
||||
|
||||
if response.Code != 0 {
|
||||
return serializer.NewError(response.Code, response.Msg, errors.New(response.Error))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
156
pkg/conf/conf.go
Normal file
156
pkg/conf/conf.go
Normal file
@ -0,0 +1,156 @@
|
||||
package conf
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/go-ini/ini"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// database 数据库
|
||||
type database struct {
|
||||
Type string
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Name string
|
||||
TablePrefix string
|
||||
DBFile string
|
||||
Port int
|
||||
Charset string
|
||||
UnixSocket bool
|
||||
}
|
||||
|
||||
// system 系统通用配置
|
||||
type system struct {
|
||||
Mode string `validate:"eq=master|eq=slave"`
|
||||
Listen string `validate:"required"`
|
||||
Debug bool
|
||||
SessionSecret string
|
||||
HashIDSalt string
|
||||
GracePeriod int `validate:"gte=0"`
|
||||
ProxyHeader string `validate:"required_with=Listen"`
|
||||
}
|
||||
|
||||
type ssl struct {
|
||||
CertPath string `validate:"omitempty,required"`
|
||||
KeyPath string `validate:"omitempty,required"`
|
||||
Listen string `validate:"required"`
|
||||
}
|
||||
|
||||
type unix struct {
|
||||
Listen string
|
||||
Perm uint32
|
||||
}
|
||||
|
||||
// slave 作为slave存储端配置
|
||||
type slave struct {
|
||||
Secret string `validate:"omitempty,gte=64"`
|
||||
CallbackTimeout int `validate:"omitempty,gte=1"`
|
||||
SignatureTTL int `validate:"omitempty,gte=1"`
|
||||
}
|
||||
|
||||
// redis 配置
|
||||
type redis struct {
|
||||
Network string
|
||||
Server string
|
||||
User string
|
||||
Password string
|
||||
DB string
|
||||
}
|
||||
|
||||
// 跨域配置
|
||||
type cors struct {
|
||||
AllowOrigins []string
|
||||
AllowMethods []string
|
||||
AllowHeaders []string
|
||||
AllowCredentials bool
|
||||
ExposeHeaders []string
|
||||
SameSite string
|
||||
Secure bool
|
||||
}
|
||||
|
||||
var cfg *ini.File
|
||||
|
||||
const defaultConf = `[System]
|
||||
Debug = false
|
||||
Mode = master
|
||||
Listen = :5212
|
||||
SessionSecret = {SessionSecret}
|
||||
HashIDSalt = {HashIDSalt}
|
||||
`
|
||||
|
||||
// Init 初始化配置文件
|
||||
func Init(path string) {
|
||||
var err error
|
||||
|
||||
if path == "" || !util.Exists(path) {
|
||||
// 创建初始配置文件
|
||||
confContent := util.Replace(map[string]string{
|
||||
"{SessionSecret}": util.RandStringRunes(64),
|
||||
"{HashIDSalt}": util.RandStringRunes(64),
|
||||
}, defaultConf)
|
||||
f, err := util.CreatNestedFile(path)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to create config file: %s", err)
|
||||
}
|
||||
|
||||
// 写入配置文件
|
||||
_, err = f.WriteString(confContent)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to write config file: %s", err)
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
|
||||
cfg, err = ini.Load(path)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to parse config file %q: %s", path, err)
|
||||
}
|
||||
|
||||
sections := map[string]interface{}{
|
||||
"Database": DatabaseConfig,
|
||||
"System": SystemConfig,
|
||||
"SSL": SSLConfig,
|
||||
"UnixSocket": UnixConfig,
|
||||
"Redis": RedisConfig,
|
||||
"CORS": CORSConfig,
|
||||
"Slave": SlaveConfig,
|
||||
}
|
||||
for sectionName, sectionStruct := range sections {
|
||||
err = mapSection(sectionName, sectionStruct)
|
||||
if err != nil {
|
||||
util.Log().Panic("Failed to parse config section %q: %s", sectionName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 映射数据库配置覆盖
|
||||
for _, key := range cfg.Section("OptionOverwrite").Keys() {
|
||||
OptionOverwrite[key.Name()] = key.Value()
|
||||
}
|
||||
|
||||
// 重设log等级
|
||||
if !SystemConfig.Debug {
|
||||
util.Level = util.LevelInformational
|
||||
util.GloablLogger = nil
|
||||
util.Log()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// mapSection 将配置文件的 Section 映射到结构体上
|
||||
func mapSection(section string, confStruct interface{}) error {
|
||||
err := cfg.Section(section).MapTo(confStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证合法性
|
||||
validate := validator.New()
|
||||
err = validate.Struct(confStruct)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
55
pkg/conf/defaults.go
Normal file
55
pkg/conf/defaults.go
Normal file
@ -0,0 +1,55 @@
|
||||
package conf
|
||||
|
||||
// RedisConfig Redis服务器配置
|
||||
var RedisConfig = &redis{
|
||||
Network: "tcp",
|
||||
Server: "",
|
||||
Password: "",
|
||||
DB: "0",
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
var DatabaseConfig = &database{
|
||||
Type: "UNSET",
|
||||
Charset: "utf8",
|
||||
DBFile: "cloudreve.db",
|
||||
Port: 3306,
|
||||
UnixSocket: false,
|
||||
}
|
||||
|
||||
// SystemConfig 系统公用配置
|
||||
var SystemConfig = &system{
|
||||
Debug: false,
|
||||
Mode: "master",
|
||||
Listen: ":5212",
|
||||
ProxyHeader: "X-Forwarded-For",
|
||||
}
|
||||
|
||||
// CORSConfig 跨域配置
|
||||
var CORSConfig = &cors{
|
||||
AllowOrigins: []string{"UNSET"},
|
||||
AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"},
|
||||
AllowHeaders: []string{"Cookie", "X-Cr-Policy", "Authorization", "Content-Length", "Content-Type", "X-Cr-Path", "X-Cr-FileName"},
|
||||
AllowCredentials: false,
|
||||
ExposeHeaders: nil,
|
||||
SameSite: "Default",
|
||||
Secure: false,
|
||||
}
|
||||
|
||||
// SlaveConfig 从机配置
|
||||
var SlaveConfig = &slave{
|
||||
CallbackTimeout: 20,
|
||||
SignatureTTL: 60,
|
||||
}
|
||||
|
||||
var SSLConfig = &ssl{
|
||||
Listen: ":443",
|
||||
CertPath: "",
|
||||
KeyPath: "",
|
||||
}
|
||||
|
||||
var UnixConfig = &unix{
|
||||
Listen: "",
|
||||
}
|
||||
|
||||
var OptionOverwrite = map[string]interface{}{}
|
22
pkg/conf/version.go
Normal file
22
pkg/conf/version.go
Normal file
@ -0,0 +1,22 @@
|
||||
package conf
|
||||
|
||||
// plusVersion 增强版版本号
|
||||
const plusVersion = "+1.1"
|
||||
|
||||
// BackendVersion 当前后端版本号
|
||||
const BackendVersion = "3.8.3" + plusVersion
|
||||
|
||||
// KeyVersion 授权版本号
|
||||
const KeyVersion = "3.3.1"
|
||||
|
||||
// RequiredDBVersion 与当前版本匹配的数据库版本
|
||||
const RequiredDBVersion = "3.8.1+1.0-plus"
|
||||
|
||||
// RequiredStaticVersion 与当前版本匹配的静态资源版本
|
||||
const RequiredStaticVersion = "3.8.3" + plusVersion
|
||||
|
||||
// IsPlus 是否为Plus版本
|
||||
const IsPlus = "true"
|
||||
|
||||
// LastCommit 最后commit id
|
||||
const LastCommit = "88409cc"
|
99
pkg/crontab/collect.go
Normal file
99
pkg/crontab/collect.go
Normal file
@ -0,0 +1,99 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
func garbageCollect() {
|
||||
// 清理打包下载产生的临时文件
|
||||
collectArchiveFile()
|
||||
|
||||
// 清理过期的内置内存缓存
|
||||
if store, ok := cache.Store.(*cache.MemoStore); ok {
|
||||
collectCache(store)
|
||||
}
|
||||
|
||||
util.Log().Info("Crontab job \"cron_garbage_collect\" complete.")
|
||||
}
|
||||
|
||||
func collectArchiveFile() {
|
||||
// 读取有效期、目录设置
|
||||
tempPath := util.RelativePath(model.GetSettingByName("temp_path"))
|
||||
expires := model.GetIntSetting("download_timeout", 30)
|
||||
|
||||
// 列出文件
|
||||
root := filepath.Join(tempPath, "archive")
|
||||
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
|
||||
if err == nil && !info.IsDir() &&
|
||||
strings.HasPrefix(filepath.Base(path), "archive_") &&
|
||||
time.Now().Sub(info.ModTime()).Seconds() > float64(expires) {
|
||||
util.Log().Debug("Delete expired batch download temp file %q.", path)
|
||||
// 删除符合条件的文件
|
||||
if err := os.Remove(path); err != nil {
|
||||
util.Log().Debug("Failed to delete temp file %q: %s", path, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.Log().Debug("Crontab job cannot list temp batch download folder: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func collectCache(store *cache.MemoStore) {
|
||||
util.Log().Debug("Cleanup memory cache.")
|
||||
store.GarbageCollect()
|
||||
}
|
||||
|
||||
func uploadSessionCollect() {
|
||||
placeholders := model.GetUploadPlaceholderFiles(0)
|
||||
|
||||
// 将过期的上传会话按照用户分组
|
||||
userToFiles := make(map[uint][]uint)
|
||||
for _, file := range placeholders {
|
||||
_, sessionExist := cache.Get(filesystem.UploadSessionCachePrefix + *file.UploadSessionID)
|
||||
if sessionExist {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := userToFiles[file.UserID]; !ok {
|
||||
userToFiles[file.UserID] = make([]uint, 0)
|
||||
}
|
||||
|
||||
userToFiles[file.UserID] = append(userToFiles[file.UserID], file.ID)
|
||||
}
|
||||
|
||||
// 删除过期的会话
|
||||
for uid, filesIDs := range userToFiles {
|
||||
user, err := model.GetUserByID(uid)
|
||||
if err != nil {
|
||||
util.Log().Warning("Owner of the upload session cannot be found: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
fs, err := filesystem.NewFileSystem(&user)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to initialize filesystem: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if err = fs.Delete(context.Background(), []uint{}, filesIDs, false, false); err != nil {
|
||||
util.Log().Warning("Failed to delete upload session: %s", err)
|
||||
}
|
||||
|
||||
fs.Recycle()
|
||||
}
|
||||
|
||||
util.Log().Info("Crontab job \"cron_recycle_upload_session\" complete.")
|
||||
}
|
53
pkg/crontab/init.go
Normal file
53
pkg/crontab/init.go
Normal file
@ -0,0 +1,53 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
// Cron 定时任务
|
||||
var Cron *cron.Cron
|
||||
|
||||
// Reload 重新启动定时任务
|
||||
func Reload() {
|
||||
if Cron != nil {
|
||||
Cron.Stop()
|
||||
}
|
||||
Init()
|
||||
}
|
||||
|
||||
// Init 初始化定时任务
|
||||
func Init() {
|
||||
util.Log().Info("Initialize crontab jobs...")
|
||||
// 读取cron日程设置
|
||||
options := model.GetSettingByNames(
|
||||
"cron_garbage_collect",
|
||||
"cron_notify_user",
|
||||
"cron_ban_user",
|
||||
"cron_recycle_upload_session",
|
||||
)
|
||||
Cron := cron.New()
|
||||
for k, v := range options {
|
||||
var handler func()
|
||||
switch k {
|
||||
case "cron_garbage_collect":
|
||||
handler = garbageCollect
|
||||
case "cron_notify_user":
|
||||
handler = notifyExpiredVAS
|
||||
case "cron_ban_user":
|
||||
handler = banOverusedUser
|
||||
case "cron_recycle_upload_session":
|
||||
handler = uploadSessionCollect
|
||||
default:
|
||||
util.Log().Warning("Unknown crontab job type %q, skipping...", k)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := Cron.AddFunc(v, handler); err != nil {
|
||||
util.Log().Warning("Failed to start crontab job %q: %s", k, err)
|
||||
}
|
||||
|
||||
}
|
||||
Cron.Start()
|
||||
}
|
83
pkg/crontab/vas.go
Normal file
83
pkg/crontab/vas.go
Normal file
@ -0,0 +1,83 @@
|
||||
package crontab
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/email"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
func notifyExpiredVAS() {
|
||||
checkStoragePack()
|
||||
checkUserGroup()
|
||||
util.Log().Info("Crontab job \"cron_notify_user\" complete.")
|
||||
}
|
||||
|
||||
// banOverusedUser 封禁超出宽容期的用户
|
||||
func banOverusedUser() {
|
||||
users := model.GetTolerantExpiredUser()
|
||||
for _, user := range users {
|
||||
|
||||
// 清除最后通知日期标记
|
||||
user.ClearNotified()
|
||||
|
||||
// 检查容量是否超额
|
||||
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
|
||||
// 封禁用户
|
||||
user.SetStatus(model.OveruseBaned)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkUserGroup 检查已过期用户组
|
||||
func checkUserGroup() {
|
||||
users := model.GetGroupExpiredUsers()
|
||||
for _, user := range users {
|
||||
|
||||
// 将用户回退到初始用户组
|
||||
user.GroupFallback()
|
||||
|
||||
// 重新加载用户
|
||||
user, _ = model.GetUserByID(user.ID)
|
||||
|
||||
// 检查容量是否超额
|
||||
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
|
||||
// 如果超额,则通知用户
|
||||
sendNotification(&user, "用户组过期")
|
||||
// 更新最后通知日期
|
||||
user.Notified()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkStoragePack 检查已过期的容量包
|
||||
func checkStoragePack() {
|
||||
packs := model.GetExpiredStoragePack()
|
||||
for _, pack := range packs {
|
||||
// 删除过期的容量包
|
||||
pack.Delete()
|
||||
|
||||
//找到所属用户
|
||||
user, err := model.GetUserByID(pack.UserID)
|
||||
if err != nil {
|
||||
util.Log().Warning("Crontab job failed to get user info of [UID=%d]: %s", pack.UserID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查容量是否超额
|
||||
if user.Storage > user.Group.MaxStorage+user.GetAvailablePackSize() {
|
||||
// 如果超额,则通知用户
|
||||
sendNotification(&user, "容量包过期")
|
||||
|
||||
// 更新最后通知日期
|
||||
user.Notified()
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sendNotification(user *model.User, reason string) {
|
||||
title, body := email.NewOveruseNotification(user.Nick, reason)
|
||||
if err := email.Send(user.Email, title, body); err != nil {
|
||||
util.Log().Warning("Failed to send notification email: %s", err)
|
||||
}
|
||||
}
|
52
pkg/email/init.go
Normal file
52
pkg/email/init.go
Normal file
@ -0,0 +1,52 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Client 默认的邮件发送客户端
|
||||
var Client Driver
|
||||
|
||||
// Lock 读写锁
|
||||
var Lock sync.RWMutex
|
||||
|
||||
// Init 初始化
|
||||
func Init() {
|
||||
util.Log().Debug("Initializing email sending queue...")
|
||||
Lock.Lock()
|
||||
defer Lock.Unlock()
|
||||
|
||||
if Client != nil {
|
||||
Client.Close()
|
||||
}
|
||||
|
||||
// 读取SMTP设置
|
||||
options := model.GetSettingByNames(
|
||||
"fromName",
|
||||
"fromAdress",
|
||||
"smtpHost",
|
||||
"replyTo",
|
||||
"smtpUser",
|
||||
"smtpPass",
|
||||
"smtpEncryption",
|
||||
)
|
||||
port := model.GetIntSetting("smtpPort", 25)
|
||||
keepAlive := model.GetIntSetting("mail_keepalive", 30)
|
||||
|
||||
client := NewSMTPClient(SMTPConfig{
|
||||
Name: options["fromName"],
|
||||
Address: options["fromAdress"],
|
||||
ReplyTo: options["replyTo"],
|
||||
Host: options["smtpHost"],
|
||||
Port: port,
|
||||
User: options["smtpUser"],
|
||||
Password: options["smtpPass"],
|
||||
Keepalive: keepAlive,
|
||||
Encryption: model.IsTrueVal(options["smtpEncryption"]),
|
||||
})
|
||||
|
||||
Client = client
|
||||
}
|
38
pkg/email/mail.go
Normal file
38
pkg/email/mail.go
Normal file
@ -0,0 +1,38 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Driver 邮件发送驱动
|
||||
type Driver interface {
|
||||
// Close 关闭驱动
|
||||
Close()
|
||||
// Send 发送邮件
|
||||
Send(to, title, body string) error
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrChanNotOpen 邮件队列未开启
|
||||
ErrChanNotOpen = errors.New("email queue is not started")
|
||||
// ErrNoActiveDriver 无可用邮件发送服务
|
||||
ErrNoActiveDriver = errors.New("no avaliable email provider")
|
||||
)
|
||||
|
||||
// Send 发送邮件
|
||||
func Send(to, title, body string) error {
|
||||
// 忽略通过QQ登录的邮箱
|
||||
if strings.HasSuffix(to, "@login.qq.com") {
|
||||
return nil
|
||||
}
|
||||
|
||||
Lock.RLock()
|
||||
defer Lock.RUnlock()
|
||||
|
||||
if Client == nil {
|
||||
return ErrNoActiveDriver
|
||||
}
|
||||
|
||||
return Client.Send(to, title, body)
|
||||
}
|
122
pkg/email/smtp.go
Normal file
122
pkg/email/smtp.go
Normal file
@ -0,0 +1,122 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/go-mail/mail"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// SMTP SMTP协议发送邮件
|
||||
type SMTP struct {
|
||||
Config SMTPConfig
|
||||
ch chan *mail.Message
|
||||
chOpen bool
|
||||
}
|
||||
|
||||
// SMTPConfig SMTP发送配置
|
||||
type SMTPConfig struct {
|
||||
Name string // 发送者名
|
||||
Address string // 发送者地址
|
||||
ReplyTo string // 回复地址
|
||||
Host string // 服务器主机名
|
||||
Port int // 服务器端口
|
||||
User string // 用户名
|
||||
Password string // 密码
|
||||
Encryption bool // 是否启用加密
|
||||
Keepalive int // SMTP 连接保留时长
|
||||
}
|
||||
|
||||
// NewSMTPClient 新建SMTP发送队列
|
||||
func NewSMTPClient(config SMTPConfig) *SMTP {
|
||||
client := &SMTP{
|
||||
Config: config,
|
||||
ch: make(chan *mail.Message, 30),
|
||||
chOpen: false,
|
||||
}
|
||||
|
||||
client.Init()
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// Send 发送邮件
|
||||
func (client *SMTP) Send(to, title, body string) error {
|
||||
if !client.chOpen {
|
||||
return ErrChanNotOpen
|
||||
}
|
||||
m := mail.NewMessage()
|
||||
m.SetAddressHeader("From", client.Config.Address, client.Config.Name)
|
||||
m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name)
|
||||
m.SetHeader("To", to)
|
||||
m.SetHeader("Subject", title)
|
||||
m.SetHeader("Message-ID", util.StrConcat(`"<`, uuid.NewString(), `@`, `cloudreveplus`, `>"`))
|
||||
m.SetBody("text/html", body)
|
||||
client.ch <- m
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭发送队列
|
||||
func (client *SMTP) Close() {
|
||||
if client.ch != nil {
|
||||
close(client.ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Init 初始化发送队列
|
||||
func (client *SMTP) Init() {
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
client.chOpen = false
|
||||
util.Log().Error("Exception while sending email: %s, queue will be reset in 10 seconds.", err)
|
||||
time.Sleep(time.Duration(10) * time.Second)
|
||||
client.Init()
|
||||
}
|
||||
}()
|
||||
|
||||
d := mail.NewDialer(client.Config.Host, client.Config.Port, client.Config.User, client.Config.Password)
|
||||
d.Timeout = time.Duration(client.Config.Keepalive+5) * time.Second
|
||||
client.chOpen = true
|
||||
// 是否启用 SSL
|
||||
d.SSL = false
|
||||
if client.Config.Encryption {
|
||||
d.SSL = true
|
||||
}
|
||||
d.StartTLSPolicy = mail.OpportunisticStartTLS
|
||||
|
||||
var s mail.SendCloser
|
||||
var err error
|
||||
open := false
|
||||
for {
|
||||
select {
|
||||
case m, ok := <-client.ch:
|
||||
if !ok {
|
||||
util.Log().Debug("Email queue closing...")
|
||||
client.chOpen = false
|
||||
return
|
||||
}
|
||||
if !open {
|
||||
if s, err = d.Dial(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
open = true
|
||||
}
|
||||
if err := mail.Send(s, m); err != nil {
|
||||
util.Log().Warning("Failed to send email: %s", err)
|
||||
} else {
|
||||
util.Log().Debug("Email sent.")
|
||||
}
|
||||
// 长时间没有新邮件,则关闭SMTP连接
|
||||
case <-time.After(time.Duration(client.Config.Keepalive) * time.Second):
|
||||
if open {
|
||||
if err := s.Close(); err != nil {
|
||||
util.Log().Warning("Failed to close SMTP connection: %s", err)
|
||||
}
|
||||
open = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
50
pkg/email/template.go
Normal file
50
pkg/email/template.go
Normal file
@ -0,0 +1,50 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// NewOveruseNotification 新建超额提醒邮件
|
||||
func NewOveruseNotification(userName, reason string) (string, string) {
|
||||
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "over_used_template")
|
||||
replace := map[string]string{
|
||||
"{siteTitle}": options["siteName"],
|
||||
"{userName}": userName,
|
||||
"{notifyReason}": reason,
|
||||
"{siteUrl}": options["siteURL"],
|
||||
"{siteSecTitle}": options["siteTitle"],
|
||||
}
|
||||
return fmt.Sprintf("【%s】空间容量超额提醒", options["siteName"]),
|
||||
util.Replace(replace, options["over_used_template"])
|
||||
}
|
||||
|
||||
// NewActivationEmail 新建激活邮件
|
||||
func NewActivationEmail(userName, activateURL string) (string, string) {
|
||||
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template")
|
||||
replace := map[string]string{
|
||||
"{siteTitle}": options["siteName"],
|
||||
"{userName}": userName,
|
||||
"{activationUrl}": activateURL,
|
||||
"{siteUrl}": options["siteURL"],
|
||||
"{siteSecTitle}": options["siteTitle"],
|
||||
}
|
||||
return fmt.Sprintf("【%s】注册激活", options["siteName"]),
|
||||
util.Replace(replace, options["mail_activation_template"])
|
||||
}
|
||||
|
||||
// NewResetEmail 新建重设密码邮件
|
||||
func NewResetEmail(userName, resetURL string) (string, string) {
|
||||
options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_reset_pwd_template")
|
||||
replace := map[string]string{
|
||||
"{siteTitle}": options["siteName"],
|
||||
"{userName}": userName,
|
||||
"{resetUrl}": resetURL,
|
||||
"{siteUrl}": options["siteURL"],
|
||||
"{siteSecTitle}": options["siteTitle"],
|
||||
}
|
||||
return fmt.Sprintf("【%s】密码重置", options["siteName"]),
|
||||
util.Replace(replace, options["mail_reset_pwd_template"])
|
||||
}
|
309
pkg/filesystem/archive.go
Normal file
309
pkg/filesystem/archive.go
Normal file
@ -0,0 +1,309 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mholt/archiver/v4"
|
||||
)
|
||||
|
||||
/* ===============
|
||||
压缩/解压缩
|
||||
===============
|
||||
*/
|
||||
|
||||
// Compress 创建给定目录和文件的压缩文件
|
||||
func (fs *FileSystem) Compress(ctx context.Context, writer io.Writer, folderIDs, fileIDs []uint, isArchive bool) error {
|
||||
// 查找待压缩目录
|
||||
folders, err := model.GetFoldersByIDs(folderIDs, fs.User.ID)
|
||||
if err != nil && len(folderIDs) != 0 {
|
||||
return ErrDBListObjects
|
||||
}
|
||||
|
||||
// 查找待压缩文件
|
||||
files, err := model.GetFilesByIDs(fileIDs, fs.User.ID)
|
||||
if err != nil && len(fileIDs) != 0 {
|
||||
return ErrDBListObjects
|
||||
}
|
||||
|
||||
// 如果上下文限制了父目录,则进行检查
|
||||
if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok {
|
||||
// 检查目录
|
||||
for _, folder := range folders {
|
||||
if *folder.ParentID != parent.ID {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
}
|
||||
|
||||
// 检查文件
|
||||
for _, file := range files {
|
||||
if file.FolderID != parent.ID {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试获取请求上下文,以便于后续检查用户取消任务
|
||||
reqContext := ctx
|
||||
ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context)
|
||||
if ok {
|
||||
reqContext = ginCtx.Request.Context()
|
||||
}
|
||||
|
||||
// 将顶级待处理对象的路径设为根路径
|
||||
for i := 0; i < len(folders); i++ {
|
||||
folders[i].Position = ""
|
||||
}
|
||||
for i := 0; i < len(files); i++ {
|
||||
files[i].Position = ""
|
||||
}
|
||||
|
||||
// 创建压缩文件Writer
|
||||
zipWriter := zip.NewWriter(writer)
|
||||
defer zipWriter.Close()
|
||||
|
||||
ctx = reqContext
|
||||
|
||||
// 压缩各个目录及文件
|
||||
for i := 0; i < len(folders); i++ {
|
||||
select {
|
||||
case <-reqContext.Done():
|
||||
// 取消压缩请求
|
||||
return ErrClientCanceled
|
||||
default:
|
||||
fs.doCompress(reqContext, nil, &folders[i], zipWriter, isArchive)
|
||||
}
|
||||
|
||||
}
|
||||
for i := 0; i < len(files); i++ {
|
||||
select {
|
||||
case <-reqContext.Done():
|
||||
// 取消压缩请求
|
||||
return ErrClientCanceled
|
||||
default:
|
||||
fs.doCompress(reqContext, &files[i], nil, zipWriter, isArchive)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fs *FileSystem) doCompress(ctx context.Context, file *model.File, folder *model.Folder, zipWriter *zip.Writer, isArchive bool) {
|
||||
// 如果对象是文件
|
||||
if file != nil {
|
||||
// 切换上传策略
|
||||
fs.Policy = file.GetPolicy()
|
||||
err := fs.DispatchHandler()
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to compress file %q: %s", file.Name, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文件内容
|
||||
fileToZip, err := fs.Handler.Get(
|
||||
context.WithValue(ctx, fsctx.FileModelCtx, *file),
|
||||
file.SourceName,
|
||||
)
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to open %q: %s", file.Name, err)
|
||||
return
|
||||
}
|
||||
if closer, ok := fileToZip.(io.Closer); ok {
|
||||
defer closer.Close()
|
||||
}
|
||||
|
||||
// 创建压缩文件头
|
||||
header := &zip.FileHeader{
|
||||
Name: filepath.FromSlash(path.Join(file.Position, file.Name)),
|
||||
Modified: file.UpdatedAt,
|
||||
UncompressedSize64: file.Size,
|
||||
}
|
||||
|
||||
// 指定是压缩还是归档
|
||||
if isArchive {
|
||||
header.Method = zip.Store
|
||||
} else {
|
||||
header.Method = zip.Deflate
|
||||
}
|
||||
|
||||
writer, err := zipWriter.CreateHeader(header)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = io.Copy(writer, fileToZip)
|
||||
} else if folder != nil {
|
||||
// 对象是目录
|
||||
// 获取子文件
|
||||
subFiles, err := folder.GetChildFiles()
|
||||
if err == nil && len(subFiles) > 0 {
|
||||
for i := 0; i < len(subFiles); i++ {
|
||||
fs.doCompress(ctx, &subFiles[i], nil, zipWriter, isArchive)
|
||||
}
|
||||
|
||||
}
|
||||
// 获取子目录,继续递归遍历
|
||||
subFolders, err := folder.GetChildFolder()
|
||||
if err == nil && len(subFolders) > 0 {
|
||||
for i := 0; i < len(subFolders); i++ {
|
||||
fs.doCompress(ctx, nil, &subFolders[i], zipWriter, isArchive)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decompress 解压缩给定压缩文件到dst目录
|
||||
func (fs *FileSystem) Decompress(ctx context.Context, src, dst, encoding string) error {
|
||||
err := fs.ResetFileIfNotExist(ctx, src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tempZipFilePath := ""
|
||||
defer func() {
|
||||
// 结束时删除临时压缩文件
|
||||
if tempZipFilePath != "" {
|
||||
if err := os.Remove(tempZipFilePath); err != nil {
|
||||
util.Log().Warning("Failed to delete temp archive file %q: %s", tempZipFilePath, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 下载压缩文件到临时目录
|
||||
fileStream, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer fileStream.Close()
|
||||
|
||||
tempZipFilePath = filepath.Join(
|
||||
util.RelativePath(model.GetSettingByName("temp_path")),
|
||||
"decompress",
|
||||
fmt.Sprintf("archive_%d.zip", time.Now().UnixNano()),
|
||||
)
|
||||
|
||||
zipFile, err := util.CreatNestedFile(tempZipFilePath)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to create temp archive file %q: %s", tempZipFilePath, err)
|
||||
tempZipFilePath = ""
|
||||
return err
|
||||
}
|
||||
defer zipFile.Close()
|
||||
|
||||
// 下载前先判断是否是可解压的格式
|
||||
format, readStream, err := archiver.Identify(fs.FileTarget[0].SourceName, fileStream)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to detect compressed format of file %q: %s", fs.FileTarget[0].SourceName, err)
|
||||
return err
|
||||
}
|
||||
|
||||
extractor, ok := format.(archiver.Extractor)
|
||||
if !ok {
|
||||
return fmt.Errorf("file not an extractor %s", fs.FileTarget[0].SourceName)
|
||||
}
|
||||
|
||||
// 只有zip格式可以多个文件同时上传
|
||||
var isZip bool
|
||||
switch extractor.(type) {
|
||||
case archiver.Zip:
|
||||
extractor = archiver.Zip{TextEncoding: encoding}
|
||||
isZip = true
|
||||
}
|
||||
|
||||
// 除了zip必须下载到本地,其余的可以边下载边解压
|
||||
reader := readStream
|
||||
if isZip {
|
||||
_, err = io.Copy(zipFile, readStream)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to write temp archive file %q: %s", tempZipFilePath, err)
|
||||
return err
|
||||
}
|
||||
|
||||
fileStream.Close()
|
||||
|
||||
// 设置文件偏移量
|
||||
zipFile.Seek(0, io.SeekStart)
|
||||
reader = zipFile
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
parallel := model.GetIntSetting("max_parallel_transfer", 4)
|
||||
worker := make(chan int, parallel)
|
||||
for i := 0; i < parallel; i++ {
|
||||
worker <- i
|
||||
}
|
||||
|
||||
// 上传文件函数
|
||||
uploadFunc := func(fileStream io.ReadCloser, size int64, savePath, rawPath string) {
|
||||
defer func() {
|
||||
if isZip {
|
||||
worker <- 1
|
||||
wg.Done()
|
||||
}
|
||||
if err := recover(); err != nil {
|
||||
util.Log().Warning("Error while uploading files inside of archive file.")
|
||||
fmt.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
err := fs.UploadFromStream(ctx, &fsctx.FileStream{
|
||||
File: fileStream,
|
||||
Size: uint64(size),
|
||||
Name: path.Base(savePath),
|
||||
VirtualPath: path.Dir(savePath),
|
||||
}, true)
|
||||
fileStream.Close()
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to upload file %q in archive file: %s, skipping...", rawPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 解压缩文件,回调函数如果出错会停止解压的下一步进行,全部return nil
|
||||
err = extractor.Extract(ctx, reader, nil, func(ctx context.Context, f archiver.File) error {
|
||||
rawPath := util.FormSlash(f.NameInArchive)
|
||||
savePath := path.Join(dst, rawPath)
|
||||
// 路径是否合法
|
||||
if !strings.HasPrefix(savePath, util.FillSlash(path.Clean(dst))) {
|
||||
util.Log().Warning("%s: illegal file path", f.NameInArchive)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果是目录
|
||||
if f.FileInfo.IsDir() {
|
||||
fs.CreateDirectory(ctx, savePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 上传文件
|
||||
fileStream, err := f.Open()
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to open file %q in archive file: %s, skipping...", rawPath, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isZip {
|
||||
uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
|
||||
} else {
|
||||
<-worker
|
||||
wg.Add(1)
|
||||
go uploadFunc(fileStream, f.FileInfo.Size(), savePath, rawPath)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
return err
|
||||
|
||||
}
|
74
pkg/filesystem/chunk/backoff/backoff.go
Normal file
74
pkg/filesystem/chunk/backoff/backoff.go
Normal file
@ -0,0 +1,74 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Backoff used for retry sleep backoff
|
||||
type Backoff interface {
|
||||
Next(err error) bool
|
||||
Reset()
|
||||
}
|
||||
|
||||
// ConstantBackoff implements Backoff interface with constant sleep time. If the error
|
||||
// is retryable and with `RetryAfter` defined, the `RetryAfter` will be used as sleep duration.
|
||||
type ConstantBackoff struct {
|
||||
Sleep time.Duration
|
||||
Max int
|
||||
|
||||
tried int
|
||||
}
|
||||
|
||||
func (c *ConstantBackoff) Next(err error) bool {
|
||||
c.tried++
|
||||
if c.tried > c.Max {
|
||||
return false
|
||||
}
|
||||
|
||||
var e *RetryableError
|
||||
if errors.As(err, &e) && e.RetryAfter > 0 {
|
||||
util.Log().Warning("Retryable error %q occurs in backoff, will sleep after %s.", e, e.RetryAfter)
|
||||
time.Sleep(e.RetryAfter)
|
||||
} else {
|
||||
time.Sleep(c.Sleep)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ConstantBackoff) Reset() {
|
||||
c.tried = 0
|
||||
}
|
||||
|
||||
type RetryableError struct {
|
||||
Err error
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
// NewRetryableErrorFromHeader constructs a new RetryableError from http response header
|
||||
// and existing error.
|
||||
func NewRetryableErrorFromHeader(err error, header http.Header) *RetryableError {
|
||||
retryAfter := header.Get("retry-after")
|
||||
if retryAfter == "" {
|
||||
retryAfter = "0"
|
||||
}
|
||||
|
||||
res := &RetryableError{
|
||||
Err: err,
|
||||
}
|
||||
|
||||
if retryAfterSecond, err := strconv.ParseInt(retryAfter, 10, 64); err == nil {
|
||||
res.RetryAfter = time.Duration(retryAfterSecond) * time.Second
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
func (e *RetryableError) Error() string {
|
||||
return fmt.Sprintf("retryable error with retry-after=%s: %s", e.RetryAfter, e.Err)
|
||||
}
|
167
pkg/filesystem/chunk/chunk.go
Normal file
167
pkg/filesystem/chunk/chunk.go
Normal file
@ -0,0 +1,167 @@
|
||||
package chunk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
const bufferTempPattern = "cdChunk.*.tmp"
|
||||
|
||||
// ChunkProcessFunc callback function for processing a chunk
|
||||
type ChunkProcessFunc func(c *ChunkGroup, chunk io.Reader) error
|
||||
|
||||
// ChunkGroup manage groups of chunks
|
||||
type ChunkGroup struct {
|
||||
file fsctx.FileHeader
|
||||
chunkSize uint64
|
||||
backoff backoff.Backoff
|
||||
enableRetryBuffer bool
|
||||
|
||||
fileInfo *fsctx.UploadTaskInfo
|
||||
currentIndex int
|
||||
chunkNum uint64
|
||||
bufferTemp *os.File
|
||||
}
|
||||
|
||||
func NewChunkGroup(file fsctx.FileHeader, chunkSize uint64, backoff backoff.Backoff, useBuffer bool) *ChunkGroup {
|
||||
c := &ChunkGroup{
|
||||
file: file,
|
||||
chunkSize: chunkSize,
|
||||
backoff: backoff,
|
||||
fileInfo: file.Info(),
|
||||
currentIndex: -1,
|
||||
enableRetryBuffer: useBuffer,
|
||||
}
|
||||
|
||||
if c.chunkSize == 0 {
|
||||
c.chunkSize = c.fileInfo.Size
|
||||
}
|
||||
|
||||
if c.fileInfo.Size == 0 {
|
||||
c.chunkNum = 1
|
||||
} else {
|
||||
c.chunkNum = c.fileInfo.Size / c.chunkSize
|
||||
if c.fileInfo.Size%c.chunkSize != 0 {
|
||||
c.chunkNum++
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// TempAvailable returns if current chunk temp file is available to be read
|
||||
func (c *ChunkGroup) TempAvailable() bool {
|
||||
if c.bufferTemp != nil {
|
||||
state, _ := c.bufferTemp.Stat()
|
||||
return state != nil && state.Size() == c.Length()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Process a chunk with retry logic
|
||||
func (c *ChunkGroup) Process(processor ChunkProcessFunc) error {
|
||||
reader := io.LimitReader(c.file, c.Length())
|
||||
|
||||
// If useBuffer is enabled, tee the reader to a temp file
|
||||
if c.enableRetryBuffer && c.bufferTemp == nil && !c.file.Seekable() {
|
||||
c.bufferTemp, _ = os.CreateTemp("", bufferTempPattern)
|
||||
reader = io.TeeReader(reader, c.bufferTemp)
|
||||
}
|
||||
|
||||
if c.bufferTemp != nil {
|
||||
defer func() {
|
||||
if c.bufferTemp != nil {
|
||||
c.bufferTemp.Close()
|
||||
os.Remove(c.bufferTemp.Name())
|
||||
c.bufferTemp = nil
|
||||
}
|
||||
}()
|
||||
|
||||
// if temp buffer file is available, use it
|
||||
if c.TempAvailable() {
|
||||
if _, err := c.bufferTemp.Seek(0, io.SeekStart); err != nil {
|
||||
return fmt.Errorf("failed to seek temp file back to chunk start: %w", err)
|
||||
}
|
||||
|
||||
util.Log().Debug("Chunk %d will be read from temp file %q.", c.Index(), c.bufferTemp.Name())
|
||||
reader = io.NopCloser(c.bufferTemp)
|
||||
}
|
||||
}
|
||||
|
||||
err := processor(c, reader)
|
||||
if err != nil {
|
||||
if c.enableRetryBuffer {
|
||||
request.BlackHole(reader)
|
||||
}
|
||||
|
||||
if err != context.Canceled && (c.file.Seekable() || c.TempAvailable()) && c.backoff.Next(err) {
|
||||
if c.file.Seekable() {
|
||||
if _, seekErr := c.file.Seek(c.Start(), io.SeekStart); seekErr != nil {
|
||||
return fmt.Errorf("failed to seek back to chunk start: %w, last error: %s", seekErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
util.Log().Debug("Retrying chunk %d, last error: %s", c.currentIndex, err)
|
||||
return c.Process(processor)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
util.Log().Debug("Chunk %d processed", c.currentIndex)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start returns the byte index of current chunk
|
||||
func (c *ChunkGroup) Start() int64 {
|
||||
return int64(uint64(c.Index()) * c.chunkSize)
|
||||
}
|
||||
|
||||
// Total returns the total length
|
||||
func (c *ChunkGroup) Total() int64 {
|
||||
return int64(c.fileInfo.Size)
|
||||
}
|
||||
|
||||
// Num returns the total chunk number
|
||||
func (c *ChunkGroup) Num() int {
|
||||
return int(c.chunkNum)
|
||||
}
|
||||
|
||||
// RangeHeader returns header value of Content-Range
|
||||
func (c *ChunkGroup) RangeHeader() string {
|
||||
return fmt.Sprintf("bytes %d-%d/%d", c.Start(), c.Start()+c.Length()-1, c.Total())
|
||||
}
|
||||
|
||||
// Index returns current chunk index, starts from 0
|
||||
func (c *ChunkGroup) Index() int {
|
||||
return c.currentIndex
|
||||
}
|
||||
|
||||
// Next switch to next chunk, returns whether all chunks are processed
|
||||
func (c *ChunkGroup) Next() bool {
|
||||
c.currentIndex++
|
||||
c.backoff.Reset()
|
||||
return c.currentIndex < int(c.chunkNum)
|
||||
}
|
||||
|
||||
// Length returns the length of current chunk
|
||||
func (c *ChunkGroup) Length() int64 {
|
||||
contentLength := c.chunkSize
|
||||
if c.Index() == int(c.chunkNum-1) {
|
||||
contentLength = c.fileInfo.Size - c.chunkSize*(c.chunkNum-1)
|
||||
}
|
||||
|
||||
return int64(contentLength)
|
||||
}
|
||||
|
||||
// IsLast returns if current chunk is the last one
|
||||
func (c *ChunkGroup) IsLast() bool {
|
||||
return c.Index() == int(c.chunkNum-1)
|
||||
}
|
427
pkg/filesystem/driver/cos/handler.go
Normal file
427
pkg/filesystem/driver/cos/handler.go
Normal file
@ -0,0 +1,427 @@
|
||||
package cos
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/google/go-querystring/query"
|
||||
cossdk "github.com/tencentyun/cos-go-sdk-v5"
|
||||
)
|
||||
|
||||
// UploadPolicy 腾讯云COS上传策略
|
||||
type UploadPolicy struct {
|
||||
Expiration string `json:"expiration"`
|
||||
Conditions []interface{} `json:"conditions"`
|
||||
}
|
||||
|
||||
// MetaData 文件元信息
|
||||
type MetaData struct {
|
||||
Size uint64
|
||||
CallbackKey string
|
||||
CallbackURL string
|
||||
}
|
||||
|
||||
type urlOption struct {
|
||||
Speed int `url:"x-cos-traffic-limit,omitempty"`
|
||||
ContentDescription string `url:"response-content-disposition,omitempty"`
|
||||
}
|
||||
|
||||
// Driver 腾讯云COS适配器模板
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
Client *cossdk.Client
|
||||
HTTPClient request.Client
|
||||
}
|
||||
|
||||
// List 列出COS文件
|
||||
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// 初始化列目录参数
|
||||
opt := &cossdk.BucketGetOptions{
|
||||
Prefix: strings.TrimPrefix(base, "/"),
|
||||
EncodingType: "",
|
||||
MaxKeys: 1000,
|
||||
}
|
||||
// 是否为递归列出
|
||||
if !recursive {
|
||||
opt.Delimiter = "/"
|
||||
}
|
||||
// 手动补齐结尾的slash
|
||||
if opt.Prefix != "" {
|
||||
opt.Prefix += "/"
|
||||
}
|
||||
|
||||
var (
|
||||
marker string
|
||||
objects []cossdk.Object
|
||||
commons []string
|
||||
)
|
||||
|
||||
for {
|
||||
res, _, err := handler.Client.Bucket.Get(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, res.Contents...)
|
||||
commons = append(commons, res.CommonPrefixes...)
|
||||
// 如果本次未列取完,则继续使用marker获取结果
|
||||
marker = res.NextMarker
|
||||
// marker 为空时结果列取完毕,跳出
|
||||
if marker == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
// 处理目录
|
||||
for _, object := range commons {
|
||||
rel, err := filepath.Rel(opt.Prefix, object)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object),
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: 0,
|
||||
IsDir: true,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
}
|
||||
// 处理文件
|
||||
for _, object := range objects {
|
||||
rel, err := filepath.Rel(opt.Prefix, object.Key)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object.Key),
|
||||
Source: object.Key,
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: uint64(object.Size),
|
||||
IsDir: false,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return res, nil
|
||||
|
||||
}
|
||||
|
||||
// CORS 创建跨域策略
|
||||
func (handler Driver) CORS() error {
|
||||
_, err := handler.Client.Bucket.PutCORS(context.Background(), &cossdk.BucketPutCORSOptions{
|
||||
Rules: []cossdk.BucketCORSRule{{
|
||||
AllowedMethods: []string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
},
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedHeaders: []string{"*"},
|
||||
MaxAgeSeconds: 3600,
|
||||
ExposeHeaders: []string{},
|
||||
}},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 获取文件源地址
|
||||
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取文件数据流
|
||||
resp, err := handler.HTTPClient.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
).CheckHTTPResponse(200).GetRSCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.SetFirstFakeChunk()
|
||||
|
||||
// 尝试自主获取文件大小
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
resp.SetContentLength(int64(file.Size))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
|
||||
opt := &cossdk.ObjectPutOptions{}
|
||||
_, err := handler.Client.Object.Put(ctx, file.Info().SavePath, file, opt)
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
obs := []cossdk.Object{}
|
||||
for _, v := range files {
|
||||
obs = append(obs, cossdk.Object{Key: v})
|
||||
}
|
||||
opt := &cossdk.ObjectDeleteMultiOptions{
|
||||
Objects: obs,
|
||||
Quiet: true,
|
||||
}
|
||||
|
||||
res, _, err := handler.Client.Object.DeleteMulti(context.Background(), opt)
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
// 整理删除结果
|
||||
failed := make([]string, 0, len(files))
|
||||
for _, v := range res.Errors {
|
||||
failed = append(failed, v.Key)
|
||||
}
|
||||
|
||||
if len(failed) == 0 {
|
||||
return failed, nil
|
||||
}
|
||||
|
||||
return failed, errors.New("delete failed")
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
// quick check by extension name
|
||||
// https://cloud.tencent.com/document/product/436/44893
|
||||
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heif", "heic"}
|
||||
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
|
||||
supported = handler.Policy.OptionsSerialized.ThumbExts
|
||||
}
|
||||
|
||||
if !util.IsInExtensionList(supported, file.Name) || file.Size > (32<<(10*2)) {
|
||||
return nil, driver.ErrorThumbNotSupported
|
||||
}
|
||||
|
||||
var (
|
||||
thumbSize = [2]uint{400, 300}
|
||||
ok = false
|
||||
)
|
||||
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
|
||||
return nil, errors.New("failed to get thumbnail size")
|
||||
}
|
||||
|
||||
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
|
||||
|
||||
thumbParam := fmt.Sprintf("imageMogr2/thumbnail/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality)
|
||||
|
||||
source, err := handler.signSourceURL(
|
||||
ctx,
|
||||
file.SourceName,
|
||||
int64(model.GetIntSetting("preview_timeout", 60)),
|
||||
&urlOption{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
thumbURL, _ := url.Parse(source)
|
||||
thumbQuery := thumbURL.Query()
|
||||
thumbQuery.Add(thumbParam, "")
|
||||
thumbURL.RawQuery = thumbQuery.Encode()
|
||||
|
||||
return &response.ContentResponse{
|
||||
Redirect: true,
|
||||
URL: thumbURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
// 尝试从上下文获取文件名
|
||||
fileName := ""
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
fileName = file.Name
|
||||
}
|
||||
|
||||
// 添加各项设置
|
||||
options := urlOption{}
|
||||
if speed > 0 {
|
||||
if speed < 819200 {
|
||||
speed = 819200
|
||||
}
|
||||
if speed > 838860800 {
|
||||
speed = 838860800
|
||||
}
|
||||
options.Speed = speed
|
||||
}
|
||||
if isDownload {
|
||||
options.ContentDescription = "attachment; filename=\"" + url.PathEscape(fileName) + "\""
|
||||
}
|
||||
|
||||
return handler.signSourceURL(ctx, path, ttl, &options)
|
||||
}
|
||||
|
||||
func (handler Driver) signSourceURL(ctx context.Context, path string, ttl int64, options *urlOption) (string, error) {
|
||||
cdnURL, err := url.Parse(handler.Policy.BaseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 公有空间不需要签名
|
||||
if !handler.Policy.IsPrivate {
|
||||
file, err := url.Parse(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 非签名URL不支持设置响应header
|
||||
options.ContentDescription = ""
|
||||
|
||||
optionQuery, err := query.Values(*options)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
file.RawQuery = optionQuery.Encode()
|
||||
sourceURL := cdnURL.ResolveReference(file)
|
||||
|
||||
return sourceURL.String(), nil
|
||||
}
|
||||
|
||||
presignedURL, err := handler.Client.Object.GetPresignedURL(ctx, http.MethodGet, path,
|
||||
handler.Policy.AccessKey, handler.Policy.SecretKey, time.Duration(ttl)*time.Second, options)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将最终生成的签名URL域名换成用户自定义的加速域名(如果有)
|
||||
presignedURL.Host = cdnURL.Host
|
||||
presignedURL.Scheme = cdnURL.Scheme
|
||||
|
||||
return presignedURL.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
// 生成回调地址
|
||||
siteURL := model.GetSiteURL()
|
||||
apiBaseURI, _ := url.Parse("/api/v3/callback/cos/" + uploadSession.Key)
|
||||
apiURL := siteURL.ResolveReference(apiBaseURI).String()
|
||||
|
||||
// 上传策略
|
||||
savePath := file.Info().SavePath
|
||||
startTime := time.Now()
|
||||
endTime := startTime.Add(time.Duration(ttl) * time.Second)
|
||||
keyTime := fmt.Sprintf("%d;%d", startTime.Unix(), endTime.Unix())
|
||||
postPolicy := UploadPolicy{
|
||||
Expiration: endTime.UTC().Format(time.RFC3339),
|
||||
Conditions: []interface{}{
|
||||
map[string]string{"bucket": handler.Policy.BucketName},
|
||||
map[string]string{"$key": savePath},
|
||||
map[string]string{"x-cos-meta-callback": apiURL},
|
||||
map[string]string{"x-cos-meta-key": uploadSession.Key},
|
||||
map[string]string{"q-sign-algorithm": "sha1"},
|
||||
map[string]string{"q-ak": handler.Policy.AccessKey},
|
||||
map[string]string{"q-sign-time": keyTime},
|
||||
},
|
||||
}
|
||||
|
||||
if handler.Policy.MaxSize > 0 {
|
||||
postPolicy.Conditions = append(postPolicy.Conditions,
|
||||
[]interface{}{"content-length-range", 0, handler.Policy.MaxSize})
|
||||
}
|
||||
|
||||
res, err := handler.getUploadCredential(ctx, postPolicy, keyTime, savePath)
|
||||
if err == nil {
|
||||
res.SessionID = uploadSession.Key
|
||||
res.Callback = apiURL
|
||||
res.UploadURLs = []string{handler.Policy.Server}
|
||||
}
|
||||
|
||||
return res, err
|
||||
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Meta 获取文件信息
|
||||
func (handler Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
|
||||
res, err := handler.Client.Object.Head(ctx, path, &cossdk.ObjectHeadOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &MetaData{
|
||||
Size: uint64(res.ContentLength),
|
||||
CallbackKey: res.Header.Get("x-cos-meta-key"),
|
||||
CallbackURL: res.Header.Get("x-cos-meta-callback"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (handler Driver) getUploadCredential(ctx context.Context, policy UploadPolicy, keyTime string, savePath string) (*serializer.UploadCredential, error) {
|
||||
// 编码上传策略
|
||||
policyJSON, err := json.Marshal(policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
|
||||
|
||||
// 签名上传策略
|
||||
hmacSign := hmac.New(sha1.New, []byte(handler.Policy.SecretKey))
|
||||
_, err = io.WriteString(hmacSign, keyTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signKey := fmt.Sprintf("%x", hmacSign.Sum(nil))
|
||||
|
||||
sha1Sign := sha1.New()
|
||||
_, err = sha1Sign.Write(policyJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stringToSign := fmt.Sprintf("%x", sha1Sign.Sum(nil))
|
||||
|
||||
// 最终签名
|
||||
hmacFinalSign := hmac.New(sha1.New, []byte(signKey))
|
||||
_, err = hmacFinalSign.Write([]byte(stringToSign))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signature := hmacFinalSign.Sum(nil)
|
||||
|
||||
return &serializer.UploadCredential{
|
||||
Policy: policyEncoded,
|
||||
Path: savePath,
|
||||
AccessKey: handler.Policy.AccessKey,
|
||||
Credential: fmt.Sprintf("%x", signature),
|
||||
KeyTime: keyTime,
|
||||
}, nil
|
||||
}
|
134
pkg/filesystem/driver/cos/scf.go
Normal file
134
pkg/filesystem/driver/cos/scf.go
Normal file
@ -0,0 +1,134 @@
|
||||
package cos
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
scf "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/scf/v20180416"
|
||||
)
|
||||
|
||||
const scfFunc = `# -*- coding: utf8 -*-
|
||||
# SCF配置COS触发,向 Cloudreve 发送回调
|
||||
from qcloud_cos_v5 import CosConfig
|
||||
from qcloud_cos_v5 import CosS3Client
|
||||
from qcloud_cos_v5 import CosServiceError
|
||||
from qcloud_cos_v5 import CosClientError
|
||||
import sys
|
||||
import logging
|
||||
import requests
|
||||
|
||||
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def main_handler(event, context):
|
||||
logger.info("start main handler")
|
||||
for record in event['Records']:
|
||||
try:
|
||||
if "x-cos-meta-callback" not in record['cos']['cosObject']['meta']:
|
||||
logger.info("Cannot find callback URL, skiped.")
|
||||
return 'Success'
|
||||
callback = record['cos']['cosObject']['meta']['x-cos-meta-callback']
|
||||
key = record['cos']['cosObject']['key']
|
||||
logger.info("Callback URL is " + callback)
|
||||
|
||||
r = requests.get(callback)
|
||||
print(r.text)
|
||||
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print('Error getting object {} callback url. '.format(key))
|
||||
raise e
|
||||
return "Fail"
|
||||
|
||||
return "Success"
|
||||
`
|
||||
|
||||
// CreateSCF 创建回调云函数
|
||||
func CreateSCF(policy *model.Policy, region string) error {
|
||||
// 初始化客户端
|
||||
credential := common.NewCredential(
|
||||
policy.AccessKey,
|
||||
policy.SecretKey,
|
||||
)
|
||||
cpf := profile.NewClientProfile()
|
||||
client, err := scf.NewClient(credential, region, cpf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建回调代码数据
|
||||
buff := &bytes.Buffer{}
|
||||
bs64 := base64.NewEncoder(base64.StdEncoding, buff)
|
||||
zipWriter := zip.NewWriter(bs64)
|
||||
header := zip.FileHeader{
|
||||
Name: "callback.py",
|
||||
Method: zip.Deflate,
|
||||
}
|
||||
writer, err := zipWriter.CreateHeader(&header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.Copy(writer, strings.NewReader(scfFunc))
|
||||
zipWriter.Close()
|
||||
|
||||
// 创建云函数
|
||||
req := scf.NewCreateFunctionRequest()
|
||||
funcName := "cloudreve_" + hashid.HashID(policy.ID, hashid.PolicyID) + strconv.FormatInt(time.Now().Unix(), 10)
|
||||
zipFileBytes, _ := ioutil.ReadAll(buff)
|
||||
zipFileStr := string(zipFileBytes)
|
||||
codeSource := "ZipFile"
|
||||
handler := "callback.main_handler"
|
||||
desc := "Cloudreve 用回调函数"
|
||||
timeout := int64(60)
|
||||
runtime := "Python3.6"
|
||||
req.FunctionName = &funcName
|
||||
req.Code = &scf.Code{
|
||||
ZipFile: &zipFileStr,
|
||||
}
|
||||
req.Handler = &handler
|
||||
req.Description = &desc
|
||||
req.Timeout = &timeout
|
||||
req.Runtime = &runtime
|
||||
req.CodeSource = &codeSource
|
||||
|
||||
_, err = client.CreateFunction(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
|
||||
// 创建触发器
|
||||
server, _ := url.Parse(policy.Server)
|
||||
triggerType := "cos"
|
||||
triggerDesc := `{"event":"cos:ObjectCreated:Post","filter":{"Prefix":"","Suffix":""}}`
|
||||
enable := "OPEN"
|
||||
|
||||
trigger := scf.NewCreateTriggerRequest()
|
||||
trigger.FunctionName = &funcName
|
||||
trigger.TriggerName = &server.Host
|
||||
trigger.Type = &triggerType
|
||||
trigger.TriggerDesc = &triggerDesc
|
||||
trigger.Enable = &enable
|
||||
|
||||
_, err = client.CreateTrigger(trigger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
73
pkg/filesystem/driver/googledrive/client.go
Normal file
73
pkg/filesystem/driver/googledrive/client.go
Normal file
@ -0,0 +1,73 @@
|
||||
package googledrive
|
||||
|
||||
import (
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"google.golang.org/api/drive/v3"
|
||||
)
|
||||
|
||||
// Client Google Drive client
|
||||
type Client struct {
|
||||
Endpoints *Endpoints
|
||||
Policy *model.Policy
|
||||
Credential *Credential
|
||||
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Redirect string
|
||||
|
||||
Request request.Client
|
||||
ClusterController cluster.Controller
|
||||
}
|
||||
|
||||
// Endpoints OneDrive客户端相关设置
|
||||
type Endpoints struct {
|
||||
UserConsentEndpoint string // OAuth认证的基URL
|
||||
TokenEndpoint string // OAuth token 基URL
|
||||
EndpointURL string // 接口请求的基URL
|
||||
}
|
||||
|
||||
const (
|
||||
TokenCachePrefix = "googledrive_"
|
||||
|
||||
oauthEndpoint = "https://oauth2.googleapis.com/token"
|
||||
userConsentBase = "https://accounts.google.com/o/oauth2/auth"
|
||||
v3DriveEndpoint = "https://www.googleapis.com/drive/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
// Defualt required scopes
|
||||
RequiredScope = []string{
|
||||
drive.DriveScope,
|
||||
"openid",
|
||||
"profile",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
}
|
||||
|
||||
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
|
||||
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
|
||||
)
|
||||
|
||||
// NewClient 根据存储策略获取新的client
|
||||
func NewClient(policy *model.Policy) (*Client, error) {
|
||||
client := &Client{
|
||||
Endpoints: &Endpoints{
|
||||
TokenEndpoint: oauthEndpoint,
|
||||
UserConsentEndpoint: userConsentBase,
|
||||
EndpointURL: v3DriveEndpoint,
|
||||
},
|
||||
Credential: &Credential{
|
||||
RefreshToken: policy.AccessKey,
|
||||
},
|
||||
Policy: policy,
|
||||
ClientID: policy.BucketName,
|
||||
ClientSecret: policy.SecretKey,
|
||||
Redirect: policy.OptionsSerialized.OauthRedirect,
|
||||
Request: request.NewClient(),
|
||||
ClusterController: cluster.DefaultController,
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
65
pkg/filesystem/driver/googledrive/handler.go
Normal file
65
pkg/filesystem/driver/googledrive/handler.go
Normal file
@ -0,0 +1,65 @@
|
||||
package googledrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Driver Google Drive 适配器
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
HTTPClient request.Client
|
||||
}
|
||||
|
||||
// NewDriver 从存储策略初始化新的Driver实例
|
||||
func NewDriver(policy *model.Policy) (driver.Handler, error) {
|
||||
return &Driver{
|
||||
Policy: policy,
|
||||
HTTPClient: request.NewClient(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
154
pkg/filesystem/driver/googledrive/oauth.go
Normal file
154
pkg/filesystem/driver/googledrive/oauth.go
Normal file
@ -0,0 +1,154 @@
|
||||
package googledrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthURL 获取OAuth认证页面URL
|
||||
func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
|
||||
query := url.Values{
|
||||
"client_id": {client.ClientID},
|
||||
"scope": {strings.Join(scope, " ")},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {client.Redirect},
|
||||
"access_type": {"offline"},
|
||||
"prompt": {"consent"},
|
||||
}
|
||||
|
||||
u, _ := url.Parse(client.Endpoints.UserConsentEndpoint)
|
||||
u.RawQuery = query.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// ObtainToken 通过code或refresh_token兑换token
|
||||
func (client *Client) ObtainToken(ctx context.Context, code, refreshToken string) (*Credential, error) {
|
||||
body := url.Values{
|
||||
"client_id": {client.ClientID},
|
||||
"redirect_uri": {client.Redirect},
|
||||
"client_secret": {client.ClientSecret},
|
||||
}
|
||||
if code != "" {
|
||||
body.Add("grant_type", "authorization_code")
|
||||
body.Add("code", code)
|
||||
} else {
|
||||
body.Add("grant_type", "refresh_token")
|
||||
body.Add("refresh_token", refreshToken)
|
||||
}
|
||||
strBody := body.Encode()
|
||||
|
||||
res := client.Request.Request(
|
||||
"POST",
|
||||
client.Endpoints.TokenEndpoint,
|
||||
io.NopCloser(strings.NewReader(strBody)),
|
||||
request.WithHeader(http.Header{
|
||||
"Content-Type": {"application/x-www-form-urlencoded"}},
|
||||
),
|
||||
request.WithContentLength(int64(len(strBody))),
|
||||
)
|
||||
if res.Err != nil {
|
||||
return nil, res.Err
|
||||
}
|
||||
|
||||
respBody, err := res.GetResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
errResp OAuthError
|
||||
credential Credential
|
||||
decodeErr error
|
||||
)
|
||||
|
||||
if res.Response.StatusCode != 200 {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||||
} else {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &credential)
|
||||
}
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
if errResp.ErrorType != "" {
|
||||
return nil, errResp
|
||||
}
|
||||
|
||||
return &credential, nil
|
||||
}
|
||||
|
||||
// UpdateCredential 更新凭证,并检查有效期
|
||||
func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
|
||||
if isSlave {
|
||||
return client.fetchCredentialFromMaster(ctx)
|
||||
}
|
||||
|
||||
oauth.GlobalMutex.Lock(client.Policy.ID)
|
||||
defer oauth.GlobalMutex.Unlock(client.Policy.ID)
|
||||
|
||||
// 如果已存在凭证
|
||||
if client.Credential != nil && client.Credential.AccessToken != "" {
|
||||
// 检查已有凭证是否过期
|
||||
if client.Credential.ExpiresIn > time.Now().Unix() {
|
||||
// 未过期,不要更新
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试从缓存中获取凭证
|
||||
if cacheCredential, ok := cache.Get(TokenCachePrefix + client.ClientID); ok {
|
||||
credential := cacheCredential.(Credential)
|
||||
if credential.ExpiresIn > time.Now().Unix() {
|
||||
client.Credential = &credential
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 获取新的凭证
|
||||
if client.Credential == nil || client.Credential.RefreshToken == "" {
|
||||
// 无有效的RefreshToken
|
||||
util.Log().Error("Failed to refresh credential for policy %q, please login your Google account again.", client.Policy.Name)
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
credential, err := client.ObtainToken(ctx, "", client.Credential.RefreshToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新有效期为绝对时间戳
|
||||
expires := credential.ExpiresIn - 60
|
||||
credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
|
||||
// refresh token for Google Drive does not expire in production
|
||||
credential.RefreshToken = client.Credential.RefreshToken
|
||||
client.Credential = credential
|
||||
|
||||
// 更新缓存
|
||||
cache.Set(TokenCachePrefix+client.ClientID, *credential, int(expires))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) AccessToken() string {
|
||||
return client.Credential.AccessToken
|
||||
}
|
||||
|
||||
// UpdateCredential 更新凭证,并检查有效期
|
||||
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
|
||||
res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.Credential = &Credential{AccessToken: res}
|
||||
return nil
|
||||
}
|
43
pkg/filesystem/driver/googledrive/types.go
Normal file
43
pkg/filesystem/driver/googledrive/types.go
Normal file
@ -0,0 +1,43 @@
|
||||
package googledrive
|
||||
|
||||
import "encoding/gob"
|
||||
|
||||
// RespError 接口返回错误
|
||||
type RespError struct {
|
||||
APIError APIError `json:"error"`
|
||||
}
|
||||
|
||||
// APIError 接口返回的错误内容
|
||||
type APIError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (err RespError) Error() string {
|
||||
return err.APIError.Message
|
||||
}
|
||||
|
||||
// Credential 获取token时返回的凭证
|
||||
type Credential struct {
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// OAuthError OAuth相关接口的错误响应
|
||||
type OAuthError struct {
|
||||
ErrorType string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (err OAuthError) Error() string {
|
||||
return err.ErrorDescription
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Credential{})
|
||||
}
|
52
pkg/filesystem/driver/handler.go
Normal file
52
pkg/filesystem/driver/handler.go
Normal file
@ -0,0 +1,52 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorThumbNotExist = fmt.Errorf("thumb not exist")
|
||||
ErrorThumbNotSupported = fmt.Errorf("thumb not supported")
|
||||
)
|
||||
|
||||
// Handler 存储策略适配器
|
||||
type Handler interface {
|
||||
// 上传文件, dst为文件存储路径,size 为文件大小。上下文关闭
|
||||
// 时,应取消上传并清理临时文件
|
||||
Put(ctx context.Context, file fsctx.FileHeader) error
|
||||
|
||||
// 删除一个或多个给定路径的文件,返回删除失败的文件路径列表及错误
|
||||
Delete(ctx context.Context, files []string) ([]string, error)
|
||||
|
||||
// 获取文件内容
|
||||
Get(ctx context.Context, path string) (response.RSCloser, error)
|
||||
|
||||
// 获取缩略图,可直接在ContentResponse中返回文件数据流,也可指
|
||||
// 定为重定向
|
||||
// 如果缩略图不存在, 且需要 Cloudreve 代理生成并上传,应返回 ErrorThumbNotExist,生
|
||||
// 成的缩略图文件存储规则与本机策略一致。
|
||||
// 如果不支持此文件的缩略图,并且不希望后续继续请求此缩略图,应返回 ErrorThumbNotSupported
|
||||
Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error)
|
||||
|
||||
// 获取外链/下载地址,
|
||||
// url - 站点本身地址,
|
||||
// isDownload - 是否直接下载
|
||||
Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error)
|
||||
|
||||
// Token 获取有效期为ttl的上传凭证和签名
|
||||
Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error)
|
||||
|
||||
// CancelToken 取消已经创建的有状态上传凭证
|
||||
CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error
|
||||
|
||||
// List 递归列取远程端path路径下文件、目录,不包含path本身,
|
||||
// 返回的对象路径以path作为起始根目录.
|
||||
// recursive - 是否递归列出
|
||||
List(ctx context.Context, path string, recursive bool) ([]response.Object, error)
|
||||
}
|
292
pkg/filesystem/driver/local/handler.go
Normal file
292
pkg/filesystem/driver/local/handler.go
Normal file
@ -0,0 +1,292 @@
|
||||
package local
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
const (
|
||||
Perm = 0744
|
||||
)
|
||||
|
||||
// Driver 本地策略适配器
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
}
|
||||
|
||||
// List 递归列取给定物理路径下所有文件
|
||||
func (handler Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
var res []response.Object
|
||||
|
||||
// 取得起始路径
|
||||
root := util.RelativePath(filepath.FromSlash(path))
|
||||
|
||||
// 开始遍历路径下的文件、目录
|
||||
err := filepath.Walk(root,
|
||||
func(path string, info os.FileInfo, err error) error {
|
||||
// 跳过根目录
|
||||
if path == root {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to walk folder %q: %s", path, err)
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
// 将遍历对象的绝对路径转换为相对路径
|
||||
rel, err := filepath.Rel(root, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res = append(res, response.Object{
|
||||
Name: info.Name(),
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Source: path,
|
||||
Size: uint64(info.Size()),
|
||||
IsDir: info.IsDir(),
|
||||
LastModify: info.ModTime(),
|
||||
})
|
||||
|
||||
// 如果非递归,则不步入目录
|
||||
if !recursive && info.IsDir() {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// Get 获取文件内容
|
||||
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 打开文件
|
||||
file, err := os.Open(util.RelativePath(path))
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to open file: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
fileInfo := file.Info()
|
||||
dst := util.RelativePath(filepath.FromSlash(fileInfo.SavePath))
|
||||
|
||||
// 如果非 Overwrite,则检查是否有重名冲突
|
||||
if fileInfo.Mode&fsctx.Overwrite != fsctx.Overwrite {
|
||||
if util.Exists(dst) {
|
||||
util.Log().Warning("File with the same name existed or unavailable: %s", dst)
|
||||
return errors.New("file with the same name existed or unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
// 如果目标目录不存在,创建
|
||||
basePath := filepath.Dir(dst)
|
||||
if !util.Exists(basePath) {
|
||||
err := os.MkdirAll(basePath, Perm)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to create directory: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
out *os.File
|
||||
err error
|
||||
)
|
||||
|
||||
openMode := os.O_CREATE | os.O_RDWR
|
||||
if fileInfo.Mode&fsctx.Append == fsctx.Append {
|
||||
openMode |= os.O_APPEND
|
||||
} else {
|
||||
openMode |= os.O_TRUNC
|
||||
}
|
||||
|
||||
out, err = os.OpenFile(dst, openMode, Perm)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to open or create file: %s", err)
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if fileInfo.Mode&fsctx.Append == fsctx.Append {
|
||||
stat, err := out.Stat()
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to read file info: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if uint64(stat.Size()) < fileInfo.AppendStart {
|
||||
return errors.New("size of unfinished uploaded chunks is not as expected")
|
||||
} else if uint64(stat.Size()) > fileInfo.AppendStart {
|
||||
out.Close()
|
||||
if err := handler.Truncate(ctx, dst, fileInfo.AppendStart); err != nil {
|
||||
return fmt.Errorf("failed to overwrite chunk: %w", err)
|
||||
}
|
||||
|
||||
out, err = os.OpenFile(dst, openMode, Perm)
|
||||
defer out.Close()
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to create or open file: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 写入文件内容
|
||||
_, err = io.Copy(out, file)
|
||||
return err
|
||||
}
|
||||
|
||||
func (handler Driver) Truncate(ctx context.Context, src string, size uint64) error {
|
||||
util.Log().Warning("Truncate file %q to [%d].", src, size)
|
||||
out, err := os.OpenFile(src, os.O_WRONLY, Perm)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to open file: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
defer out.Close()
|
||||
return out.Truncate(int64(size))
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
deleteFailed := make([]string, 0, len(files))
|
||||
var retErr error
|
||||
|
||||
for _, value := range files {
|
||||
filePath := util.RelativePath(filepath.FromSlash(value))
|
||||
if util.Exists(filePath) {
|
||||
err := os.Remove(filePath)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to delete file: %s", err)
|
||||
retErr = err
|
||||
deleteFailed = append(deleteFailed, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试删除文件的缩略图(如果有)
|
||||
_ = os.Remove(util.RelativePath(value + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")))
|
||||
}
|
||||
|
||||
return deleteFailed, retErr
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
// Quick check thumb existence on master.
|
||||
if conf.SystemConfig.Mode == "master" && file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusNotExist {
|
||||
// Tell invoker to generate a thumb
|
||||
return nil, driver.ErrorThumbNotExist
|
||||
}
|
||||
|
||||
thumbFile, err := handler.Get(ctx, file.ThumbFile())
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
err = fmt.Errorf("thumb not exist: %w (%w)", err, driver.ErrorThumbNotExist)
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &response.ContentResponse{
|
||||
Redirect: false,
|
||||
Content: thumbFile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
file, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
|
||||
if !ok {
|
||||
return "", errors.New("failed to read file model context")
|
||||
}
|
||||
|
||||
var baseURL *url.URL
|
||||
// 是否启用了CDN
|
||||
if handler.Policy.BaseURL != "" {
|
||||
cdnURL, err := url.Parse(handler.Policy.BaseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
baseURL = cdnURL
|
||||
}
|
||||
|
||||
var (
|
||||
signedURI *url.URL
|
||||
err error
|
||||
)
|
||||
if isDownload {
|
||||
// 创建下载会话,将文件信息写入缓存
|
||||
downloadSessionID := util.RandStringRunes(16)
|
||||
err = cache.Set("download_"+downloadSessionID, file, int(ttl))
|
||||
if err != nil {
|
||||
return "", serializer.NewError(serializer.CodeCacheOperation, "Failed to create download session", err)
|
||||
}
|
||||
|
||||
// 签名生成文件记录
|
||||
signedURI, err = auth.SignURI(
|
||||
auth.General,
|
||||
fmt.Sprintf("/api/v3/file/download/%s", downloadSessionID),
|
||||
ttl,
|
||||
)
|
||||
} else {
|
||||
// 签名生成文件记录
|
||||
signedURI, err = auth.SignURI(
|
||||
auth.General,
|
||||
fmt.Sprintf("/api/v3/file/get/%d/%s", file.ID, file.Name),
|
||||
ttl,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign url", err)
|
||||
}
|
||||
|
||||
finalURL := signedURI.String()
|
||||
if baseURL != nil {
|
||||
finalURL = baseURL.ResolveReference(signedURI).String()
|
||||
}
|
||||
|
||||
return finalURL, nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token,本地策略直接返回空值
|
||||
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
if util.Exists(uploadSession.SavePath) {
|
||||
return nil, errors.New("placeholder file already exist")
|
||||
}
|
||||
|
||||
return &serializer.UploadCredential{
|
||||
SessionID: uploadSession.Key,
|
||||
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return nil
|
||||
}
|
595
pkg/filesystem/driver/onedrive/api.go
Normal file
595
pkg/filesystem/driver/onedrive/api.go
Normal file
@ -0,0 +1,595 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
const (
|
||||
// SmallFileSize 单文件上传接口最大尺寸
|
||||
SmallFileSize uint64 = 4 * 1024 * 1024
|
||||
// ChunkSize 服务端中转分片上传分片大小
|
||||
ChunkSize uint64 = 10 * 1024 * 1024
|
||||
// ListRetry 列取请求重试次数
|
||||
ListRetry = 1
|
||||
chunkRetrySleep = time.Second * 5
|
||||
|
||||
notFoundError = "itemNotFound"
|
||||
)
|
||||
|
||||
// GetSourcePath 获取文件的绝对路径
|
||||
func (info *FileInfo) GetSourcePath() string {
|
||||
res, err := url.PathUnescape(info.ParentReference.Path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimPrefix(
|
||||
path.Join(
|
||||
strings.TrimPrefix(res, "/drive/root:"),
|
||||
info.Name,
|
||||
),
|
||||
"/",
|
||||
)
|
||||
}
|
||||
|
||||
func (client *Client) getRequestURL(api string, opts ...Option) string {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
base, _ := url.Parse(client.Endpoints.EndpointURL)
|
||||
if base == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if options.useDriverResource {
|
||||
base.Path = path.Join(base.Path, client.Endpoints.DriverResource, api)
|
||||
} else {
|
||||
base.Path = path.Join(base.Path, api)
|
||||
}
|
||||
|
||||
return base.String()
|
||||
}
|
||||
|
||||
// ListChildren 根据路径列取子对象
|
||||
func (client *Client) ListChildren(ctx context.Context, path string) ([]FileInfo, error) {
|
||||
var requestURL string
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
if dst == "" {
|
||||
requestURL = client.getRequestURL("root/children")
|
||||
} else {
|
||||
requestURL = client.getRequestURL("root:/" + dst + ":/children")
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?$top=999999999", "", 200)
|
||||
if err != nil {
|
||||
retried := 0
|
||||
if v, ok := ctx.Value(fsctx.RetryCtx).(int); ok {
|
||||
retried = v
|
||||
}
|
||||
if retried < ListRetry {
|
||||
retried++
|
||||
util.Log().Debug("Failed to list path %q: %s, will retry in 5 seconds.", path, err)
|
||||
time.Sleep(time.Duration(5) * time.Second)
|
||||
return client.ListChildren(context.WithValue(ctx, fsctx.RetryCtx, retried), path)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
fileInfo ListResponse
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
return fileInfo.Value, nil
|
||||
}
|
||||
|
||||
// Meta 根据资源ID或文件路径获取文件元信息
|
||||
func (client *Client) Meta(ctx context.Context, id string, path string) (*FileInfo, error) {
|
||||
var requestURL string
|
||||
if id != "" {
|
||||
requestURL = client.getRequestURL("items/" + id)
|
||||
} else {
|
||||
dst := strings.TrimPrefix(path, "/")
|
||||
requestURL = client.getRequestURL("root:/" + dst)
|
||||
}
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL+"?expand=thumbnails", "", 200)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
fileInfo FileInfo
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &fileInfo)
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
return &fileInfo, nil
|
||||
|
||||
}
|
||||
|
||||
// CreateUploadSession 创建分片上传会话
|
||||
func (client *Client) CreateUploadSession(ctx context.Context, dst string, opts ...Option) (string, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/createUploadSession")
|
||||
body := map[string]map[string]interface{}{
|
||||
"item": {
|
||||
"@microsoft.graph.conflictBehavior": options.conflictBehavior,
|
||||
},
|
||||
}
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
|
||||
res, err := client.requestWithStr(ctx, "POST", requestURL, string(bodyBytes), 200)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
uploadSession UploadSessionResponse
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
|
||||
if decodeErr != nil {
|
||||
return "", decodeErr
|
||||
}
|
||||
|
||||
return uploadSession.UploadURL, nil
|
||||
}
|
||||
|
||||
// GetSiteIDByURL 通过 SharePoint 站点 URL 获取站点ID
|
||||
func (client *Client) GetSiteIDByURL(ctx context.Context, siteUrl string) (string, error) {
|
||||
siteUrlParsed, err := url.Parse(siteUrl)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
hostName := siteUrlParsed.Hostname()
|
||||
relativePath := strings.Trim(siteUrlParsed.Path, "/")
|
||||
requestURL := client.getRequestURL(fmt.Sprintf("sites/%s:/%s", hostName, relativePath), WithDriverResource(false))
|
||||
res, reqErr := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
if reqErr != nil {
|
||||
return "", reqErr
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
siteInfo Site
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &siteInfo)
|
||||
if decodeErr != nil {
|
||||
return "", decodeErr
|
||||
}
|
||||
|
||||
return siteInfo.ID, nil
|
||||
}
|
||||
|
||||
// GetUploadSessionStatus 查询上传会话状态
|
||||
func (client *Client) GetUploadSessionStatus(ctx context.Context, uploadURL string) (*UploadSessionResponse, error) {
|
||||
res, err := client.requestWithStr(ctx, "GET", uploadURL, "", 200)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
uploadSession UploadSessionResponse
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &uploadSession)
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
return &uploadSession, nil
|
||||
}
|
||||
|
||||
// UploadChunk 上传分片
|
||||
func (client *Client) UploadChunk(ctx context.Context, uploadURL string, content io.Reader, current *chunk.ChunkGroup) (*UploadSessionResponse, error) {
|
||||
res, err := client.request(
|
||||
ctx, "PUT", uploadURL, content,
|
||||
request.WithContentLength(current.Length()),
|
||||
request.WithHeader(http.Header{
|
||||
"Content-Range": {current.RangeHeader()},
|
||||
}),
|
||||
request.WithoutHeader([]string{"Authorization", "Content-Type"}),
|
||||
request.WithTimeout(0),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to upload OneDrive chunk #%d: %w", current.Index(), err)
|
||||
}
|
||||
|
||||
if current.IsLast() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
uploadRes UploadSessionResponse
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
return &uploadRes, nil
|
||||
}
|
||||
|
||||
// Upload 上传文件
|
||||
func (client *Client) Upload(ctx context.Context, file fsctx.FileHeader) error {
|
||||
fileInfo := file.Info()
|
||||
// 决定是否覆盖文件
|
||||
overwrite := "fail"
|
||||
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
|
||||
overwrite = "replace"
|
||||
}
|
||||
|
||||
size := int(fileInfo.Size)
|
||||
dst := fileInfo.SavePath
|
||||
|
||||
// 小文件,使用简单上传接口上传
|
||||
if size <= int(SmallFileSize) {
|
||||
_, err := client.SimpleUpload(ctx, dst, file, int64(size), WithConflictBehavior(overwrite))
|
||||
return err
|
||||
}
|
||||
|
||||
// 大文件,进行分片
|
||||
// 创建上传会话
|
||||
uploadURL, err := client.CreateUploadSession(ctx, dst, WithConflictBehavior(overwrite))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Initial chunk groups
|
||||
chunks := chunk.NewChunkGroup(file, client.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
|
||||
Max: model.GetIntSetting("chunk_retries", 5),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
_, err := client.UploadChunk(ctx, uploadURL, content, current)
|
||||
return err
|
||||
}
|
||||
|
||||
// upload chunks
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUploadSession 删除上传会话
|
||||
func (client *Client) DeleteUploadSession(ctx context.Context, uploadURL string) error {
|
||||
_, err := client.requestWithStr(ctx, "DELETE", uploadURL, "", 204)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SimpleUpload 上传小文件到dst
|
||||
func (client *Client) SimpleUpload(ctx context.Context, dst string, body io.Reader, size int64, opts ...Option) (*UploadResult, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("root:/" + dst + ":/content")
|
||||
requestURL += ("?@microsoft.graph.conflictBehavior=" + options.conflictBehavior)
|
||||
|
||||
res, err := client.request(ctx, "PUT", requestURL, body, request.WithContentLength(int64(size)),
|
||||
request.WithTimeout(0),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
uploadRes UploadResult
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &uploadRes)
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
return &uploadRes, nil
|
||||
}
|
||||
|
||||
// BatchDelete 并行删除给出的文件,返回删除失败的文件,及第一个遇到的错误。此方法将文件分为
|
||||
// 20个一组,调用Delete并行删除
|
||||
// TODO 测试
|
||||
func (client *Client) BatchDelete(ctx context.Context, dst []string) ([]string, error) {
|
||||
groupNum := len(dst)/20 + 1
|
||||
finalRes := make([]string, 0, len(dst))
|
||||
res := make([]string, 0, 20)
|
||||
var err error
|
||||
|
||||
for i := 0; i < groupNum; i++ {
|
||||
end := 20*i + 20
|
||||
if i == groupNum-1 {
|
||||
end = len(dst)
|
||||
}
|
||||
res, err = client.Delete(ctx, dst[20*i:end])
|
||||
finalRes = append(finalRes, res...)
|
||||
}
|
||||
|
||||
return finalRes, err
|
||||
}
|
||||
|
||||
// Delete 并行删除文件,返回删除失败的文件,及第一个遇到的错误,
|
||||
// 由于API限制,最多删除20个
|
||||
func (client *Client) Delete(ctx context.Context, dst []string) ([]string, error) {
|
||||
body := client.makeBatchDeleteRequestsBody(dst)
|
||||
res, err := client.requestWithStr(ctx, "POST", client.getRequestURL("$batch",
|
||||
WithDriverResource(false)), body, 200)
|
||||
if err != nil {
|
||||
return dst, err
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
deleteRes BatchResponses
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &deleteRes)
|
||||
if decodeErr != nil {
|
||||
return dst, decodeErr
|
||||
}
|
||||
|
||||
// 取得删除失败的文件
|
||||
failed := getDeleteFailed(&deleteRes)
|
||||
if len(failed) != 0 {
|
||||
return failed, ErrDeleteFile
|
||||
}
|
||||
return failed, nil
|
||||
}
|
||||
|
||||
func getDeleteFailed(res *BatchResponses) []string {
|
||||
var failed = make([]string, 0, len(res.Responses))
|
||||
for _, v := range res.Responses {
|
||||
if v.Status != 204 && v.Status != 404 {
|
||||
failed = append(failed, v.ID)
|
||||
}
|
||||
}
|
||||
return failed
|
||||
}
|
||||
|
||||
// makeBatchDeleteRequestsBody 生成批量删除请求正文
|
||||
func (client *Client) makeBatchDeleteRequestsBody(files []string) string {
|
||||
req := BatchRequests{
|
||||
Requests: make([]BatchRequest, len(files)),
|
||||
}
|
||||
for i, v := range files {
|
||||
v = strings.TrimPrefix(v, "/")
|
||||
filePath, _ := url.Parse("/" + client.Endpoints.DriverResource + "/root:/")
|
||||
filePath.Path = path.Join(filePath.Path, v)
|
||||
req.Requests[i] = BatchRequest{
|
||||
ID: v,
|
||||
Method: "DELETE",
|
||||
URL: filePath.EscapedPath(),
|
||||
}
|
||||
}
|
||||
|
||||
res, _ := json.Marshal(req)
|
||||
return string(res)
|
||||
}
|
||||
|
||||
// GetThumbURL 获取给定尺寸的缩略图URL
|
||||
func (client *Client) GetThumbURL(ctx context.Context, dst string, w, h uint) (string, error) {
|
||||
dst = strings.TrimPrefix(dst, "/")
|
||||
requestURL := client.getRequestURL("root:/"+dst+":/thumbnails/0") + "/large"
|
||||
|
||||
res, err := client.requestWithStr(ctx, "GET", requestURL, "", 200)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var (
|
||||
decodeErr error
|
||||
thumbRes ThumbResponse
|
||||
)
|
||||
decodeErr = json.Unmarshal([]byte(res), &thumbRes)
|
||||
if decodeErr != nil {
|
||||
return "", decodeErr
|
||||
}
|
||||
|
||||
if thumbRes.URL != "" {
|
||||
return thumbRes.URL, nil
|
||||
}
|
||||
|
||||
if len(thumbRes.Value) == 1 {
|
||||
if res, ok := thumbRes.Value[0]["large"]; ok {
|
||||
return res.(map[string]interface{})["url"].(string), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", ErrThumbSizeNotFound
|
||||
}
|
||||
|
||||
// MonitorUpload 监控客户端分片上传进度
|
||||
func (client *Client) MonitorUpload(uploadURL, callbackKey, path string, size uint64, ttl int64) {
|
||||
// 回调完成通知chan
|
||||
callbackChan := mq.GlobalMQ.Subscribe(callbackKey, 1)
|
||||
defer mq.GlobalMQ.Unsubscribe(callbackKey, callbackChan)
|
||||
|
||||
timeout := model.GetIntSetting("onedrive_monitor_timeout", 600)
|
||||
interval := model.GetIntSetting("onedrive_callback_check", 20)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-callbackChan:
|
||||
util.Log().Debug("Client finished OneDrive callback.")
|
||||
return
|
||||
case <-time.After(time.Duration(ttl) * time.Second):
|
||||
// 上传会话到期,仍未完成上传,创建占位符
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to create placeholder file: %s", err)
|
||||
}
|
||||
return
|
||||
case <-time.After(time.Duration(timeout) * time.Second):
|
||||
util.Log().Debug("Checking OneDrive upload status.")
|
||||
status, err := client.GetUploadSessionStatus(context.Background(), uploadURL)
|
||||
|
||||
if err != nil {
|
||||
if resErr, ok := err.(*RespError); ok {
|
||||
if resErr.APIError.Code == notFoundError {
|
||||
util.Log().Debug("Upload completed, will check upload callback later.")
|
||||
select {
|
||||
case <-time.After(time.Duration(interval) * time.Second):
|
||||
util.Log().Warning("No callback is made, file will be deleted.")
|
||||
cache.Deletes([]string{callbackKey}, "callback_")
|
||||
_, err = client.Delete(context.Background(), []string{path})
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to delete file without callback: %s", err)
|
||||
}
|
||||
case <-callbackChan:
|
||||
util.Log().Debug("Client finished callback.")
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
util.Log().Debug("Failed to get upload session status: %s, continue next iteration.", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
// 成功获取分片上传状态,检查文件大小
|
||||
if len(status.NextExpectedRanges) == 0 {
|
||||
continue
|
||||
}
|
||||
sizeRange := strings.Split(
|
||||
status.NextExpectedRanges[len(status.NextExpectedRanges)-1],
|
||||
"-",
|
||||
)
|
||||
if len(sizeRange) != 2 {
|
||||
continue
|
||||
}
|
||||
uploadFullSize, _ := strconv.ParseUint(sizeRange[1], 10, 64)
|
||||
if (sizeRange[0] == "0" && sizeRange[1] == "") || uploadFullSize+1 != size {
|
||||
util.Log().Debug("Upload has not started, or uploaded file size not match, canceling upload session...")
|
||||
// 取消上传会话,实测OneDrive取消上传会话后,客户端还是可以上传,
|
||||
// 所以上传一个空文件占位,阻止客户端上传
|
||||
client.DeleteUploadSession(context.Background(), uploadURL)
|
||||
_, err := client.SimpleUpload(context.Background(), path, strings.NewReader(""), 0, WithConflictBehavior("replace"))
|
||||
if err != nil {
|
||||
util.Log().Debug("无法创建占位文件,%s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sysError(err error) *RespError {
|
||||
return &RespError{APIError: APIError{
|
||||
Code: "system",
|
||||
Message: err.Error(),
|
||||
}}
|
||||
}
|
||||
|
||||
func (client *Client) request(ctx context.Context, method string, url string, body io.Reader, option ...request.Option) (string, error) {
|
||||
// 获取凭证
|
||||
err := client.UpdateCredential(ctx, conf.SystemConfig.Mode == "slave")
|
||||
if err != nil {
|
||||
return "", sysError(err)
|
||||
}
|
||||
|
||||
option = append(option,
|
||||
request.WithHeader(http.Header{
|
||||
"Authorization": {"Bearer " + client.Credential.AccessToken},
|
||||
"Content-Type": {"application/json"},
|
||||
}),
|
||||
request.WithContext(ctx),
|
||||
request.WithTPSLimit(
|
||||
fmt.Sprintf("policy_%d", client.Policy.ID),
|
||||
client.Policy.OptionsSerialized.TPSLimit,
|
||||
client.Policy.OptionsSerialized.TPSLimitBurst,
|
||||
),
|
||||
)
|
||||
|
||||
// 发送请求
|
||||
res := client.Request.Request(
|
||||
method,
|
||||
url,
|
||||
body,
|
||||
option...,
|
||||
)
|
||||
|
||||
if res.Err != nil {
|
||||
return "", sysError(res.Err)
|
||||
}
|
||||
|
||||
respBody, err := res.GetResponse()
|
||||
if err != nil {
|
||||
return "", sysError(err)
|
||||
}
|
||||
|
||||
// 解析请求响应
|
||||
var (
|
||||
errResp RespError
|
||||
decodeErr error
|
||||
)
|
||||
// 如果有错误
|
||||
if res.Response.StatusCode < 200 || res.Response.StatusCode >= 300 {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||||
if decodeErr != nil {
|
||||
util.Log().Debug("Onedrive returns unknown response: %s", respBody)
|
||||
return "", sysError(decodeErr)
|
||||
}
|
||||
|
||||
if res.Response.StatusCode == 429 {
|
||||
util.Log().Warning("OneDrive request is throttled.")
|
||||
return "", backoff.NewRetryableErrorFromHeader(&errResp, res.Response.Header)
|
||||
}
|
||||
|
||||
return "", &errResp
|
||||
}
|
||||
|
||||
return respBody, nil
|
||||
}
|
||||
|
||||
func (client *Client) requestWithStr(ctx context.Context, method string, url string, body string, expectedCode int) (string, error) {
|
||||
// 发送请求
|
||||
bodyReader := io.NopCloser(strings.NewReader(body))
|
||||
return client.request(ctx, method, url, bodyReader,
|
||||
request.WithContentLength(int64(len(body))),
|
||||
)
|
||||
}
|
78
pkg/filesystem/driver/onedrive/client.go
Normal file
78
pkg/filesystem/driver/onedrive/client.go
Normal file
@ -0,0 +1,78 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrAuthEndpoint 无法解析授权端点地址
|
||||
ErrAuthEndpoint = errors.New("failed to parse endpoint url")
|
||||
// ErrInvalidRefreshToken 上传策略无有效的RefreshToken
|
||||
ErrInvalidRefreshToken = errors.New("no valid refresh token in this policy")
|
||||
// ErrDeleteFile 无法删除文件
|
||||
ErrDeleteFile = errors.New("cannot delete file")
|
||||
// ErrClientCanceled 客户端取消操作
|
||||
ErrClientCanceled = errors.New("client canceled")
|
||||
// Desired thumb size not available
|
||||
ErrThumbSizeNotFound = errors.New("thumb size not found")
|
||||
)
|
||||
|
||||
// Client OneDrive客户端
|
||||
type Client struct {
|
||||
Endpoints *Endpoints
|
||||
Policy *model.Policy
|
||||
Credential *Credential
|
||||
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Redirect string
|
||||
|
||||
Request request.Client
|
||||
ClusterController cluster.Controller
|
||||
}
|
||||
|
||||
// Endpoints OneDrive客户端相关设置
|
||||
type Endpoints struct {
|
||||
OAuthURL string // OAuth认证的基URL
|
||||
OAuthEndpoints *oauthEndpoint
|
||||
EndpointURL string // 接口请求的基URL
|
||||
isInChina bool // 是否为世纪互联
|
||||
DriverResource string // 要使用的驱动器
|
||||
}
|
||||
|
||||
// NewClient 根据存储策略获取新的client
|
||||
func NewClient(policy *model.Policy) (*Client, error) {
|
||||
client := &Client{
|
||||
Endpoints: &Endpoints{
|
||||
OAuthURL: policy.BaseURL,
|
||||
EndpointURL: policy.Server,
|
||||
DriverResource: policy.OptionsSerialized.OdDriver,
|
||||
},
|
||||
Credential: &Credential{
|
||||
RefreshToken: policy.AccessKey,
|
||||
},
|
||||
Policy: policy,
|
||||
ClientID: policy.BucketName,
|
||||
ClientSecret: policy.SecretKey,
|
||||
Redirect: policy.OptionsSerialized.OauthRedirect,
|
||||
Request: request.NewClient(),
|
||||
ClusterController: cluster.DefaultController,
|
||||
}
|
||||
|
||||
if client.Endpoints.DriverResource == "" {
|
||||
client.Endpoints.DriverResource = "me/drive"
|
||||
}
|
||||
|
||||
oauthBase := client.getOAuthEndpoint()
|
||||
if oauthBase == nil {
|
||||
return nil, ErrAuthEndpoint
|
||||
}
|
||||
client.Endpoints.OAuthEndpoints = oauthBase
|
||||
|
||||
return client, nil
|
||||
}
|
238
pkg/filesystem/driver/onedrive/handler.go
Normal file
238
pkg/filesystem/driver/onedrive/handler.go
Normal file
@ -0,0 +1,238 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Driver OneDrive 适配器
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
Client *Client
|
||||
HTTPClient request.Client
|
||||
}
|
||||
|
||||
// NewDriver 从存储策略初始化新的Driver实例
|
||||
func NewDriver(policy *model.Policy) (driver.Handler, error) {
|
||||
client, err := NewClient(policy)
|
||||
if policy.OptionsSerialized.ChunkSize == 0 {
|
||||
policy.OptionsSerialized.ChunkSize = 50 << 20 // 50MB
|
||||
}
|
||||
|
||||
return Driver{
|
||||
Policy: policy,
|
||||
Client: client,
|
||||
HTTPClient: request.NewClient(),
|
||||
}, err
|
||||
}
|
||||
|
||||
// List 列取项目
|
||||
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
base = strings.TrimPrefix(base, "/")
|
||||
// 列取子项目
|
||||
objects, _ := handler.Client.ListChildren(ctx, base)
|
||||
|
||||
// 获取真实的列取起始根目录
|
||||
rootPath := base
|
||||
if realBase, ok := ctx.Value(fsctx.PathCtx).(string); ok {
|
||||
rootPath = realBase
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, fsctx.PathCtx, base)
|
||||
}
|
||||
|
||||
// 整理结果
|
||||
res := make([]response.Object, 0, len(objects))
|
||||
for _, object := range objects {
|
||||
source := path.Join(base, object.Name)
|
||||
rel, err := filepath.Rel(rootPath, source)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: object.Name,
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Source: source,
|
||||
Size: object.Size,
|
||||
IsDir: object.Folder != nil,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// 递归列取子目录
|
||||
if recursive {
|
||||
for _, object := range objects {
|
||||
if object.Folder != nil {
|
||||
sub, _ := handler.List(ctx, path.Join(base, object.Name), recursive)
|
||||
res = append(res, sub...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 获取文件源地址
|
||||
downloadURL, err := handler.Source(
|
||||
ctx,
|
||||
path,
|
||||
60,
|
||||
false,
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取文件数据流
|
||||
resp, err := handler.HTTPClient.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
).CheckHTTPResponse(200).GetRSCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.SetFirstFakeChunk()
|
||||
|
||||
// 尝试自主获取文件大小
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
resp.SetContentLength(int64(file.Size))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
|
||||
return handler.Client.Upload(ctx, file)
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
return handler.Client.BatchDelete(ctx, files)
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
var (
|
||||
thumbSize = [2]uint{400, 300}
|
||||
ok = false
|
||||
)
|
||||
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
|
||||
return nil, errors.New("failed to get thumbnail size")
|
||||
}
|
||||
|
||||
res, err := handler.Client.GetThumbURL(ctx, file.SourceName, thumbSize[0], thumbSize[1])
|
||||
if err != nil {
|
||||
var apiErr *RespError
|
||||
if errors.As(err, &apiErr); err == ErrThumbSizeNotFound || (apiErr != nil && apiErr.APIError.Code == notFoundError) {
|
||||
// OneDrive cannot generate thumbnail for this file
|
||||
return nil, driver.ErrorThumbNotSupported
|
||||
}
|
||||
}
|
||||
|
||||
return &response.ContentResponse{
|
||||
Redirect: true,
|
||||
URL: res,
|
||||
}, err
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler Driver) Source(
|
||||
ctx context.Context,
|
||||
path string,
|
||||
ttl int64,
|
||||
isDownload bool,
|
||||
speed int,
|
||||
) (string, error) {
|
||||
cacheKey := fmt.Sprintf("onedrive_source_%d_%s", handler.Policy.ID, path)
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
cacheKey = fmt.Sprintf("onedrive_source_file_%d_%d", file.UpdatedAt.Unix(), file.ID)
|
||||
}
|
||||
|
||||
// 尝试从缓存中查找
|
||||
if cachedURL, ok := cache.Get(cacheKey); ok {
|
||||
return handler.replaceSourceHost(cachedURL.(string))
|
||||
}
|
||||
|
||||
// 缓存不存在,重新获取
|
||||
res, err := handler.Client.Meta(ctx, "", path)
|
||||
if err == nil {
|
||||
// 写入新的缓存
|
||||
cache.Set(
|
||||
cacheKey,
|
||||
res.DownloadURL,
|
||||
model.GetIntSetting("onedrive_source_timeout", 1800),
|
||||
)
|
||||
return handler.replaceSourceHost(res.DownloadURL)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
func (handler Driver) replaceSourceHost(origin string) (string, error) {
|
||||
if handler.Policy.OptionsSerialized.OdProxy != "" {
|
||||
source, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
cdn, err := url.Parse(handler.Policy.OptionsSerialized.OdProxy)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 替换反代地址
|
||||
source.Scheme = cdn.Scheme
|
||||
source.Host = cdn.Host
|
||||
return source.String(), nil
|
||||
}
|
||||
|
||||
return origin, nil
|
||||
}
|
||||
|
||||
// Token 获取上传会话URL
|
||||
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
fileInfo := file.Info()
|
||||
|
||||
uploadURL, err := handler.Client.CreateUploadSession(ctx, fileInfo.SavePath, WithConflictBehavior("fail"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 监控回调及上传
|
||||
go handler.Client.MonitorUpload(uploadURL, uploadSession.Key, fileInfo.SavePath, fileInfo.Size, ttl)
|
||||
|
||||
uploadSession.UploadURL = uploadURL
|
||||
return &serializer.UploadCredential{
|
||||
SessionID: uploadSession.Key,
|
||||
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
|
||||
UploadURLs: []string{uploadURL},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return handler.Client.DeleteUploadSession(ctx, uploadSession.UploadURL)
|
||||
}
|
25
pkg/filesystem/driver/onedrive/lock.go
Normal file
25
pkg/filesystem/driver/onedrive/lock.go
Normal file
@ -0,0 +1,25 @@
|
||||
package onedrive
|
||||
|
||||
import "sync"
|
||||
|
||||
// CredentialLock 针对存储策略凭证的锁
|
||||
type CredentialLock interface {
|
||||
Lock(uint)
|
||||
Unlock(uint)
|
||||
}
|
||||
|
||||
var GlobalMutex = mutexMap{}
|
||||
|
||||
type mutexMap struct {
|
||||
locks sync.Map
|
||||
}
|
||||
|
||||
func (m *mutexMap) Lock(id uint) {
|
||||
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
func (m *mutexMap) Unlock(id uint) {
|
||||
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
192
pkg/filesystem/driver/onedrive/oauth.go
Normal file
192
pkg/filesystem/driver/onedrive/oauth.go
Normal file
@ -0,0 +1,192 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/oauth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Error 实现error接口
|
||||
func (err OAuthError) Error() string {
|
||||
return err.ErrorDescription
|
||||
}
|
||||
|
||||
// OAuthURL 获取OAuth认证页面URL
|
||||
func (client *Client) OAuthURL(ctx context.Context, scope []string) string {
|
||||
query := url.Values{
|
||||
"client_id": {client.ClientID},
|
||||
"scope": {strings.Join(scope, " ")},
|
||||
"response_type": {"code"},
|
||||
"redirect_uri": {client.Redirect},
|
||||
}
|
||||
client.Endpoints.OAuthEndpoints.authorize.RawQuery = query.Encode()
|
||||
return client.Endpoints.OAuthEndpoints.authorize.String()
|
||||
}
|
||||
|
||||
// getOAuthEndpoint 根据指定的AuthURL获取详细的认证接口地址
|
||||
func (client *Client) getOAuthEndpoint() *oauthEndpoint {
|
||||
base, err := url.Parse(client.Endpoints.OAuthURL)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
token *url.URL
|
||||
authorize *url.URL
|
||||
)
|
||||
switch base.Host {
|
||||
case "login.live.com":
|
||||
token, _ = url.Parse("https://login.live.com/oauth20_token.srf")
|
||||
authorize, _ = url.Parse("https://login.live.com/oauth20_authorize.srf")
|
||||
case "login.chinacloudapi.cn":
|
||||
client.Endpoints.isInChina = true
|
||||
token, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/token")
|
||||
authorize, _ = url.Parse("https://login.chinacloudapi.cn/common/oauth2/v2.0/authorize")
|
||||
default:
|
||||
token, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/token")
|
||||
authorize, _ = url.Parse("https://login.microsoftonline.com/common/oauth2/v2.0/authorize")
|
||||
}
|
||||
|
||||
return &oauthEndpoint{
|
||||
token: *token,
|
||||
authorize: *authorize,
|
||||
}
|
||||
}
|
||||
|
||||
// ObtainToken 通过code或refresh_token兑换token
|
||||
func (client *Client) ObtainToken(ctx context.Context, opts ...Option) (*Credential, error) {
|
||||
options := newDefaultOption()
|
||||
for _, o := range opts {
|
||||
o.apply(options)
|
||||
}
|
||||
|
||||
body := url.Values{
|
||||
"client_id": {client.ClientID},
|
||||
"redirect_uri": {client.Redirect},
|
||||
"client_secret": {client.ClientSecret},
|
||||
}
|
||||
if options.code != "" {
|
||||
body.Add("grant_type", "authorization_code")
|
||||
body.Add("code", options.code)
|
||||
} else {
|
||||
body.Add("grant_type", "refresh_token")
|
||||
body.Add("refresh_token", options.refreshToken)
|
||||
}
|
||||
strBody := body.Encode()
|
||||
|
||||
res := client.Request.Request(
|
||||
"POST",
|
||||
client.Endpoints.OAuthEndpoints.token.String(),
|
||||
ioutil.NopCloser(strings.NewReader(strBody)),
|
||||
request.WithHeader(http.Header{
|
||||
"Content-Type": {"application/x-www-form-urlencoded"}},
|
||||
),
|
||||
request.WithContentLength(int64(len(strBody))),
|
||||
)
|
||||
if res.Err != nil {
|
||||
return nil, res.Err
|
||||
}
|
||||
|
||||
respBody, err := res.GetResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
errResp OAuthError
|
||||
credential Credential
|
||||
decodeErr error
|
||||
)
|
||||
|
||||
if res.Response.StatusCode != 200 {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &errResp)
|
||||
} else {
|
||||
decodeErr = json.Unmarshal([]byte(respBody), &credential)
|
||||
}
|
||||
if decodeErr != nil {
|
||||
return nil, decodeErr
|
||||
}
|
||||
|
||||
if errResp.ErrorType != "" {
|
||||
return nil, errResp
|
||||
}
|
||||
|
||||
return &credential, nil
|
||||
|
||||
}
|
||||
|
||||
// UpdateCredential 更新凭证,并检查有效期
|
||||
func (client *Client) UpdateCredential(ctx context.Context, isSlave bool) error {
|
||||
if isSlave {
|
||||
return client.fetchCredentialFromMaster(ctx)
|
||||
}
|
||||
|
||||
oauth.GlobalMutex.Lock(client.Policy.ID)
|
||||
defer oauth.GlobalMutex.Unlock(client.Policy.ID)
|
||||
|
||||
// 如果已存在凭证
|
||||
if client.Credential != nil && client.Credential.AccessToken != "" {
|
||||
// 检查已有凭证是否过期
|
||||
if client.Credential.ExpiresIn > time.Now().Unix() {
|
||||
// 未过期,不要更新
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试从缓存中获取凭证
|
||||
if cacheCredential, ok := cache.Get("onedrive_" + client.ClientID); ok {
|
||||
credential := cacheCredential.(Credential)
|
||||
if credential.ExpiresIn > time.Now().Unix() {
|
||||
client.Credential = &credential
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 获取新的凭证
|
||||
if client.Credential == nil || client.Credential.RefreshToken == "" {
|
||||
// 无有效的RefreshToken
|
||||
util.Log().Error("Failed to refresh credential for policy %q, please login your Microsoft account again.", client.Policy.Name)
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
|
||||
credential, err := client.ObtainToken(ctx, WithRefreshToken(client.Credential.RefreshToken))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新有效期为绝对时间戳
|
||||
expires := credential.ExpiresIn - 60
|
||||
credential.ExpiresIn = time.Now().Add(time.Duration(expires) * time.Second).Unix()
|
||||
client.Credential = credential
|
||||
|
||||
// 更新存储策略的 RefreshToken
|
||||
client.Policy.UpdateAccessKeyAndClearCache(credential.RefreshToken)
|
||||
|
||||
// 更新缓存
|
||||
cache.Set("onedrive_"+client.ClientID, *credential, int(expires))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) AccessToken() string {
|
||||
return client.Credential.AccessToken
|
||||
}
|
||||
|
||||
// UpdateCredential 更新凭证,并检查有效期
|
||||
func (client *Client) fetchCredentialFromMaster(ctx context.Context) error {
|
||||
res, err := client.ClusterController.GetPolicyOauthToken(client.Policy.MasterID, client.Policy.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client.Credential = &Credential{AccessToken: res}
|
||||
return nil
|
||||
}
|
59
pkg/filesystem/driver/onedrive/options.go
Normal file
59
pkg/filesystem/driver/onedrive/options.go
Normal file
@ -0,0 +1,59 @@
|
||||
package onedrive
|
||||
|
||||
import "time"
|
||||
|
||||
// Option 发送请求的额外设置
|
||||
type Option interface {
|
||||
apply(*options)
|
||||
}
|
||||
|
||||
type options struct {
|
||||
redirect string
|
||||
code string
|
||||
refreshToken string
|
||||
conflictBehavior string
|
||||
expires time.Time
|
||||
useDriverResource bool
|
||||
}
|
||||
|
||||
type optionFunc func(*options)
|
||||
|
||||
// WithCode 设置接口Code
|
||||
func WithCode(t string) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.code = t
|
||||
})
|
||||
}
|
||||
|
||||
// WithRefreshToken 设置接口RefreshToken
|
||||
func WithRefreshToken(t string) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.refreshToken = t
|
||||
})
|
||||
}
|
||||
|
||||
// WithConflictBehavior 设置文件重名后的处理方式
|
||||
func WithConflictBehavior(t string) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.conflictBehavior = t
|
||||
})
|
||||
}
|
||||
|
||||
// WithConflictBehavior 设置文件重名后的处理方式
|
||||
func WithDriverResource(t bool) Option {
|
||||
return optionFunc(func(o *options) {
|
||||
o.useDriverResource = t
|
||||
})
|
||||
}
|
||||
|
||||
func (f optionFunc) apply(o *options) {
|
||||
f(o)
|
||||
}
|
||||
|
||||
func newDefaultOption() *options {
|
||||
return &options{
|
||||
conflictBehavior: "fail",
|
||||
useDriverResource: true,
|
||||
expires: time.Now().UTC().Add(time.Duration(1) * time.Hour),
|
||||
}
|
||||
}
|
140
pkg/filesystem/driver/onedrive/types.go
Normal file
140
pkg/filesystem/driver/onedrive/types.go
Normal file
@ -0,0 +1,140 @@
|
||||
package onedrive
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// RespError 接口返回错误
|
||||
type RespError struct {
|
||||
APIError APIError `json:"error"`
|
||||
}
|
||||
|
||||
// APIError 接口返回的错误内容
|
||||
type APIError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// UploadSessionResponse 分片上传会话
|
||||
type UploadSessionResponse struct {
|
||||
DataContext string `json:"@odata.context"`
|
||||
ExpirationDateTime string `json:"expirationDateTime"`
|
||||
NextExpectedRanges []string `json:"nextExpectedRanges"`
|
||||
UploadURL string `json:"uploadUrl"`
|
||||
}
|
||||
|
||||
// FileInfo 文件元信息
|
||||
type FileInfo struct {
|
||||
Name string `json:"name"`
|
||||
Size uint64 `json:"size"`
|
||||
Image imageInfo `json:"image"`
|
||||
ParentReference parentReference `json:"parentReference"`
|
||||
DownloadURL string `json:"@microsoft.graph.downloadUrl"`
|
||||
File *file `json:"file"`
|
||||
Folder *folder `json:"folder"`
|
||||
}
|
||||
|
||||
type file struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
}
|
||||
|
||||
type folder struct {
|
||||
ChildCount int `json:"childCount"`
|
||||
}
|
||||
|
||||
type imageInfo struct {
|
||||
Height int `json:"height"`
|
||||
Width int `json:"width"`
|
||||
}
|
||||
|
||||
type parentReference struct {
|
||||
Path string `json:"path"`
|
||||
Name string `json:"name"`
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// UploadResult 上传结果
|
||||
type UploadResult struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Size uint64 `json:"size"`
|
||||
}
|
||||
|
||||
// BatchRequests 批量操作请求
|
||||
type BatchRequests struct {
|
||||
Requests []BatchRequest `json:"requests"`
|
||||
}
|
||||
|
||||
// BatchRequest 批量操作单个请求
|
||||
type BatchRequest struct {
|
||||
ID string `json:"id"`
|
||||
Method string `json:"method"`
|
||||
URL string `json:"url"`
|
||||
Body interface{} `json:"body,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
}
|
||||
|
||||
// BatchResponses 批量操作响应
|
||||
type BatchResponses struct {
|
||||
Responses []BatchResponse `json:"responses"`
|
||||
}
|
||||
|
||||
// BatchResponse 批量操作单个响应
|
||||
type BatchResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// ThumbResponse 获取缩略图的响应
|
||||
type ThumbResponse struct {
|
||||
Value []map[string]interface{} `json:"value"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// ListResponse 列取子项目响应
|
||||
type ListResponse struct {
|
||||
Value []FileInfo `json:"value"`
|
||||
Context string `json:"@odata.context"`
|
||||
}
|
||||
|
||||
// oauthEndpoint OAuth接口地址
|
||||
type oauthEndpoint struct {
|
||||
token url.URL
|
||||
authorize url.URL
|
||||
}
|
||||
|
||||
// Credential 获取token时返回的凭证
|
||||
type Credential struct {
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// OAuthError OAuth相关接口的错误响应
|
||||
type OAuthError struct {
|
||||
ErrorType string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
}
|
||||
|
||||
// Site SharePoint 站点信息
|
||||
type Site struct {
|
||||
Description string `json:"description"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"displayName"`
|
||||
WebUrl string `json:"webUrl"`
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Credential{})
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (err RespError) Error() string {
|
||||
return err.APIError.Message
|
||||
}
|
117
pkg/filesystem/driver/oss/callback.go
Normal file
117
pkg/filesystem/driver/oss/callback.go
Normal file
@ -0,0 +1,117 @@
|
||||
package oss
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/md5"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
)
|
||||
|
||||
// GetPublicKey 从回调请求或缓存中获取OSS的回调签名公钥
|
||||
func GetPublicKey(r *http.Request) ([]byte, error) {
|
||||
var pubKey []byte
|
||||
|
||||
// 尝试从缓存中获取
|
||||
pub, exist := cache.Get("oss_public_key")
|
||||
if exist {
|
||||
return pub.([]byte), nil
|
||||
}
|
||||
|
||||
// 从请求中获取
|
||||
pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get("x-oss-pub-key-url"))
|
||||
if err != nil {
|
||||
return pubKey, err
|
||||
}
|
||||
|
||||
// 确保这个 public key 是由 OSS 颁发的
|
||||
if !strings.HasPrefix(string(pubURL), "http://gosspublic.alicdn.com/") &&
|
||||
!strings.HasPrefix(string(pubURL), "https://gosspublic.alicdn.com/") {
|
||||
return pubKey, errors.New("public key url invalid")
|
||||
}
|
||||
|
||||
// 获取公钥
|
||||
client := request.NewClient()
|
||||
body, err := client.Request("GET", string(pubURL), nil).
|
||||
CheckHTTPResponse(200).
|
||||
GetResponse()
|
||||
if err != nil {
|
||||
return pubKey, err
|
||||
}
|
||||
|
||||
// 写入缓存
|
||||
_ = cache.Set("oss_public_key", []byte(body), 86400*7)
|
||||
|
||||
return []byte(body), nil
|
||||
}
|
||||
|
||||
func getRequestMD5(r *http.Request) ([]byte, error) {
|
||||
var byteMD5 []byte
|
||||
|
||||
// 获取请求正文
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
r.Body.Close()
|
||||
if err != nil {
|
||||
return byteMD5, err
|
||||
}
|
||||
r.Body = ioutil.NopCloser(bytes.NewReader(body))
|
||||
|
||||
strURLPathDecode, err := url.PathUnescape(r.URL.Path)
|
||||
if err != nil {
|
||||
return byteMD5, err
|
||||
}
|
||||
|
||||
strAuth := fmt.Sprintf("%s\n%s", strURLPathDecode, string(body))
|
||||
md5Ctx := md5.New()
|
||||
md5Ctx.Write([]byte(strAuth))
|
||||
byteMD5 = md5Ctx.Sum(nil)
|
||||
|
||||
return byteMD5, nil
|
||||
}
|
||||
|
||||
// VerifyCallbackSignature 验证OSS回调请求
|
||||
func VerifyCallbackSignature(r *http.Request) error {
|
||||
bytePublicKey, err := GetPublicKey(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
byteMD5, err := getRequestMD5(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
strAuthorizationBase64 := r.Header.Get("authorization")
|
||||
if strAuthorizationBase64 == "" {
|
||||
return errors.New("no authorization field in Request header")
|
||||
}
|
||||
authorization, _ := base64.StdEncoding.DecodeString(strAuthorizationBase64)
|
||||
|
||||
pubBlock, _ := pem.Decode(bytePublicKey)
|
||||
if pubBlock == nil {
|
||||
return errors.New("pubBlock not exist")
|
||||
}
|
||||
pubInterface, err := x509.ParsePKIXPublicKey(pubBlock.Bytes)
|
||||
if (pubInterface == nil) || (err != nil) {
|
||||
return err
|
||||
}
|
||||
pub := pubInterface.(*rsa.PublicKey)
|
||||
|
||||
errorVerifyPKCS1v15 := rsa.VerifyPKCS1v15(pub, crypto.MD5, byteMD5, authorization)
|
||||
if errorVerifyPKCS1v15 != nil {
|
||||
return errorVerifyPKCS1v15
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
501
pkg/filesystem/driver/oss/handler.go
Normal file
501
pkg/filesystem/driver/oss/handler.go
Normal file
@ -0,0 +1,501 @@
|
||||
package oss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/HFO4/aliyun-oss-go-sdk/oss"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// UploadPolicy 阿里云OSS上传策略
|
||||
type UploadPolicy struct {
|
||||
Expiration string `json:"expiration"`
|
||||
Conditions []interface{} `json:"conditions"`
|
||||
}
|
||||
|
||||
// CallbackPolicy 回调策略
|
||||
type CallbackPolicy struct {
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
CallbackBody string `json:"callbackBody"`
|
||||
CallbackBodyType string `json:"callbackBodyType"`
|
||||
}
|
||||
|
||||
// Driver 阿里云OSS策略适配器
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
client *oss.Client
|
||||
bucket *oss.Bucket
|
||||
HTTPClient request.Client
|
||||
}
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
chunkRetrySleep = time.Duration(5) * time.Second
|
||||
|
||||
// MultiPartUploadThreshold 服务端使用分片上传的阈值
|
||||
MultiPartUploadThreshold uint64 = 5 * (1 << 30) // 5GB
|
||||
// VersionID 文件版本标识
|
||||
VersionID key = iota
|
||||
)
|
||||
|
||||
func NewDriver(policy *model.Policy) (*Driver, error) {
|
||||
if policy.OptionsSerialized.ChunkSize == 0 {
|
||||
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
driver := &Driver{
|
||||
Policy: policy,
|
||||
HTTPClient: request.NewClient(),
|
||||
}
|
||||
|
||||
return driver, driver.InitOSSClient(false)
|
||||
}
|
||||
|
||||
// CORS 创建跨域策略
|
||||
func (handler *Driver) CORS() error {
|
||||
return handler.client.SetBucketCORS(handler.Policy.BucketName, []oss.CORSRule{
|
||||
{
|
||||
AllowedOrigin: []string{"*"},
|
||||
AllowedMethod: []string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
},
|
||||
ExposeHeader: []string{},
|
||||
AllowedHeader: []string{"*"},
|
||||
MaxAgeSeconds: 3600,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// InitOSSClient 初始化OSS鉴权客户端
|
||||
func (handler *Driver) InitOSSClient(forceUsePublicEndpoint bool) error {
|
||||
if handler.Policy == nil {
|
||||
return errors.New("empty policy")
|
||||
}
|
||||
|
||||
// 决定是否使用内网 Endpoint
|
||||
endpoint := handler.Policy.Server
|
||||
if handler.Policy.OptionsSerialized.ServerSideEndpoint != "" && !forceUsePublicEndpoint {
|
||||
endpoint = handler.Policy.OptionsSerialized.ServerSideEndpoint
|
||||
}
|
||||
|
||||
// 初始化客户端
|
||||
client, err := oss.New(endpoint, handler.Policy.AccessKey, handler.Policy.SecretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handler.client = client
|
||||
|
||||
// 初始化存储桶
|
||||
bucket, err := client.Bucket(handler.Policy.BucketName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handler.bucket = bucket
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 列出OSS上的文件
|
||||
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// 列取文件
|
||||
base = strings.TrimPrefix(base, "/")
|
||||
if base != "" {
|
||||
base += "/"
|
||||
}
|
||||
|
||||
var (
|
||||
delimiter string
|
||||
marker string
|
||||
objects []oss.ObjectProperties
|
||||
commons []string
|
||||
)
|
||||
if !recursive {
|
||||
delimiter = "/"
|
||||
}
|
||||
|
||||
for {
|
||||
subRes, err := handler.bucket.ListObjects(oss.Marker(marker), oss.Prefix(base),
|
||||
oss.MaxKeys(1000), oss.Delimiter(delimiter))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, subRes.Objects...)
|
||||
commons = append(commons, subRes.CommonPrefixes...)
|
||||
marker = subRes.NextMarker
|
||||
if marker == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
// 处理目录
|
||||
for _, object := range commons {
|
||||
rel, err := filepath.Rel(base, object)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object),
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: 0,
|
||||
IsDir: true,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
}
|
||||
// 处理文件
|
||||
for _, object := range objects {
|
||||
rel, err := filepath.Rel(base, object.Key)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(object.Key, "/") {
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object.Key),
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: 0,
|
||||
IsDir: true,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
} else {
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object.Key),
|
||||
Source: object.Key,
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: uint64(object.Size),
|
||||
IsDir: false,
|
||||
LastModify: object.LastModified,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 通过VersionID禁止缓存
|
||||
ctx = context.WithValue(ctx, VersionID, time.Now().UnixNano())
|
||||
|
||||
// 尽可能使用私有 Endpoint
|
||||
ctx = context.WithValue(ctx, fsctx.ForceUsePublicEndpointCtx, false)
|
||||
|
||||
// 获取文件源地址
|
||||
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取文件数据流
|
||||
resp, err := handler.HTTPClient.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
).CheckHTTPResponse(200).GetRSCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.SetFirstFakeChunk()
|
||||
|
||||
// 尝试自主获取文件大小
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
resp.SetContentLength(int64(file.Size))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
fileInfo := file.Info()
|
||||
|
||||
// 凭证有效期
|
||||
credentialTTL := model.GetIntSetting("upload_session_timeout", 3600)
|
||||
|
||||
// 是否允许覆盖
|
||||
overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite
|
||||
options := []oss.Option{
|
||||
oss.Expires(time.Now().Add(time.Duration(credentialTTL) * time.Second)),
|
||||
oss.ForbidOverWrite(!overwrite),
|
||||
}
|
||||
|
||||
// 小文件直接上传
|
||||
if fileInfo.Size < MultiPartUploadThreshold {
|
||||
return handler.bucket.PutObject(fileInfo.SavePath, file, options...)
|
||||
}
|
||||
|
||||
// 超过阈值时使用分片上传
|
||||
imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initiate multipart upload: %w", err)
|
||||
}
|
||||
|
||||
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
|
||||
Max: model.GetIntSetting("chunk_retries", 5),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
_, err := handler.bucket.UploadPart(imur, content, current.Length(), current.Index()+1)
|
||||
return err
|
||||
}
|
||||
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = handler.bucket.CompleteMultipartUpload(imur, oss.CompleteAll("yes"), oss.ForbidOverWrite(!overwrite))
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件
|
||||
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
// 删除文件
|
||||
delRes, err := handler.bucket.DeleteObjects(files)
|
||||
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
// 统计未删除的文件
|
||||
failed := util.SliceDifference(files, delRes.DeletedObjects)
|
||||
if len(failed) > 0 {
|
||||
return failed, errors.New("failed to delete")
|
||||
}
|
||||
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
// quick check by extension name
|
||||
// https://help.aliyun.com/document_detail/183902.html
|
||||
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "heic", "tiff", "avif"}
|
||||
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
|
||||
supported = handler.Policy.OptionsSerialized.ThumbExts
|
||||
}
|
||||
|
||||
if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) {
|
||||
return nil, driver.ErrorThumbNotSupported
|
||||
}
|
||||
|
||||
// 初始化客户端
|
||||
if err := handler.InitOSSClient(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
thumbSize = [2]uint{400, 300}
|
||||
ok = false
|
||||
)
|
||||
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
|
||||
return nil, errors.New("failed to get thumbnail size")
|
||||
}
|
||||
|
||||
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
|
||||
|
||||
thumbParam := fmt.Sprintf("image/resize,m_lfit,h_%d,w_%d/quality,q_%d", thumbSize[1], thumbSize[0], thumbEncodeQuality)
|
||||
ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, thumbParam)
|
||||
thumbOption := []oss.Option{oss.Process(thumbParam)}
|
||||
thumbURL, err := handler.signSourceURL(
|
||||
ctx,
|
||||
file.SourceName,
|
||||
int64(model.GetIntSetting("preview_timeout", 60)),
|
||||
thumbOption,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &response.ContentResponse{
|
||||
Redirect: true,
|
||||
URL: thumbURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
// 初始化客户端
|
||||
usePublicEndpoint := true
|
||||
if forceUsePublicEndpoint, ok := ctx.Value(fsctx.ForceUsePublicEndpointCtx).(bool); ok {
|
||||
usePublicEndpoint = forceUsePublicEndpoint
|
||||
}
|
||||
if err := handler.InitOSSClient(usePublicEndpoint); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 尝试从上下文获取文件名
|
||||
fileName := ""
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
fileName = file.Name
|
||||
}
|
||||
|
||||
// 添加各项设置
|
||||
var signOptions = make([]oss.Option, 0, 2)
|
||||
if isDownload {
|
||||
signOptions = append(signOptions, oss.ResponseContentDisposition("attachment; filename=\""+url.PathEscape(fileName)+"\""))
|
||||
}
|
||||
if speed > 0 {
|
||||
// Byte 转换为 bit
|
||||
speed *= 8
|
||||
|
||||
// OSS对速度值有范围限制
|
||||
if speed < 819200 {
|
||||
speed = 819200
|
||||
}
|
||||
if speed > 838860800 {
|
||||
speed = 838860800
|
||||
}
|
||||
signOptions = append(signOptions, oss.TrafficLimitParam(int64(speed)))
|
||||
}
|
||||
|
||||
return handler.signSourceURL(ctx, path, ttl, signOptions)
|
||||
}
|
||||
|
||||
func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64, options []oss.Option) (string, error) {
|
||||
signedURL, err := handler.bucket.SignURL(path, oss.HTTPGet, ttl, options...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将最终生成的签名URL域名换成用户自定义的加速域名(如果有)
|
||||
finalURL, err := url.Parse(signedURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 公有空间替换掉Key及不支持的头
|
||||
if !handler.Policy.IsPrivate {
|
||||
query := finalURL.Query()
|
||||
query.Del("OSSAccessKeyId")
|
||||
query.Del("Signature")
|
||||
query.Del("response-content-disposition")
|
||||
query.Del("x-oss-traffic-limit")
|
||||
finalURL.RawQuery = query.Encode()
|
||||
}
|
||||
|
||||
if handler.Policy.BaseURL != "" {
|
||||
cdnURL, err := url.Parse(handler.Policy.BaseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
finalURL.Host = cdnURL.Host
|
||||
finalURL.Scheme = cdnURL.Scheme
|
||||
}
|
||||
|
||||
return finalURL.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
// 初始化客户端
|
||||
if err := handler.InitOSSClient(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成回调地址
|
||||
siteURL := model.GetSiteURL()
|
||||
apiBaseURI, _ := url.Parse("/api/v3/callback/oss/" + uploadSession.Key)
|
||||
apiURL := siteURL.ResolveReference(apiBaseURI)
|
||||
|
||||
// 回调策略
|
||||
callbackPolicy := CallbackPolicy{
|
||||
CallbackURL: apiURL.String(),
|
||||
CallbackBody: `{"name":${x:fname},"source_name":${object},"size":${size},"pic_info":"${imageInfo.width},${imageInfo.height}"}`,
|
||||
CallbackBodyType: "application/json",
|
||||
}
|
||||
callbackPolicyJSON, err := json.Marshal(callbackPolicy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode callback policy: %w", err)
|
||||
}
|
||||
callbackPolicyEncoded := base64.StdEncoding.EncodeToString(callbackPolicyJSON)
|
||||
|
||||
// 初始化分片上传
|
||||
fileInfo := file.Info()
|
||||
options := []oss.Option{
|
||||
oss.Expires(time.Now().Add(time.Duration(ttl) * time.Second)),
|
||||
oss.ForbidOverWrite(true),
|
||||
oss.ContentType(fileInfo.DetectMimeType()),
|
||||
}
|
||||
imur, err := handler.bucket.InitiateMultipartUpload(fileInfo.SavePath, options...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize multipart upload: %w", err)
|
||||
}
|
||||
uploadSession.UploadID = imur.UploadID
|
||||
|
||||
// 为每个分片签名上传 URL
|
||||
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false)
|
||||
urls := make([]string, chunks.Num())
|
||||
for chunks.Next() {
|
||||
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
|
||||
signedURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPut, ttl,
|
||||
oss.PartNumber(c.Index()+1),
|
||||
oss.UploadID(imur.UploadID),
|
||||
oss.ContentType("application/octet-stream"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
urls[c.Index()] = signedURL
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 签名完成分片上传的URL
|
||||
completeURL, err := handler.bucket.SignURL(fileInfo.SavePath, oss.HTTPPost, ttl,
|
||||
oss.ContentType("application/octet-stream"),
|
||||
oss.UploadID(imur.UploadID),
|
||||
oss.Expires(time.Now().Add(time.Duration(ttl)*time.Second)),
|
||||
oss.CompleteAll("yes"),
|
||||
oss.ForbidOverWrite(true),
|
||||
oss.CallbackParam(callbackPolicyEncoded))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &serializer.UploadCredential{
|
||||
SessionID: uploadSession.Key,
|
||||
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
|
||||
UploadID: imur.UploadID,
|
||||
UploadURLs: urls,
|
||||
CompleteURL: completeURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return handler.bucket.AbortMultipartUpload(oss.InitiateMultipartUploadResult{UploadID: uploadSession.UploadID, Key: uploadSession.SavePath}, nil)
|
||||
}
|
354
pkg/filesystem/driver/qiniu/handler.go
Normal file
354
pkg/filesystem/driver/qiniu/handler.go
Normal file
@ -0,0 +1,354 @@
|
||||
package qiniu
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
)
|
||||
|
||||
// Driver 本地策略适配器
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
mac *qbox.Mac
|
||||
cfg *storage.Config
|
||||
bucket *storage.BucketManager
|
||||
}
|
||||
|
||||
func NewDriver(policy *model.Policy) *Driver {
|
||||
if policy.OptionsSerialized.ChunkSize == 0 {
|
||||
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
mac := qbox.NewMac(policy.AccessKey, policy.SecretKey)
|
||||
cfg := &storage.Config{UseHTTPS: true}
|
||||
return &Driver{
|
||||
Policy: policy,
|
||||
mac: mac,
|
||||
cfg: cfg,
|
||||
bucket: storage.NewBucketManager(mac, cfg),
|
||||
}
|
||||
}
|
||||
|
||||
// List 列出给定路径下的文件
|
||||
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
base = strings.TrimPrefix(base, "/")
|
||||
if base != "" {
|
||||
base += "/"
|
||||
}
|
||||
|
||||
var (
|
||||
delimiter string
|
||||
marker string
|
||||
objects []storage.ListItem
|
||||
commons []string
|
||||
)
|
||||
if !recursive {
|
||||
delimiter = "/"
|
||||
}
|
||||
|
||||
for {
|
||||
entries, folders, nextMarker, hashNext, err := handler.bucket.ListFiles(
|
||||
handler.Policy.BucketName,
|
||||
base, delimiter, marker, 1000)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, entries...)
|
||||
commons = append(commons, folders...)
|
||||
if !hashNext {
|
||||
break
|
||||
}
|
||||
marker = nextMarker
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
// 处理目录
|
||||
for _, object := range commons {
|
||||
rel, err := filepath.Rel(base, object)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object),
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: 0,
|
||||
IsDir: true,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
}
|
||||
// 处理文件
|
||||
for _, object := range objects {
|
||||
rel, err := filepath.Rel(base, object.Key)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object.Key),
|
||||
Source: object.Key,
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: uint64(object.Fsize),
|
||||
IsDir: false,
|
||||
LastModify: time.Unix(object.PutTime/10000000, 0),
|
||||
})
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 给文件名加上随机参数以强制拉取
|
||||
path = fmt.Sprintf("%s?v=%d", path, time.Now().UnixNano())
|
||||
|
||||
// 获取文件源地址
|
||||
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取文件数据流
|
||||
client := request.NewClient()
|
||||
resp, err := client.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithHeader(
|
||||
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
|
||||
),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
).CheckHTTPResponse(200).GetRSCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.SetFirstFakeChunk()
|
||||
|
||||
// 尝试自主获取文件大小
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
resp.SetContentLength(int64(file.Size))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
|
||||
// 凭证有效期
|
||||
credentialTTL := model.GetIntSetting("upload_session_timeout", 3600)
|
||||
|
||||
// 生成上传策略
|
||||
fileInfo := file.Info()
|
||||
scope := handler.Policy.BucketName
|
||||
if fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite {
|
||||
scope = fmt.Sprintf("%s:%s", handler.Policy.BucketName, fileInfo.SavePath)
|
||||
}
|
||||
|
||||
putPolicy := storage.PutPolicy{
|
||||
// 指定为覆盖策略
|
||||
Scope: scope,
|
||||
SaveKey: fileInfo.SavePath,
|
||||
ForceSaveKey: true,
|
||||
FsizeLimit: int64(fileInfo.Size),
|
||||
}
|
||||
// 是否开启了MIMEType限制
|
||||
if handler.Policy.OptionsSerialized.MimeType != "" {
|
||||
putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType
|
||||
}
|
||||
|
||||
// 生成上传凭证
|
||||
token, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, int64(credentialTTL), false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建上传表单
|
||||
cfg := storage.Config{}
|
||||
formUploader := storage.NewFormUploader(&cfg)
|
||||
ret := storage.PutRet{}
|
||||
putExtra := storage.PutExtra{
|
||||
Params: map[string]string{},
|
||||
}
|
||||
|
||||
// 开始上传
|
||||
err = formUploader.Put(ctx, &ret, token.Credential, fileInfo.SavePath, file, int64(fileInfo.Size), &putExtra)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件
|
||||
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
// TODO 大于一千个文件需要分批发送
|
||||
deleteOps := make([]string, 0, len(files))
|
||||
for _, key := range files {
|
||||
deleteOps = append(deleteOps, storage.URIDelete(handler.Policy.BucketName, key))
|
||||
}
|
||||
|
||||
rets, err := handler.bucket.Batch(deleteOps)
|
||||
|
||||
// 处理删除结果
|
||||
if err != nil {
|
||||
failed := make([]string, 0, len(rets))
|
||||
for k, ret := range rets {
|
||||
if ret.Code != 200 && ret.Code != 612 {
|
||||
failed = append(failed, files[k])
|
||||
}
|
||||
}
|
||||
return failed, errors.New("删除失败")
|
||||
}
|
||||
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
// quick check by extension name
|
||||
// https://developer.qiniu.com/dora/api/basic-processing-images-imageview2
|
||||
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "tiff", "avif", "psd"}
|
||||
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
|
||||
supported = handler.Policy.OptionsSerialized.ThumbExts
|
||||
}
|
||||
|
||||
if !util.IsInExtensionList(supported, file.Name) || file.Size > (20<<(10*2)) {
|
||||
return nil, driver.ErrorThumbNotSupported
|
||||
}
|
||||
|
||||
var (
|
||||
thumbSize = [2]uint{400, 300}
|
||||
ok = false
|
||||
)
|
||||
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
|
||||
return nil, errors.New("failed to get thumbnail size")
|
||||
}
|
||||
|
||||
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
|
||||
|
||||
thumb := fmt.Sprintf("%s?imageView2/1/w/%d/h/%d/q/%d", file.SourceName, thumbSize[0], thumbSize[1], thumbEncodeQuality)
|
||||
return &response.ContentResponse{
|
||||
Redirect: true,
|
||||
URL: handler.signSourceURL(
|
||||
ctx,
|
||||
thumb,
|
||||
int64(model.GetIntSetting("preview_timeout", 60)),
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
// 尝试从上下文获取文件名
|
||||
fileName := ""
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
fileName = file.Name
|
||||
}
|
||||
|
||||
// 加入下载相关设置
|
||||
if isDownload {
|
||||
path = path + "?attname=" + url.PathEscape(fileName)
|
||||
}
|
||||
|
||||
// 取得原始文件地址
|
||||
return handler.signSourceURL(ctx, path, ttl), nil
|
||||
}
|
||||
|
||||
func (handler *Driver) signSourceURL(ctx context.Context, path string, ttl int64) string {
|
||||
var sourceURL string
|
||||
if handler.Policy.IsPrivate {
|
||||
deadline := time.Now().Add(time.Second * time.Duration(ttl)).Unix()
|
||||
sourceURL = storage.MakePrivateURL(handler.mac, handler.Policy.BaseURL, path, deadline)
|
||||
} else {
|
||||
sourceURL = storage.MakePublicURL(handler.Policy.BaseURL, path)
|
||||
}
|
||||
return sourceURL
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
// 生成回调地址
|
||||
siteURL := model.GetSiteURL()
|
||||
apiBaseURI, _ := url.Parse("/api/v3/callback/qiniu/" + uploadSession.Key)
|
||||
apiURL := siteURL.ResolveReference(apiBaseURI)
|
||||
|
||||
// 创建上传策略
|
||||
fileInfo := file.Info()
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: handler.Policy.BucketName,
|
||||
CallbackURL: apiURL.String(),
|
||||
CallbackBody: `{"size":$(fsize),"pic_info":"$(imageInfo.width),$(imageInfo.height)"}`,
|
||||
CallbackBodyType: "application/json",
|
||||
SaveKey: fileInfo.SavePath,
|
||||
ForceSaveKey: true,
|
||||
FsizeLimit: int64(handler.Policy.MaxSize),
|
||||
}
|
||||
// 是否开启了MIMEType限制
|
||||
if handler.Policy.OptionsSerialized.MimeType != "" {
|
||||
putPolicy.MimeLimit = handler.Policy.OptionsSerialized.MimeType
|
||||
}
|
||||
|
||||
credential, err := handler.getUploadCredential(ctx, putPolicy, fileInfo, ttl, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to init parts: %w", err)
|
||||
}
|
||||
|
||||
credential.SessionID = uploadSession.Key
|
||||
credential.ChunkSize = handler.Policy.OptionsSerialized.ChunkSize
|
||||
|
||||
uploadSession.UploadURL = credential.UploadURLs[0]
|
||||
uploadSession.Credential = credential.Credential
|
||||
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
// getUploadCredential 签名上传策略并创建上传会话
|
||||
func (handler *Driver) getUploadCredential(ctx context.Context, policy storage.PutPolicy, file *fsctx.UploadTaskInfo, TTL int64, resume bool) (*serializer.UploadCredential, error) {
|
||||
// 上传凭证
|
||||
policy.Expires = uint64(TTL)
|
||||
upToken := policy.UploadToken(handler.mac)
|
||||
|
||||
// 初始化分片上传
|
||||
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
|
||||
upHost, err := resumeUploader.UpHost(handler.Policy.AccessKey, handler.Policy.BucketName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret := &storage.InitPartsRet{}
|
||||
if resume {
|
||||
err = resumeUploader.InitParts(ctx, upToken, upHost, handler.Policy.BucketName, file.SavePath, true, ret)
|
||||
}
|
||||
|
||||
return &serializer.UploadCredential{
|
||||
UploadURLs: []string{upHost + "/buckets/" + handler.Policy.BucketName + "/objects/" + base64.URLEncoding.EncodeToString([]byte(file.SavePath)) + "/uploads/" + ret.UploadID},
|
||||
Credential: upToken,
|
||||
}, err
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
resumeUploader := storage.NewResumeUploaderV2(handler.cfg)
|
||||
return resumeUploader.Client.CallWith(ctx, nil, "DELETE", uploadSession.UploadURL, http.Header{"Authorization": {"UpToken " + uploadSession.Credential}}, nil, 0)
|
||||
}
|
195
pkg/filesystem/driver/remote/client.go
Normal file
195
pkg/filesystem/driver/remote/client.go
Normal file
@ -0,0 +1,195 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gofrs/uuid"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
basePath = "/api/v3/slave/"
|
||||
OverwriteHeader = auth.CrHeaderPrefix + "Overwrite"
|
||||
chunkRetrySleep = time.Duration(5) * time.Second
|
||||
)
|
||||
|
||||
// Client to operate uploading to remote slave server
|
||||
type Client interface {
|
||||
// CreateUploadSession creates remote upload session
|
||||
CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error
|
||||
// GetUploadURL signs an url for uploading file
|
||||
GetUploadURL(ttl int64, sessionID string) (string, string, error)
|
||||
// Upload uploads file to remote server
|
||||
Upload(ctx context.Context, file fsctx.FileHeader) error
|
||||
// DeleteUploadSession deletes remote upload session
|
||||
DeleteUploadSession(ctx context.Context, sessionID string) error
|
||||
}
|
||||
|
||||
// NewClient creates new Client from given policy
|
||||
func NewClient(policy *model.Policy) (Client, error) {
|
||||
authInstance := auth.HMACAuth{[]byte(policy.SecretKey)}
|
||||
serverURL, err := url.Parse(policy.Server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base, _ := url.Parse(basePath)
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
|
||||
return &remoteClient{
|
||||
policy: policy,
|
||||
authInstance: authInstance,
|
||||
httpClient: request.NewClient(
|
||||
request.WithEndpoint(serverURL.ResolveReference(base).String()),
|
||||
request.WithCredential(authInstance, int64(signTTL)),
|
||||
request.WithMasterMeta(),
|
||||
request.WithSlaveMeta(policy.AccessKey),
|
||||
),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type remoteClient struct {
|
||||
policy *model.Policy
|
||||
authInstance auth.Auth
|
||||
httpClient request.Client
|
||||
}
|
||||
|
||||
func (c *remoteClient) Upload(ctx context.Context, file fsctx.FileHeader) error {
|
||||
ttl := model.GetIntSetting("upload_session_timeout", 86400)
|
||||
fileInfo := file.Info()
|
||||
session := &serializer.UploadSession{
|
||||
Key: uuid.Must(uuid.NewV4()).String(),
|
||||
VirtualPath: fileInfo.VirtualPath,
|
||||
Name: fileInfo.FileName,
|
||||
Size: fileInfo.Size,
|
||||
SavePath: fileInfo.SavePath,
|
||||
LastModified: fileInfo.LastModified,
|
||||
Policy: *c.policy,
|
||||
}
|
||||
|
||||
// Create upload session
|
||||
overwrite := fileInfo.Mode&fsctx.Overwrite == fsctx.Overwrite
|
||||
if err := c.CreateUploadSession(ctx, session, int64(ttl), overwrite); err != nil {
|
||||
return fmt.Errorf("failed to create upload session: %w", err)
|
||||
}
|
||||
|
||||
// Initial chunk groups
|
||||
chunks := chunk.NewChunkGroup(file, c.policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{
|
||||
Max: model.GetIntSetting("chunk_retries", 5),
|
||||
Sleep: chunkRetrySleep,
|
||||
}, model.IsTrueVal(model.GetSettingByName("use_temp_chunk_buffer")))
|
||||
|
||||
uploadFunc := func(current *chunk.ChunkGroup, content io.Reader) error {
|
||||
return c.uploadChunk(ctx, session.Key, current.Index(), content, overwrite, current.Length())
|
||||
}
|
||||
|
||||
// upload chunks
|
||||
for chunks.Next() {
|
||||
if err := chunks.Process(uploadFunc); err != nil {
|
||||
if err := c.DeleteUploadSession(ctx, session.Key); err != nil {
|
||||
util.Log().Warning("failed to delete upload session: %s", err)
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to upload chunk #%d: %w", chunks.Index(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) DeleteUploadSession(ctx context.Context, sessionID string) error {
|
||||
resp, err := c.httpClient.Request(
|
||||
"DELETE",
|
||||
"upload/"+sessionID,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error {
|
||||
reqBodyEncoded, err := json.Marshal(map[string]interface{}{
|
||||
"session": session,
|
||||
"ttl": ttl,
|
||||
"overwrite": overwrite,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
resp, err := c.httpClient.Request(
|
||||
"PUT",
|
||||
"upload",
|
||||
bodyReader,
|
||||
request.WithContext(ctx),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) GetUploadURL(ttl int64, sessionID string) (string, string, error) {
|
||||
base, err := url.Parse(c.policy.Server)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
base.Path = path.Join(base.Path, basePath, "upload", sessionID)
|
||||
req, err := http.NewRequest("POST", base.String(), nil)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
req = auth.SignRequest(c.authInstance, req, ttl)
|
||||
return req.URL.String(), req.Header["Authorization"][0], nil
|
||||
}
|
||||
|
||||
func (c *remoteClient) uploadChunk(ctx context.Context, sessionID string, index int, chunk io.Reader, overwrite bool, size int64) error {
|
||||
resp, err := c.httpClient.Request(
|
||||
"POST",
|
||||
fmt.Sprintf("upload/%s?chunk=%d", sessionID, index),
|
||||
chunk,
|
||||
request.WithContext(ctx),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
request.WithContentLength(size),
|
||||
request.WithHeader(map[string][]string{OverwriteHeader: {fmt.Sprintf("%t", overwrite)}}),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(resp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
311
pkg/filesystem/driver/remote/handler.go
Normal file
311
pkg/filesystem/driver/remote/handler.go
Normal file
@ -0,0 +1,311 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Driver 远程存储策略适配器
|
||||
type Driver struct {
|
||||
Client request.Client
|
||||
Policy *model.Policy
|
||||
AuthInstance auth.Auth
|
||||
|
||||
uploadClient Client
|
||||
}
|
||||
|
||||
// NewDriver initializes a new Driver from policy
|
||||
// TODO: refactor all method into upload client
|
||||
func NewDriver(policy *model.Policy) (*Driver, error) {
|
||||
client, err := NewClient(policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Driver{
|
||||
Policy: policy,
|
||||
Client: request.NewClient(),
|
||||
AuthInstance: auth.HMACAuth{[]byte(policy.SecretKey)},
|
||||
uploadClient: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// List 列取文件
|
||||
func (handler *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
var res []response.Object
|
||||
|
||||
reqBody := serializer.ListRequest{
|
||||
Path: path,
|
||||
Recursive: recursive,
|
||||
}
|
||||
reqBodyEncoded, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
// 发送列表请求
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
resp, err := handler.Client.Request(
|
||||
"POST",
|
||||
handler.getAPIUrl("list"),
|
||||
bodyReader,
|
||||
request.WithCredential(handler.AuthInstance, int64(signTTL)),
|
||||
request.WithMasterMeta(),
|
||||
).CheckHTTPResponse(200).DecodeResponse()
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
if resp.Code != 0 {
|
||||
return res, errors.New(resp.Error)
|
||||
}
|
||||
|
||||
if resStr, ok := resp.Data.(string); ok {
|
||||
err = json.Unmarshal([]byte(resStr), &res)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// getAPIUrl 获取接口请求地址
|
||||
func (handler *Driver) getAPIUrl(scope string, routes ...string) string {
|
||||
serverURL, err := url.Parse(handler.Policy.Server)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var controller *url.URL
|
||||
|
||||
switch scope {
|
||||
case "delete":
|
||||
controller, _ = url.Parse("/api/v3/slave/delete")
|
||||
case "thumb":
|
||||
controller, _ = url.Parse("/api/v3/slave/thumb")
|
||||
case "list":
|
||||
controller, _ = url.Parse("/api/v3/slave/list")
|
||||
default:
|
||||
controller = serverURL
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
controller.Path = path.Join(controller.Path, r)
|
||||
}
|
||||
|
||||
return serverURL.ResolveReference(controller).String()
|
||||
}
|
||||
|
||||
// Get 获取文件内容
|
||||
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 尝试获取速度限制
|
||||
speedLimit := 0
|
||||
if user, ok := ctx.Value(fsctx.UserCtx).(model.User); ok {
|
||||
speedLimit = user.Group.SpeedLimit
|
||||
}
|
||||
|
||||
// 获取文件源地址
|
||||
downloadURL, err := handler.Source(ctx, path, 0, true, speedLimit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取文件数据流
|
||||
resp, err := handler.Client.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
request.WithMasterMeta(),
|
||||
).CheckHTTPResponse(200).GetRSCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.SetFirstFakeChunk()
|
||||
|
||||
// 尝试获取文件大小
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
resp.SetContentLength(int64(file.Size))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
|
||||
return handler.uploadClient.Upload(ctx, file)
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
// 封装接口请求正文
|
||||
reqBody := serializer.RemoteDeleteRequest{
|
||||
Files: files,
|
||||
}
|
||||
reqBodyEncoded, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
// 发送删除请求
|
||||
bodyReader := strings.NewReader(string(reqBodyEncoded))
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
resp, err := handler.Client.Request(
|
||||
"POST",
|
||||
handler.getAPIUrl("delete"),
|
||||
bodyReader,
|
||||
request.WithCredential(handler.AuthInstance, int64(signTTL)),
|
||||
request.WithMasterMeta(),
|
||||
request.WithSlaveMeta(handler.Policy.AccessKey),
|
||||
).CheckHTTPResponse(200).GetResponse()
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
// 处理删除结果
|
||||
var reqResp serializer.Response
|
||||
err = json.Unmarshal([]byte(resp), &reqResp)
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
if reqResp.Code != 0 {
|
||||
var failedResp serializer.RemoteDeleteRequest
|
||||
if failed, ok := reqResp.Data.(string); ok {
|
||||
err = json.Unmarshal([]byte(failed), &failedResp)
|
||||
if err == nil {
|
||||
return failedResp.Files, errors.New(reqResp.Error)
|
||||
}
|
||||
}
|
||||
return files, errors.New("unknown format of returned response")
|
||||
}
|
||||
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
// quick check by extension name
|
||||
supported := []string{"png", "jpg", "jpeg", "gif"}
|
||||
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
|
||||
supported = handler.Policy.OptionsSerialized.ThumbExts
|
||||
}
|
||||
|
||||
if !util.IsInExtensionList(supported, file.Name) {
|
||||
return nil, driver.ErrorThumbNotSupported
|
||||
}
|
||||
|
||||
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(file.SourceName))
|
||||
thumbURL := fmt.Sprintf("%s/%s/%s", handler.getAPIUrl("thumb"), sourcePath, filepath.Ext(file.Name))
|
||||
ttl := model.GetIntSetting("preview_timeout", 60)
|
||||
signedThumbURL, err := auth.SignURI(handler.AuthInstance, thumbURL, int64(ttl))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &response.ContentResponse{
|
||||
Redirect: true,
|
||||
URL: signedThumbURL.String(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
// 尝试从上下文获取文件名
|
||||
fileName := "file"
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
fileName = file.Name
|
||||
}
|
||||
|
||||
serverURL, err := url.Parse(handler.Policy.Server)
|
||||
if err != nil {
|
||||
return "", errors.New("无法解析远程服务端地址")
|
||||
}
|
||||
|
||||
// 是否启用了CDN
|
||||
if handler.Policy.BaseURL != "" {
|
||||
cdnURL, err := url.Parse(handler.Policy.BaseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
serverURL = cdnURL
|
||||
}
|
||||
|
||||
var (
|
||||
signedURI *url.URL
|
||||
controller = "/api/v3/slave/download"
|
||||
)
|
||||
if !isDownload {
|
||||
controller = "/api/v3/slave/source"
|
||||
}
|
||||
|
||||
// 签名下载地址
|
||||
sourcePath := base64.RawURLEncoding.EncodeToString([]byte(path))
|
||||
signedURI, err = auth.SignURI(
|
||||
handler.AuthInstance,
|
||||
fmt.Sprintf("%s/%d/%s/%s", controller, speed, sourcePath, url.PathEscape(fileName)),
|
||||
ttl,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return "", serializer.NewError(serializer.CodeEncryptError, "Failed to sign URL", err)
|
||||
}
|
||||
|
||||
finalURL := serverURL.ResolveReference(signedURI).String()
|
||||
return finalURL, nil
|
||||
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
siteURL := model.GetSiteURL()
|
||||
apiBaseURI, _ := url.Parse(path.Join("/api/v3/callback/remote", uploadSession.Key, uploadSession.CallbackSecret))
|
||||
apiURL := siteURL.ResolveReference(apiBaseURI)
|
||||
|
||||
// 在从机端创建上传会话
|
||||
uploadSession.Callback = apiURL.String()
|
||||
if err := handler.uploadClient.CreateUploadSession(ctx, uploadSession, ttl, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取上传地址
|
||||
uploadURL, sign, err := handler.uploadClient.GetUploadURL(ttl, uploadSession.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign upload url: %w", err)
|
||||
}
|
||||
|
||||
return &serializer.UploadCredential{
|
||||
SessionID: uploadSession.Key,
|
||||
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
|
||||
UploadURLs: []string{uploadURL},
|
||||
Credential: sign,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return handler.uploadClient.DeleteUploadSession(ctx, uploadSession.Key)
|
||||
}
|
440
pkg/filesystem/driver/s3/handler.go
Normal file
440
pkg/filesystem/driver/s3/handler.go
Normal file
@ -0,0 +1,440 @@
|
||||
package s3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/chunk/backoff"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Driver 适配器模板
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
sess *session.Session
|
||||
svc *s3.S3
|
||||
}
|
||||
|
||||
// UploadPolicy S3上传策略
|
||||
type UploadPolicy struct {
|
||||
Expiration string `json:"expiration"`
|
||||
Conditions []interface{} `json:"conditions"`
|
||||
}
|
||||
|
||||
// MetaData 文件信息
|
||||
type MetaData struct {
|
||||
Size uint64
|
||||
Etag string
|
||||
}
|
||||
|
||||
func NewDriver(policy *model.Policy) (*Driver, error) {
|
||||
if policy.OptionsSerialized.ChunkSize == 0 {
|
||||
policy.OptionsSerialized.ChunkSize = 25 << 20 // 25 MB
|
||||
}
|
||||
|
||||
driver := &Driver{
|
||||
Policy: policy,
|
||||
}
|
||||
|
||||
return driver, driver.InitS3Client()
|
||||
}
|
||||
|
||||
// InitS3Client 初始化S3会话
|
||||
func (handler *Driver) InitS3Client() error {
|
||||
if handler.Policy == nil {
|
||||
return errors.New("empty policy")
|
||||
}
|
||||
|
||||
if handler.svc == nil {
|
||||
// 初始化会话
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Credentials: credentials.NewStaticCredentials(handler.Policy.AccessKey, handler.Policy.SecretKey, ""),
|
||||
Endpoint: &handler.Policy.Server,
|
||||
Region: &handler.Policy.OptionsSerialized.Region,
|
||||
S3ForcePathStyle: &handler.Policy.OptionsSerialized.S3ForcePathStyle,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
handler.sess = sess
|
||||
handler.svc = s3.New(sess)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 列出给定路径下的文件
|
||||
func (handler *Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
// 初始化列目录参数
|
||||
base = strings.TrimPrefix(base, "/")
|
||||
if base != "" {
|
||||
base += "/"
|
||||
}
|
||||
|
||||
opt := &s3.ListObjectsInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Prefix: &base,
|
||||
MaxKeys: aws.Int64(1000),
|
||||
}
|
||||
|
||||
// 是否为递归列出
|
||||
if !recursive {
|
||||
opt.Delimiter = aws.String("/")
|
||||
}
|
||||
|
||||
var (
|
||||
objects []*s3.Object
|
||||
commons []*s3.CommonPrefix
|
||||
)
|
||||
|
||||
for {
|
||||
res, err := handler.svc.ListObjectsWithContext(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, res.Contents...)
|
||||
commons = append(commons, res.CommonPrefixes...)
|
||||
|
||||
// 如果本次未列取完,则继续使用marker获取结果
|
||||
if *res.IsTruncated {
|
||||
opt.Marker = res.NextMarker
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 处理列取结果
|
||||
res := make([]response.Object, 0, len(objects)+len(commons))
|
||||
|
||||
// 处理目录
|
||||
for _, object := range commons {
|
||||
rel, err := filepath.Rel(*opt.Prefix, *object.Prefix)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(*object.Prefix),
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: 0,
|
||||
IsDir: true,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
}
|
||||
// 处理文件
|
||||
for _, object := range objects {
|
||||
rel, err := filepath.Rel(*opt.Prefix, *object.Key)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(*object.Key),
|
||||
Source: *object.Key,
|
||||
RelativePath: filepath.ToSlash(rel),
|
||||
Size: uint64(*object.Size),
|
||||
IsDir: false,
|
||||
LastModify: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
return res, nil
|
||||
|
||||
}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 获取文件源地址
|
||||
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取文件数据流
|
||||
client := request.NewClient()
|
||||
resp, err := client.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithHeader(
|
||||
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
|
||||
),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
).CheckHTTPResponse(200).GetRSCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.SetFirstFakeChunk()
|
||||
|
||||
// 尝试自主获取文件大小
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
resp.SetContentLength(int64(file.Size))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
|
||||
// 初始化客户端
|
||||
if err := handler.InitS3Client(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uploader := s3manager.NewUploader(handler.sess, func(u *s3manager.Uploader) {
|
||||
u.PartSize = int64(handler.Policy.OptionsSerialized.ChunkSize)
|
||||
})
|
||||
|
||||
dst := file.Info().SavePath
|
||||
_, err := uploader.Upload(&s3manager.UploadInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &dst,
|
||||
Body: io.LimitReader(file, int64(file.Info().Size)),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
failed := make([]string, 0, len(files))
|
||||
deleted := make([]string, 0, len(files))
|
||||
|
||||
keys := make([]*s3.ObjectIdentifier, 0, len(files))
|
||||
for _, file := range files {
|
||||
filePath := file
|
||||
keys = append(keys, &s3.ObjectIdentifier{Key: &filePath})
|
||||
}
|
||||
|
||||
// 发送异步删除请求
|
||||
res, err := handler.svc.DeleteObjects(
|
||||
&s3.DeleteObjectsInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Delete: &s3.Delete{
|
||||
Objects: keys,
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return files, err
|
||||
}
|
||||
|
||||
// 统计未删除的文件
|
||||
for _, deleteRes := range res.Deleted {
|
||||
deleted = append(deleted, *deleteRes.Key)
|
||||
}
|
||||
failed = util.SliceDifference(files, deleted)
|
||||
|
||||
return failed, nil
|
||||
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
return nil, driver.ErrorThumbNotSupported
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
|
||||
// 尝试从上下文获取文件名
|
||||
fileName := ""
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
fileName = file.Name
|
||||
}
|
||||
|
||||
// 初始化客户端
|
||||
if err := handler.InitS3Client(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
contentDescription := aws.String("attachment; filename=\"" + url.PathEscape(fileName) + "\"")
|
||||
if !isDownload {
|
||||
contentDescription = nil
|
||||
}
|
||||
req, _ := handler.svc.GetObjectRequest(
|
||||
&s3.GetObjectInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &path,
|
||||
ResponseContentDisposition: contentDescription,
|
||||
})
|
||||
|
||||
signedURL, err := req.Presign(time.Duration(ttl) * time.Second)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将最终生成的签名URL域名换成用户自定义的加速域名(如果有)
|
||||
finalURL, err := url.Parse(signedURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 公有空间替换掉Key及不支持的头
|
||||
if !handler.Policy.IsPrivate {
|
||||
finalURL.RawQuery = ""
|
||||
}
|
||||
|
||||
if handler.Policy.BaseURL != "" {
|
||||
cdnURL, err := url.Parse(handler.Policy.BaseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
finalURL.Host = cdnURL.Host
|
||||
finalURL.Scheme = cdnURL.Scheme
|
||||
}
|
||||
|
||||
return finalURL.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
// 检查文件是否存在
|
||||
fileInfo := file.Info()
|
||||
if _, err := handler.Meta(ctx, fileInfo.SavePath); err == nil {
|
||||
return nil, fmt.Errorf("file already exist")
|
||||
}
|
||||
|
||||
// 创建分片上传
|
||||
expires := time.Now().Add(time.Duration(ttl) * time.Second)
|
||||
res, err := handler.svc.CreateMultipartUpload(&s3.CreateMultipartUploadInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &fileInfo.SavePath,
|
||||
Expires: &expires,
|
||||
ContentType: aws.String(fileInfo.DetectMimeType()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create multipart upload: %w", err)
|
||||
}
|
||||
|
||||
uploadSession.UploadID = *res.UploadId
|
||||
|
||||
// 为每个分片签名上传 URL
|
||||
chunks := chunk.NewChunkGroup(file, handler.Policy.OptionsSerialized.ChunkSize, &backoff.ConstantBackoff{}, false)
|
||||
urls := make([]string, chunks.Num())
|
||||
for chunks.Next() {
|
||||
err := chunks.Process(func(c *chunk.ChunkGroup, chunk io.Reader) error {
|
||||
signedReq, _ := handler.svc.UploadPartRequest(&s3.UploadPartInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &fileInfo.SavePath,
|
||||
PartNumber: aws.Int64(int64(c.Index() + 1)),
|
||||
UploadId: res.UploadId,
|
||||
})
|
||||
|
||||
signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
urls[c.Index()] = signedURL
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 签名完成分片上传的请求URL
|
||||
signedReq, _ := handler.svc.CompleteMultipartUploadRequest(&s3.CompleteMultipartUploadInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &fileInfo.SavePath,
|
||||
UploadId: res.UploadId,
|
||||
})
|
||||
|
||||
signedURL, err := signedReq.Presign(time.Duration(ttl) * time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成上传凭证
|
||||
return &serializer.UploadCredential{
|
||||
SessionID: uploadSession.Key,
|
||||
ChunkSize: handler.Policy.OptionsSerialized.ChunkSize,
|
||||
UploadID: *res.UploadId,
|
||||
UploadURLs: urls,
|
||||
CompleteURL: signedURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Meta 获取文件信息
|
||||
func (handler *Driver) Meta(ctx context.Context, path string) (*MetaData, error) {
|
||||
res, err := handler.svc.HeadObject(
|
||||
&s3.HeadObjectInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &path,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MetaData{
|
||||
Size: uint64(*res.ContentLength),
|
||||
Etag: *res.ETag,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// CORS 创建跨域策略
|
||||
func (handler *Driver) CORS() error {
|
||||
rule := s3.CORSRule{
|
||||
AllowedMethods: aws.StringSlice([]string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"HEAD",
|
||||
}),
|
||||
AllowedOrigins: aws.StringSlice([]string{"*"}),
|
||||
AllowedHeaders: aws.StringSlice([]string{"*"}),
|
||||
ExposeHeaders: aws.StringSlice([]string{"ETag"}),
|
||||
MaxAgeSeconds: aws.Int64(3600),
|
||||
}
|
||||
|
||||
_, err := handler.svc.PutBucketCors(&s3.PutBucketCorsInput{
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
CORSConfiguration: &s3.CORSConfiguration{
|
||||
CORSRules: []*s3.CORSRule{&rule},
|
||||
},
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
_, err := handler.svc.AbortMultipartUpload(&s3.AbortMultipartUploadInput{
|
||||
UploadId: &uploadSession.UploadID,
|
||||
Bucket: &handler.Policy.BucketName,
|
||||
Key: &uploadSession.SavePath,
|
||||
})
|
||||
return err
|
||||
}
|
7
pkg/filesystem/driver/shadow/masterinslave/errors.go
Normal file
7
pkg/filesystem/driver/shadow/masterinslave/errors.go
Normal file
@ -0,0 +1,7 @@
|
||||
package masterinslave
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
|
||||
)
|
60
pkg/filesystem/driver/shadow/masterinslave/handler.go
Normal file
60
pkg/filesystem/driver/shadow/masterinslave/handler.go
Normal file
@ -0,0 +1,60 @@
|
||||
package masterinslave
|
||||
|
||||
import (
|
||||
"context"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Driver 影子存储策略,用于在从机端上传文件
|
||||
type Driver struct {
|
||||
master cluster.Node
|
||||
handler driver.Handler
|
||||
policy *model.Policy
|
||||
}
|
||||
|
||||
// NewDriver 返回新的处理器
|
||||
func NewDriver(master cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
|
||||
return &Driver{
|
||||
master: master,
|
||||
handler: handler,
|
||||
policy: policy,
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
return d.handler.Put(ctx, file)
|
||||
}
|
||||
|
||||
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
return d.handler.Delete(ctx, files)
|
||||
}
|
||||
|
||||
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
return "", ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return nil
|
||||
}
|
9
pkg/filesystem/driver/shadow/slaveinmaster/errors.go
Normal file
9
pkg/filesystem/driver/shadow/slaveinmaster/errors.go
Normal file
@ -0,0 +1,9 @@
|
||||
package slaveinmaster
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNotImplemented = errors.New("this method of shadowed policy is not implemented")
|
||||
ErrSlaveSrcPathNotExist = errors.New("cannot determine source file path in slave node")
|
||||
ErrWaitResultTimeout = errors.New("timeout waiting for slave transfer result")
|
||||
)
|
124
pkg/filesystem/driver/shadow/slaveinmaster/handler.go
Normal file
124
pkg/filesystem/driver/shadow/slaveinmaster/handler.go
Normal file
@ -0,0 +1,124 @@
|
||||
package slaveinmaster
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// Driver 影子存储策略,将上传任务指派给从机节点处理,并等待从机通知上传结果
|
||||
type Driver struct {
|
||||
node cluster.Node
|
||||
handler driver.Handler
|
||||
policy *model.Policy
|
||||
client request.Client
|
||||
}
|
||||
|
||||
// NewDriver 返回新的从机指派处理器
|
||||
func NewDriver(node cluster.Node, handler driver.Handler, policy *model.Policy) driver.Handler {
|
||||
var endpoint *url.URL
|
||||
if serverURL, err := url.Parse(node.DBModel().Server); err == nil {
|
||||
var controller *url.URL
|
||||
controller, _ = url.Parse("/api/v3/slave/")
|
||||
endpoint = serverURL.ResolveReference(controller)
|
||||
}
|
||||
|
||||
signTTL := model.GetIntSetting("slave_api_timeout", 60)
|
||||
return &Driver{
|
||||
node: node,
|
||||
handler: handler,
|
||||
policy: policy,
|
||||
client: request.NewClient(
|
||||
request.WithMasterMeta(),
|
||||
request.WithTimeout(time.Duration(signTTL)*time.Second),
|
||||
request.WithCredential(node.SlaveAuthInstance(), int64(signTTL)),
|
||||
request.WithEndpoint(endpoint.String()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// Put 将ctx中指定的从机物理文件由从机上传到目标存储策略
|
||||
func (d *Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
|
||||
fileInfo := file.Info()
|
||||
req := serializer.SlaveTransferReq{
|
||||
Src: fileInfo.Src,
|
||||
Dst: fileInfo.SavePath,
|
||||
Policy: d.policy,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 订阅转存结果
|
||||
resChan := mq.GlobalMQ.Subscribe(req.Hash(model.GetSettingByName("siteID")), 0)
|
||||
defer mq.GlobalMQ.Unsubscribe(req.Hash(model.GetSettingByName("siteID")), resChan)
|
||||
|
||||
res, err := d.client.Request("PUT", "task/transfer", bytes.NewReader(body)).
|
||||
CheckHTTPResponse(200).
|
||||
DecodeResponse()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return serializer.NewErrorFromResponse(res)
|
||||
}
|
||||
|
||||
// 等待转存结果或者超时
|
||||
waitTimeout := model.GetIntSetting("slave_transfer_timeout", 172800)
|
||||
select {
|
||||
case <-time.After(time.Duration(waitTimeout) * time.Second):
|
||||
return ErrWaitResultTimeout
|
||||
case msg := <-resChan:
|
||||
if msg.Event != serializer.SlaveTransferSuccess {
|
||||
return errors.New(msg.Content.(serializer.SlaveTransferResult).Error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
return d.handler.Delete(ctx, files)
|
||||
}
|
||||
|
||||
func (d *Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
return "", ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
func (d *Driver) List(ctx context.Context, path string, recursive bool) ([]response.Object, error) {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (d *Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return nil
|
||||
}
|
358
pkg/filesystem/driver/upyun/handler.go
Normal file
358
pkg/filesystem/driver/upyun/handler.go
Normal file
@ -0,0 +1,358 @@
|
||||
package upyun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/upyun/go-sdk/upyun"
|
||||
)
|
||||
|
||||
// UploadPolicy 又拍云上传策略
|
||||
type UploadPolicy struct {
|
||||
Bucket string `json:"bucket"`
|
||||
SaveKey string `json:"save-key"`
|
||||
Expiration int64 `json:"expiration"`
|
||||
CallbackURL string `json:"notify-url"`
|
||||
ContentLength uint64 `json:"content-length"`
|
||||
ContentLengthRange string `json:"content-length-range,omitempty"`
|
||||
AllowFileType string `json:"allow-file-type,omitempty"`
|
||||
}
|
||||
|
||||
// Driver 又拍云策略适配器
|
||||
type Driver struct {
|
||||
Policy *model.Policy
|
||||
}
|
||||
|
||||
func (handler Driver) List(ctx context.Context, base string, recursive bool) ([]response.Object, error) {
|
||||
base = strings.TrimPrefix(base, "/")
|
||||
|
||||
// 用于接受SDK返回对象的chan
|
||||
objChan := make(chan *upyun.FileInfo)
|
||||
objects := []*upyun.FileInfo{}
|
||||
|
||||
// 列取配置
|
||||
listConf := &upyun.GetObjectsConfig{
|
||||
Path: "/" + base,
|
||||
ObjectsChan: objChan,
|
||||
MaxListTries: 1,
|
||||
}
|
||||
// 递归列取时不限制递归次数
|
||||
if recursive {
|
||||
listConf.MaxListLevel = -1
|
||||
}
|
||||
|
||||
// 启动一个goroutine收集列取的对象信
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func(input chan *upyun.FileInfo, output *[]*upyun.FileInfo, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
file, ok := <-input
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
*output = append(*output, file)
|
||||
}
|
||||
}(objChan, &objects, wg)
|
||||
|
||||
up := upyun.NewUpYun(&upyun.UpYunConfig{
|
||||
Bucket: handler.Policy.BucketName,
|
||||
Operator: handler.Policy.AccessKey,
|
||||
Password: handler.Policy.SecretKey,
|
||||
})
|
||||
|
||||
err := up.List(listConf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// 汇总处理列取结果
|
||||
res := make([]response.Object, 0, len(objects))
|
||||
for _, object := range objects {
|
||||
res = append(res, response.Object{
|
||||
Name: path.Base(object.Name),
|
||||
RelativePath: object.Name,
|
||||
Source: path.Join(base, object.Name),
|
||||
Size: uint64(object.Size),
|
||||
IsDir: object.IsDir,
|
||||
LastModify: object.Time,
|
||||
})
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Get 获取文件
|
||||
func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 获取文件源地址
|
||||
downloadURL, err := handler.Source(ctx, path, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取文件数据流
|
||||
client := request.NewClient()
|
||||
resp, err := client.Request(
|
||||
"GET",
|
||||
downloadURL,
|
||||
nil,
|
||||
request.WithContext(ctx),
|
||||
request.WithHeader(
|
||||
http.Header{"Cache-Control": {"no-cache", "no-store", "must-revalidate"}},
|
||||
),
|
||||
request.WithTimeout(time.Duration(0)),
|
||||
).CheckHTTPResponse(200).GetRSCloser()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.SetFirstFakeChunk()
|
||||
|
||||
// 尝试自主获取文件大小
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
resp.SetContentLength(int64(file.Size))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
|
||||
}
|
||||
|
||||
// Put 将文件流保存到指定目录
|
||||
func (handler Driver) Put(ctx context.Context, file fsctx.FileHeader) error {
|
||||
defer file.Close()
|
||||
|
||||
up := upyun.NewUpYun(&upyun.UpYunConfig{
|
||||
Bucket: handler.Policy.BucketName,
|
||||
Operator: handler.Policy.AccessKey,
|
||||
Password: handler.Policy.SecretKey,
|
||||
})
|
||||
err := up.Put(&upyun.PutObjectConfig{
|
||||
Path: file.Info().SavePath,
|
||||
Reader: file,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 删除一个或多个文件,
|
||||
// 返回未删除的文件,及遇到的最后一个错误
|
||||
func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) {
|
||||
up := upyun.NewUpYun(&upyun.UpYunConfig{
|
||||
Bucket: handler.Policy.BucketName,
|
||||
Operator: handler.Policy.AccessKey,
|
||||
Password: handler.Policy.SecretKey,
|
||||
})
|
||||
|
||||
var (
|
||||
failed = make([]string, 0, len(files))
|
||||
lastErr error
|
||||
currentIndex = 0
|
||||
indexLock sync.Mutex
|
||||
failedLock sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
routineNum = 4
|
||||
)
|
||||
wg.Add(routineNum)
|
||||
|
||||
// upyun不支持批量操作,这里开四个协程并行操作
|
||||
for i := 0; i < routineNum; i++ {
|
||||
go func() {
|
||||
for {
|
||||
// 取得待删除文件
|
||||
indexLock.Lock()
|
||||
if currentIndex >= len(files) {
|
||||
// 所有文件处理完成
|
||||
wg.Done()
|
||||
indexLock.Unlock()
|
||||
return
|
||||
}
|
||||
path := files[currentIndex]
|
||||
currentIndex++
|
||||
indexLock.Unlock()
|
||||
|
||||
// 发送异步删除请求
|
||||
err := up.Delete(&upyun.DeleteObjectConfig{
|
||||
Path: path,
|
||||
Async: true,
|
||||
})
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
failedLock.Lock()
|
||||
lastErr = err
|
||||
failed = append(failed, path)
|
||||
failedLock.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return failed, lastErr
|
||||
}
|
||||
|
||||
// Thumb 获取文件缩略图
|
||||
func (handler Driver) Thumb(ctx context.Context, file *model.File) (*response.ContentResponse, error) {
|
||||
// quick check by extension name
|
||||
// https://help.upyun.com/knowledge-base/image/
|
||||
supported := []string{"png", "jpg", "jpeg", "gif", "bmp", "webp", "svg"}
|
||||
if len(handler.Policy.OptionsSerialized.ThumbExts) > 0 {
|
||||
supported = handler.Policy.OptionsSerialized.ThumbExts
|
||||
}
|
||||
|
||||
if !util.IsInExtensionList(supported, file.Name) {
|
||||
return nil, driver.ErrorThumbNotSupported
|
||||
}
|
||||
|
||||
var (
|
||||
thumbSize = [2]uint{400, 300}
|
||||
ok = false
|
||||
)
|
||||
if thumbSize, ok = ctx.Value(fsctx.ThumbSizeCtx).([2]uint); !ok {
|
||||
return nil, errors.New("failed to get thumbnail size")
|
||||
}
|
||||
|
||||
thumbEncodeQuality := model.GetIntSetting("thumb_encode_quality", 85)
|
||||
|
||||
thumbParam := fmt.Sprintf("!/fwfh/%dx%d/quality/%d", thumbSize[0], thumbSize[1], thumbEncodeQuality)
|
||||
thumbURL, err := handler.Source(ctx, file.SourceName+thumbParam, int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &response.ContentResponse{
|
||||
Redirect: true,
|
||||
URL: thumbURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Source 获取外链URL
|
||||
func (handler Driver) Source(ctx context.Context, path string, ttl int64, isDownload bool, speed int) (string, error) {
|
||||
// 尝试从上下文获取文件名
|
||||
fileName := ""
|
||||
if file, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
fileName = file.Name
|
||||
}
|
||||
|
||||
sourceURL, err := url.Parse(handler.Policy.BaseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileKey, err := url.Parse(url.PathEscape(path))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sourceURL = sourceURL.ResolveReference(fileKey)
|
||||
|
||||
// 如果是下载文件URL
|
||||
if isDownload {
|
||||
query := sourceURL.Query()
|
||||
query.Add("_upd", fileName)
|
||||
sourceURL.RawQuery = query.Encode()
|
||||
}
|
||||
|
||||
return handler.signURL(ctx, sourceURL, ttl)
|
||||
}
|
||||
|
||||
func (handler Driver) signURL(ctx context.Context, path *url.URL, TTL int64) (string, error) {
|
||||
if !handler.Policy.IsPrivate {
|
||||
// 未开启Token防盗链时,直接返回
|
||||
return path.String(), nil
|
||||
}
|
||||
|
||||
etime := time.Now().Add(time.Duration(TTL) * time.Second).Unix()
|
||||
signStr := fmt.Sprintf(
|
||||
"%s&%d&%s",
|
||||
handler.Policy.OptionsSerialized.Token,
|
||||
etime,
|
||||
path.Path,
|
||||
)
|
||||
signMd5 := fmt.Sprintf("%x", md5.Sum([]byte(signStr)))
|
||||
finalSign := signMd5[12:20] + strconv.FormatInt(etime, 10)
|
||||
|
||||
// 将签名添加到URL中
|
||||
query := path.Query()
|
||||
query.Add("_upt", finalSign)
|
||||
path.RawQuery = query.Encode()
|
||||
|
||||
return path.String(), nil
|
||||
}
|
||||
|
||||
// Token 获取上传策略和认证Token
|
||||
func (handler Driver) Token(ctx context.Context, ttl int64, uploadSession *serializer.UploadSession, file fsctx.FileHeader) (*serializer.UploadCredential, error) {
|
||||
// 生成回调地址
|
||||
siteURL := model.GetSiteURL()
|
||||
apiBaseURI, _ := url.Parse("/api/v3/callback/upyun/" + uploadSession.Key)
|
||||
apiURL := siteURL.ResolveReference(apiBaseURI)
|
||||
|
||||
// 上传策略
|
||||
fileInfo := file.Info()
|
||||
putPolicy := UploadPolicy{
|
||||
Bucket: handler.Policy.BucketName,
|
||||
// TODO escape
|
||||
SaveKey: fileInfo.SavePath,
|
||||
Expiration: time.Now().Add(time.Duration(ttl) * time.Second).Unix(),
|
||||
CallbackURL: apiURL.String(),
|
||||
ContentLength: fileInfo.Size,
|
||||
ContentLengthRange: fmt.Sprintf("0,%d", fileInfo.Size),
|
||||
AllowFileType: strings.Join(handler.Policy.OptionsSerialized.FileType, ","),
|
||||
}
|
||||
|
||||
// 生成上传凭证
|
||||
policyJSON, err := json.Marshal(putPolicy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
policyEncoded := base64.StdEncoding.EncodeToString(policyJSON)
|
||||
|
||||
// 生成签名
|
||||
elements := []string{"POST", "/" + handler.Policy.BucketName, policyEncoded}
|
||||
signStr := handler.Sign(ctx, elements)
|
||||
|
||||
return &serializer.UploadCredential{
|
||||
SessionID: uploadSession.Key,
|
||||
Policy: policyEncoded,
|
||||
Credential: signStr,
|
||||
UploadURLs: []string{"https://v0.api.upyun.com/" + handler.Policy.BucketName},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 取消上传凭证
|
||||
func (handler Driver) CancelToken(ctx context.Context, uploadSession *serializer.UploadSession) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sign 计算又拍云的签名头
|
||||
func (handler Driver) Sign(ctx context.Context, elements []string) string {
|
||||
password := fmt.Sprintf("%x", md5.Sum([]byte(handler.Policy.SecretKey)))
|
||||
mac := hmac.New(sha1.New, []byte(password))
|
||||
value := strings.Join(elements, "&")
|
||||
mac.Write([]byte(value))
|
||||
signStr := base64.StdEncoding.EncodeToString((mac.Sum(nil)))
|
||||
return fmt.Sprintf("UPYUN %s:%s", handler.Policy.AccessKey, signStr)
|
||||
}
|
25
pkg/filesystem/errors.go
Normal file
25
pkg/filesystem/errors.go
Normal file
@ -0,0 +1,25 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnknownPolicyType = serializer.NewError(serializer.CodeInternalSetting, "Unknown policy type", nil)
|
||||
ErrFileSizeTooBig = serializer.NewError(serializer.CodeFileTooLarge, "File is too large", nil)
|
||||
ErrFileExtensionNotAllowed = serializer.NewError(serializer.CodeFileTypeNotAllowed, "File type not allowed", nil)
|
||||
ErrInsufficientCapacity = serializer.NewError(serializer.CodeInsufficientCapacity, "Insufficient capacity", nil)
|
||||
ErrIllegalObjectName = serializer.NewError(serializer.CodeIllegalObjectName, "Invalid object name", nil)
|
||||
ErrClientCanceled = errors.New("Client canceled operation")
|
||||
ErrRootProtected = serializer.NewError(serializer.CodeRootProtected, "Root protected", nil)
|
||||
ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "Failed to create file record", nil)
|
||||
ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "Object existed", nil)
|
||||
ErrFileUploadSessionExisted = serializer.NewError(serializer.CodeConflictUploadOngoing, "Upload session existed", nil)
|
||||
ErrPathNotExist = serializer.NewError(serializer.CodeParentNotExist, "Path not exist", nil)
|
||||
ErrObjectNotExist = serializer.NewError(serializer.CodeParentNotExist, "Object not exist", nil)
|
||||
ErrIO = serializer.NewError(serializer.CodeIOFailed, "Failed to read file data", nil)
|
||||
ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "Failed to list object records", nil)
|
||||
ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "Failed to delete object records", nil)
|
||||
ErrOneObjectOnly = serializer.ParamErr("You can only copy one object at the same time", nil)
|
||||
)
|
387
pkg/filesystem/file.go
Normal file
387
pkg/filesystem/file.go
Normal file
@ -0,0 +1,387 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/juju/ratelimit"
|
||||
)
|
||||
|
||||
/* ============
|
||||
文件相关
|
||||
============
|
||||
*/
|
||||
|
||||
// 限速后的ReaderSeeker
|
||||
type lrs struct {
|
||||
response.RSCloser
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (r lrs) Read(p []byte) (int, error) {
|
||||
return r.r.Read(p)
|
||||
}
|
||||
|
||||
// withSpeedLimit 给原有的ReadSeeker加上限速
|
||||
func (fs *FileSystem) withSpeedLimit(rs response.RSCloser) response.RSCloser {
|
||||
// 如果用户组有速度限制,就返回限制流速的ReaderSeeker
|
||||
if fs.User.Group.SpeedLimit != 0 {
|
||||
speed := fs.User.Group.SpeedLimit
|
||||
bucket := ratelimit.NewBucketWithRate(float64(speed), int64(speed))
|
||||
lrs := lrs{rs, ratelimit.Reader(rs, bucket)}
|
||||
return lrs
|
||||
}
|
||||
// 否则返回原始流
|
||||
return rs
|
||||
|
||||
}
|
||||
|
||||
// AddFile 新增文件记录
|
||||
func (fs *FileSystem) AddFile(ctx context.Context, parent *model.Folder, file fsctx.FileHeader) (*model.File, error) {
|
||||
// 添加文件记录前的钩子
|
||||
err := fs.Trigger(ctx, "BeforeAddFile", file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uploadInfo := file.Info()
|
||||
newFile := model.File{
|
||||
Name: uploadInfo.FileName,
|
||||
SourceName: uploadInfo.SavePath,
|
||||
UserID: fs.User.ID,
|
||||
Size: uploadInfo.Size,
|
||||
FolderID: parent.ID,
|
||||
PolicyID: fs.Policy.ID,
|
||||
MetadataSerialized: uploadInfo.Metadata,
|
||||
UploadSessionID: uploadInfo.UploadSessionID,
|
||||
}
|
||||
|
||||
err = newFile.Create()
|
||||
|
||||
if err != nil {
|
||||
if err := fs.Trigger(ctx, "AfterValidateFailed", file); err != nil {
|
||||
util.Log().Debug("AfterValidateFailed hook execution failed: %s", err)
|
||||
}
|
||||
return nil, ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
fs.User.Storage += newFile.Size
|
||||
return &newFile, nil
|
||||
}
|
||||
|
||||
// GetPhysicalFileContent 根据文件物理路径获取文件流
|
||||
func (fs *FileSystem) GetPhysicalFileContent(ctx context.Context, path string) (response.RSCloser, error) {
|
||||
// 重设上传策略
|
||||
fs.Policy = &model.Policy{Type: "local"}
|
||||
_ = fs.DispatchHandler()
|
||||
|
||||
// 获取文件流
|
||||
rs, err := fs.Handler.Get(ctx, path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fs.withSpeedLimit(rs), nil
|
||||
}
|
||||
|
||||
// Preview 预览文件
|
||||
//
|
||||
// path - 文件虚拟路径
|
||||
// isText - 是否为文本文件,文本文件会忽略重定向,直接由
|
||||
// 服务端拉取中转给用户,故会对文件大小进行限制
|
||||
func (fs *FileSystem) Preview(ctx context.Context, id uint, isText bool) (*response.ContentResponse, error) {
|
||||
err := fs.resetFileIDIfNotExist(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果是文本文件预览,需要检查大小限制
|
||||
sizeLimit := model.GetIntSetting("maxEditSize", 2<<20)
|
||||
if isText && fs.FileTarget[0].Size > uint64(sizeLimit) {
|
||||
return nil, ErrFileSizeTooBig
|
||||
}
|
||||
|
||||
// 是否直接返回文件内容
|
||||
if isText || fs.Policy.IsDirectlyPreview() {
|
||||
resp, err := fs.GetDownloadContent(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response.ContentResponse{
|
||||
Redirect: false,
|
||||
Content: resp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 否则重定向到签名的预览URL
|
||||
ttl := model.GetIntSetting("preview_timeout", 60)
|
||||
previewURL, err := fs.SignURL(ctx, &fs.FileTarget[0], int64(ttl), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response.ContentResponse{
|
||||
Redirect: true,
|
||||
URL: previewURL,
|
||||
MaxAge: ttl,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// GetDownloadContent 获取用于下载的文件流
|
||||
func (fs *FileSystem) GetDownloadContent(ctx context.Context, id uint) (response.RSCloser, error) {
|
||||
// 获取原始文件流
|
||||
rs, err := fs.GetContent(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 返回限速处理后的文件流
|
||||
return fs.withSpeedLimit(rs), nil
|
||||
|
||||
}
|
||||
|
||||
// GetContent 获取文件内容,path为虚拟路径
|
||||
func (fs *FileSystem) GetContent(ctx context.Context, id uint) (response.RSCloser, error) {
|
||||
err := fs.resetFileIDIfNotExist(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0])
|
||||
|
||||
// 获取文件流
|
||||
rs, err := fs.Handler.Get(ctx, fs.FileTarget[0].SourceName)
|
||||
if err != nil {
|
||||
return nil, ErrIO.WithError(err)
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
// deleteGroupedFile 对分组好的文件执行删除操作,
|
||||
// 返回每个分组失败的文件列表
|
||||
func (fs *FileSystem) deleteGroupedFile(ctx context.Context, files map[uint][]*model.File) map[uint][]string {
|
||||
// 失败的文件列表
|
||||
// TODO 并行删除
|
||||
failed := make(map[uint][]string, len(files))
|
||||
thumbs := make([]string, 0)
|
||||
|
||||
for policyID, toBeDeletedFiles := range files {
|
||||
// 列举出需要物理删除的文件的物理路径
|
||||
sourceNamesAll := make([]string, 0, len(toBeDeletedFiles))
|
||||
uploadSessions := make([]*serializer.UploadSession, 0, len(toBeDeletedFiles))
|
||||
|
||||
for i := 0; i < len(toBeDeletedFiles); i++ {
|
||||
sourceNamesAll = append(sourceNamesAll, toBeDeletedFiles[i].SourceName)
|
||||
|
||||
if toBeDeletedFiles[i].UploadSessionID != nil {
|
||||
if session, ok := cache.Get(UploadSessionCachePrefix + *toBeDeletedFiles[i].UploadSessionID); ok {
|
||||
uploadSession := session.(serializer.UploadSession)
|
||||
uploadSessions = append(uploadSessions, &uploadSession)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if sidecar thumb file exist
|
||||
if model.IsTrueVal(toBeDeletedFiles[i].MetadataSerialized[model.ThumbSidecarMetadataKey]) {
|
||||
thumbs = append(thumbs, toBeDeletedFiles[i].ThumbFile())
|
||||
}
|
||||
}
|
||||
|
||||
// 切换上传策略
|
||||
fs.Policy = toBeDeletedFiles[0].GetPolicy()
|
||||
err := fs.DispatchHandler()
|
||||
if err != nil {
|
||||
failed[policyID] = sourceNamesAll
|
||||
continue
|
||||
}
|
||||
|
||||
// 取消上传会话
|
||||
for _, upSession := range uploadSessions {
|
||||
if err := fs.Handler.CancelToken(ctx, upSession); err != nil {
|
||||
util.Log().Warning("Failed to cancel upload session for %q: %s", upSession.Name, err)
|
||||
}
|
||||
|
||||
cache.Deletes([]string{upSession.Key}, UploadSessionCachePrefix)
|
||||
}
|
||||
|
||||
// 执行删除
|
||||
toBeDeletedSrcs := append(sourceNamesAll, thumbs...)
|
||||
failedFile, _ := fs.Handler.Delete(ctx, toBeDeletedSrcs)
|
||||
|
||||
// Exclude failed results related to thumb file
|
||||
failed[policyID] = util.SliceDifference(failedFile, thumbs)
|
||||
}
|
||||
|
||||
return failed
|
||||
}
|
||||
|
||||
// GroupFileByPolicy 将目标文件按照存储策略分组
|
||||
func (fs *FileSystem) GroupFileByPolicy(ctx context.Context, files []model.File) map[uint][]*model.File {
|
||||
var policyGroup = make(map[uint][]*model.File)
|
||||
|
||||
for key := range files {
|
||||
if file, ok := policyGroup[files[key].PolicyID]; ok {
|
||||
// 如果已存在分组,直接追加
|
||||
policyGroup[files[key].PolicyID] = append(file, &files[key])
|
||||
} else {
|
||||
// 分组不存在,创建
|
||||
policyGroup[files[key].PolicyID] = make([]*model.File, 0)
|
||||
policyGroup[files[key].PolicyID] = append(policyGroup[files[key].PolicyID], &files[key])
|
||||
}
|
||||
}
|
||||
|
||||
return policyGroup
|
||||
}
|
||||
|
||||
// GetDownloadURL 创建文件下载链接, timeout 为数据库中存储过期时间的字段
|
||||
func (fs *FileSystem) GetDownloadURL(ctx context.Context, id uint, timeout string) (string, error) {
|
||||
err := fs.resetFileIDIfNotExist(ctx, id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
fileTarget := &fs.FileTarget[0]
|
||||
|
||||
// 生成下載地址
|
||||
ttl := model.GetIntSetting(timeout, 60)
|
||||
source, err := fs.SignURL(
|
||||
ctx,
|
||||
fileTarget,
|
||||
int64(ttl),
|
||||
true,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// GetSource 获取可直接访问文件的外链地址
|
||||
func (fs *FileSystem) GetSource(ctx context.Context, fileID uint) (string, error) {
|
||||
// 查找文件记录
|
||||
err := fs.resetFileIDIfNotExist(ctx, fileID)
|
||||
if err != nil {
|
||||
return "", ErrObjectNotExist.WithError(err)
|
||||
}
|
||||
|
||||
// 检查存储策略是否可以获得外链
|
||||
if !fs.Policy.IsOriginLinkEnable {
|
||||
return "", serializer.NewError(
|
||||
serializer.CodePolicyNotAllowed,
|
||||
"This policy is not enabled for getting source link",
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
source, err := fs.SignURL(ctx, &fs.FileTarget[0], 0, false)
|
||||
if err != nil {
|
||||
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
|
||||
}
|
||||
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// SignURL 签名文件原始 URL
|
||||
func (fs *FileSystem) SignURL(ctx context.Context, file *model.File, ttl int64, isDownload bool) (string, error) {
|
||||
fs.FileTarget = []model.File{*file}
|
||||
ctx = context.WithValue(ctx, fsctx.FileModelCtx, *file)
|
||||
|
||||
err := fs.resetPolicyToFirstFile(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 签名最终URL
|
||||
// 生成外链地址
|
||||
source, err := fs.Handler.Source(ctx, fs.FileTarget[0].SourceName, ttl, isDownload, fs.User.Group.SpeedLimit)
|
||||
if err != nil {
|
||||
return "", serializer.NewError(serializer.CodeNotSet, "Failed to get source link", err)
|
||||
}
|
||||
|
||||
return source, nil
|
||||
}
|
||||
|
||||
// ResetFileIfNotExist 重设当前目标文件为 path,如果当前目标为空
|
||||
func (fs *FileSystem) ResetFileIfNotExist(ctx context.Context, path string) error {
|
||||
// 找到文件
|
||||
if len(fs.FileTarget) == 0 {
|
||||
exist, file := fs.IsFileExist(path)
|
||||
if !exist {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
fs.FileTarget = []model.File{*file}
|
||||
}
|
||||
|
||||
// 将当前存储策略重设为文件使用的
|
||||
return fs.resetPolicyToFirstFile(ctx)
|
||||
}
|
||||
|
||||
// ResetFileIfNotExist 重设当前目标文件为 id,如果当前目标为空
|
||||
func (fs *FileSystem) resetFileIDIfNotExist(ctx context.Context, id uint) error {
|
||||
// 找到文件
|
||||
if len(fs.FileTarget) == 0 {
|
||||
file, err := model.GetFilesByIDs([]uint{id}, fs.User.ID)
|
||||
if err != nil || len(file) == 0 {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
fs.FileTarget = []model.File{file[0]}
|
||||
}
|
||||
|
||||
// 如果上下文限制了父目录,则进行检查
|
||||
if parent, ok := ctx.Value(fsctx.LimitParentCtx).(*model.Folder); ok {
|
||||
if parent.ID != fs.FileTarget[0].FolderID {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
}
|
||||
|
||||
// 将当前存储策略重设为文件使用的
|
||||
return fs.resetPolicyToFirstFile(ctx)
|
||||
}
|
||||
|
||||
// resetPolicyToFirstFile 将当前存储策略重设为第一个目标文件文件使用的
|
||||
func (fs *FileSystem) resetPolicyToFirstFile(ctx context.Context) error {
|
||||
if len(fs.FileTarget) == 0 {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
|
||||
// 从机模式不进行操作
|
||||
if conf.SystemConfig.Mode == "slave" {
|
||||
return nil
|
||||
}
|
||||
|
||||
fs.Policy = fs.FileTarget[0].GetPolicy()
|
||||
err := fs.DispatchHandler()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Search 搜索文件
|
||||
func (fs *FileSystem) Search(ctx context.Context, keywords ...interface{}) ([]serializer.Object, error) {
|
||||
parents := make([]uint, 0)
|
||||
|
||||
// 如果限定了根目录,则只在这个根目录下搜索。
|
||||
if fs.Root != nil {
|
||||
allFolders, err := model.GetRecursiveChildFolder([]uint{fs.Root.ID}, fs.User.ID, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list all folders: %w", err)
|
||||
}
|
||||
|
||||
for _, folder := range allFolders {
|
||||
parents = append(parents, folder.ID)
|
||||
}
|
||||
}
|
||||
|
||||
files, _ := model.GetFilesByKeywords(fs.User.ID, parents, keywords...)
|
||||
fs.SetTargetFile(&files)
|
||||
|
||||
return fs.listObjects(ctx, "/", files, nil, nil), nil
|
||||
}
|
295
pkg/filesystem/filesystem.go
Normal file
295
pkg/filesystem/filesystem.go
Normal file
@ -0,0 +1,295 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/cos"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/googledrive"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/oss"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/qiniu"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/remote"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/s3"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/masterinslave"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/shadow/slaveinmaster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/upyun"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/gin-gonic/gin"
|
||||
cossdk "github.com/tencentyun/cos-go-sdk-v5"
|
||||
)
|
||||
|
||||
// FSPool 文件系统资源池
|
||||
var FSPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &FileSystem{}
|
||||
},
|
||||
}
|
||||
|
||||
// FileSystem 管理文件的文件系统
|
||||
type FileSystem struct {
|
||||
// 文件系统所有者
|
||||
User *model.User
|
||||
// 操作文件使用的存储策略
|
||||
Policy *model.Policy
|
||||
// 当前正在处理的文件对象
|
||||
FileTarget []model.File
|
||||
// 当前正在处理的目录对象
|
||||
DirTarget []model.Folder
|
||||
// 相对根目录
|
||||
Root *model.Folder
|
||||
// 互斥锁
|
||||
Lock sync.Mutex
|
||||
|
||||
/*
|
||||
钩子函数
|
||||
*/
|
||||
Hooks map[string][]Hook
|
||||
|
||||
/*
|
||||
文件系统处理适配器
|
||||
*/
|
||||
Handler driver.Handler
|
||||
|
||||
// 回收锁
|
||||
recycleLock sync.Mutex
|
||||
}
|
||||
|
||||
// getEmptyFS 从pool中获取新的FileSystem
|
||||
func getEmptyFS() *FileSystem {
|
||||
fs := FSPool.Get().(*FileSystem)
|
||||
return fs
|
||||
}
|
||||
|
||||
// Recycle 回收FileSystem资源
|
||||
func (fs *FileSystem) Recycle() {
|
||||
fs.recycleLock.Lock()
|
||||
fs.reset()
|
||||
FSPool.Put(fs)
|
||||
}
|
||||
|
||||
// reset 重设文件系统,以便回收使用
|
||||
func (fs *FileSystem) reset() {
|
||||
fs.User = nil
|
||||
fs.CleanTargets()
|
||||
fs.Policy = nil
|
||||
fs.Hooks = nil
|
||||
fs.Handler = nil
|
||||
fs.Root = nil
|
||||
fs.Lock = sync.Mutex{}
|
||||
fs.recycleLock = sync.Mutex{}
|
||||
}
|
||||
|
||||
// NewFileSystem 初始化一个文件系统
|
||||
func NewFileSystem(user *model.User) (*FileSystem, error) {
|
||||
fs := getEmptyFS()
|
||||
fs.User = user
|
||||
fs.Policy = user.GetPolicyID(nil)
|
||||
|
||||
// 分配存储策略适配器
|
||||
err := fs.DispatchHandler()
|
||||
|
||||
return fs, err
|
||||
}
|
||||
|
||||
// NewAnonymousFileSystem 初始化匿名文件系统
|
||||
func NewAnonymousFileSystem() (*FileSystem, error) {
|
||||
fs := getEmptyFS()
|
||||
fs.User = &model.User{}
|
||||
|
||||
// 如果是主机模式下,则为匿名文件系统分配游客用户组
|
||||
if conf.SystemConfig.Mode == "master" {
|
||||
anonymousGroup, err := model.GetGroupByID(3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fs.User.Group = anonymousGroup
|
||||
} else {
|
||||
// 从机模式下,分配本地策略处理器
|
||||
fs.Handler = local.Driver{}
|
||||
}
|
||||
|
||||
return fs, nil
|
||||
}
|
||||
|
||||
// DispatchHandler 根据存储策略分配文件适配器
|
||||
func (fs *FileSystem) DispatchHandler() error {
|
||||
handler, err := getNewPolicyHandler(fs.Policy)
|
||||
fs.Handler = handler
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// getNewPolicyHandler 根据存储策略类型字段获取处理器
|
||||
func getNewPolicyHandler(policy *model.Policy) (driver.Handler, error) {
|
||||
if policy == nil {
|
||||
return nil, ErrUnknownPolicyType
|
||||
}
|
||||
|
||||
switch policy.Type {
|
||||
case "mock", "anonymous":
|
||||
return nil, nil
|
||||
case "local":
|
||||
return local.Driver{
|
||||
Policy: policy,
|
||||
}, nil
|
||||
case "remote":
|
||||
return remote.NewDriver(policy)
|
||||
case "qiniu":
|
||||
return qiniu.NewDriver(policy), nil
|
||||
case "oss":
|
||||
return oss.NewDriver(policy)
|
||||
case "upyun":
|
||||
return upyun.Driver{
|
||||
Policy: policy,
|
||||
}, nil
|
||||
case "onedrive":
|
||||
return onedrive.NewDriver(policy)
|
||||
case "cos":
|
||||
u, _ := url.Parse(policy.Server)
|
||||
b := &cossdk.BaseURL{BucketURL: u}
|
||||
return cos.Driver{
|
||||
Policy: policy,
|
||||
Client: cossdk.NewClient(b, &http.Client{
|
||||
Transport: &cossdk.AuthorizationTransport{
|
||||
SecretID: policy.AccessKey,
|
||||
SecretKey: policy.SecretKey,
|
||||
},
|
||||
}),
|
||||
HTTPClient: request.NewClient(),
|
||||
}, nil
|
||||
case "s3":
|
||||
return s3.NewDriver(policy)
|
||||
case "googledrive":
|
||||
return googledrive.NewDriver(policy)
|
||||
default:
|
||||
return nil, ErrUnknownPolicyType
|
||||
}
|
||||
}
|
||||
|
||||
// NewFileSystemFromContext 从gin.Context创建文件系统
|
||||
func NewFileSystemFromContext(c *gin.Context) (*FileSystem, error) {
|
||||
user, exist := c.Get("user")
|
||||
if !exist {
|
||||
return NewAnonymousFileSystem()
|
||||
}
|
||||
fs, err := NewFileSystem(user.(*model.User))
|
||||
return fs, err
|
||||
}
|
||||
|
||||
// NewFileSystemFromCallback 从gin.Context创建回调用文件系统
|
||||
func NewFileSystemFromCallback(c *gin.Context) (*FileSystem, error) {
|
||||
fs, err := NewFileSystemFromContext(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取回调会话
|
||||
callbackSessionRaw, ok := c.Get(UploadSessionCtx)
|
||||
if !ok {
|
||||
return nil, errors.New("upload session not exist")
|
||||
}
|
||||
callbackSession := callbackSessionRaw.(*serializer.UploadSession)
|
||||
|
||||
// 重新指向上传策略
|
||||
fs.Policy = &callbackSession.Policy
|
||||
err = fs.DispatchHandler()
|
||||
|
||||
return fs, err
|
||||
}
|
||||
|
||||
// SwitchToSlaveHandler 将负责上传的 Handler 切换为从机节点
|
||||
func (fs *FileSystem) SwitchToSlaveHandler(node cluster.Node) {
|
||||
fs.Handler = slaveinmaster.NewDriver(node, fs.Handler, fs.Policy)
|
||||
}
|
||||
|
||||
// SwitchToShadowHandler 将负责上传的 Handler 切换为从机节点转存使用的影子处理器
|
||||
func (fs *FileSystem) SwitchToShadowHandler(master cluster.Node, masterURL, masterID string) {
|
||||
switch fs.Policy.Type {
|
||||
case "local":
|
||||
fs.Policy.Type = "remote"
|
||||
fs.Policy.Server = masterURL
|
||||
fs.Policy.AccessKey = fmt.Sprintf("%d", master.ID())
|
||||
fs.Policy.SecretKey = master.DBModel().MasterKey
|
||||
fs.DispatchHandler()
|
||||
case "onedrive":
|
||||
fs.Policy.MasterID = masterID
|
||||
}
|
||||
|
||||
fs.Handler = masterinslave.NewDriver(master, fs.Handler, fs.Policy)
|
||||
}
|
||||
|
||||
// SetTargetFile 设置当前处理的目标文件
|
||||
func (fs *FileSystem) SetTargetFile(files *[]model.File) {
|
||||
if len(fs.FileTarget) == 0 {
|
||||
fs.FileTarget = *files
|
||||
} else {
|
||||
fs.FileTarget = append(fs.FileTarget, *files...)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// SetTargetDir 设置当前处理的目标目录
|
||||
func (fs *FileSystem) SetTargetDir(dirs *[]model.Folder) {
|
||||
if len(fs.DirTarget) == 0 {
|
||||
fs.DirTarget = *dirs
|
||||
} else {
|
||||
fs.DirTarget = append(fs.DirTarget, *dirs...)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// SetTargetFileByIDs 根据文件ID设置目标文件,忽略用户ID
|
||||
func (fs *FileSystem) SetTargetFileByIDs(ids []uint) error {
|
||||
files, err := model.GetFilesByIDs(ids, 0)
|
||||
if err != nil || len(files) == 0 {
|
||||
return ErrFileExisted.WithError(err)
|
||||
}
|
||||
fs.SetTargetFile(&files)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTargetByInterface 根据 model.File 或者 model.Folder 设置目标对象
|
||||
// TODO 测试
|
||||
func (fs *FileSystem) SetTargetByInterface(target interface{}) error {
|
||||
if file, ok := target.(*model.File); ok {
|
||||
fs.SetTargetFile(&[]model.File{*file})
|
||||
return nil
|
||||
}
|
||||
if folder, ok := target.(*model.Folder); ok {
|
||||
fs.SetTargetDir(&[]model.Folder{*folder})
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
|
||||
// CleanTargets 清空目标
|
||||
func (fs *FileSystem) CleanTargets() {
|
||||
fs.FileTarget = fs.FileTarget[:0]
|
||||
fs.DirTarget = fs.DirTarget[:0]
|
||||
}
|
||||
|
||||
// SetPolicyFromPath 根据给定路径尝试设定偏好存储策略
|
||||
func (fs *FileSystem) SetPolicyFromPath(filePath string) error {
|
||||
_, parent := fs.getClosedParent(filePath)
|
||||
// 尝试获取并重设存储策略
|
||||
fs.Policy = fs.User.GetPolicyID(parent)
|
||||
return fs.DispatchHandler()
|
||||
}
|
||||
|
||||
// SetPolicyFromPreference 尝试设定偏好存储策略
|
||||
func (fs *FileSystem) SetPolicyFromPreference(preference uint) error {
|
||||
// 尝试获取并重设存储策略
|
||||
fs.Policy = fs.User.GetPolicyByPreference(preference)
|
||||
return fs.DispatchHandler()
|
||||
}
|
44
pkg/filesystem/fsctx/context.go
Normal file
44
pkg/filesystem/fsctx/context.go
Normal file
@ -0,0 +1,44 @@
|
||||
package fsctx
|
||||
|
||||
type key int
|
||||
|
||||
const (
|
||||
// GinCtx Gin的上下文
|
||||
GinCtx key = iota
|
||||
// PathCtx 文件或目录的虚拟路径
|
||||
PathCtx
|
||||
// FileModelCtx 文件数据库模型
|
||||
FileModelCtx
|
||||
// FolderModelCtx 目录数据库模型
|
||||
FolderModelCtx
|
||||
// HTTPCtx HTTP请求的上下文
|
||||
HTTPCtx
|
||||
// UploadPolicyCtx 上传策略,一般为slave模式下使用
|
||||
UploadPolicyCtx
|
||||
// UserCtx 用户
|
||||
UserCtx
|
||||
// ThumbSizeCtx 缩略图尺寸
|
||||
ThumbSizeCtx
|
||||
// FileSizeCtx 文件大小
|
||||
FileSizeCtx
|
||||
// ShareKeyCtx 分享文件的 HashID
|
||||
ShareKeyCtx
|
||||
// LimitParentCtx 限制父目录
|
||||
LimitParentCtx
|
||||
// IgnoreDirectoryConflictCtx 忽略目录重名冲突
|
||||
IgnoreDirectoryConflictCtx
|
||||
// RetryCtx 失败重试次数
|
||||
RetryCtx
|
||||
// ForceUsePublicEndpointCtx 强制使用公网 Endpoint
|
||||
ForceUsePublicEndpointCtx
|
||||
// CancelFuncCtx Context 取消函數
|
||||
CancelFuncCtx
|
||||
// 文件在从机节点中的路径
|
||||
SlaveSrcPath
|
||||
// Webdav目标名称
|
||||
WebdavDstName
|
||||
// WebDAVCtx WebDAV
|
||||
WebDAVCtx
|
||||
// WebDAV反代Url
|
||||
WebDAVProxyUrlCtx
|
||||
)
|
123
pkg/filesystem/fsctx/stream.go
Normal file
123
pkg/filesystem/fsctx/stream.go
Normal file
@ -0,0 +1,123 @@
|
||||
package fsctx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/HFO4/aliyun-oss-go-sdk/oss"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WriteMode int
|
||||
|
||||
const (
|
||||
Overwrite WriteMode = 0x00001
|
||||
// Append 只适用于本地策略
|
||||
Append WriteMode = 0x00002
|
||||
Nop WriteMode = 0x00004
|
||||
)
|
||||
|
||||
type UploadTaskInfo struct {
|
||||
Size uint64
|
||||
MimeType string
|
||||
FileName string
|
||||
VirtualPath string
|
||||
Mode WriteMode
|
||||
Metadata map[string]string
|
||||
LastModified *time.Time
|
||||
SavePath string
|
||||
UploadSessionID *string
|
||||
AppendStart uint64
|
||||
Model interface{}
|
||||
Src string
|
||||
}
|
||||
|
||||
// Get mimetype of uploaded file, if it's not defined, detect it from file name
|
||||
func (u *UploadTaskInfo) DetectMimeType() string {
|
||||
if u.MimeType != "" {
|
||||
return u.MimeType
|
||||
}
|
||||
|
||||
return oss.TypeByExtension(u.FileName)
|
||||
}
|
||||
|
||||
// FileHeader 上传来的文件数据处理器
|
||||
type FileHeader interface {
|
||||
io.Reader
|
||||
io.Closer
|
||||
io.Seeker
|
||||
Info() *UploadTaskInfo
|
||||
SetSize(uint64)
|
||||
SetModel(fileModel interface{})
|
||||
Seekable() bool
|
||||
}
|
||||
|
||||
// FileStream 用户传来的文件
|
||||
type FileStream struct {
|
||||
Mode WriteMode
|
||||
LastModified *time.Time
|
||||
Metadata map[string]string
|
||||
File io.ReadCloser
|
||||
Seeker io.Seeker
|
||||
Size uint64
|
||||
VirtualPath string
|
||||
Name string
|
||||
MimeType string
|
||||
SavePath string
|
||||
UploadSessionID *string
|
||||
AppendStart uint64
|
||||
Model interface{}
|
||||
Src string
|
||||
}
|
||||
|
||||
func (file *FileStream) Read(p []byte) (n int, err error) {
|
||||
if file.File != nil {
|
||||
return file.File.Read(p)
|
||||
}
|
||||
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
func (file *FileStream) Close() error {
|
||||
if file.File != nil {
|
||||
return file.File.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (file *FileStream) Seek(offset int64, whence int) (int64, error) {
|
||||
if file.Seekable() {
|
||||
return file.Seeker.Seek(offset, whence)
|
||||
}
|
||||
|
||||
return 0, errors.New("no seeker")
|
||||
}
|
||||
|
||||
func (file *FileStream) Seekable() bool {
|
||||
return file.Seeker != nil
|
||||
}
|
||||
|
||||
func (file *FileStream) Info() *UploadTaskInfo {
|
||||
return &UploadTaskInfo{
|
||||
Size: file.Size,
|
||||
MimeType: file.MimeType,
|
||||
FileName: file.Name,
|
||||
VirtualPath: file.VirtualPath,
|
||||
Mode: file.Mode,
|
||||
Metadata: file.Metadata,
|
||||
LastModified: file.LastModified,
|
||||
SavePath: file.SavePath,
|
||||
UploadSessionID: file.UploadSessionID,
|
||||
AppendStart: file.AppendStart,
|
||||
Model: file.Model,
|
||||
Src: file.Src,
|
||||
}
|
||||
}
|
||||
|
||||
func (file *FileStream) SetSize(size uint64) {
|
||||
file.Size = size
|
||||
}
|
||||
|
||||
func (file *FileStream) SetModel(fileModel interface{}) {
|
||||
file.Model = fileModel
|
||||
}
|
320
pkg/filesystem/hooks.go
Normal file
320
pkg/filesystem/hooks.go
Normal file
@ -0,0 +1,320 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/local"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
// Hook 钩子函数
|
||||
type Hook func(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error
|
||||
|
||||
// Use 注入钩子
|
||||
func (fs *FileSystem) Use(name string, hook Hook) {
|
||||
if fs.Hooks == nil {
|
||||
fs.Hooks = make(map[string][]Hook)
|
||||
}
|
||||
if _, ok := fs.Hooks[name]; ok {
|
||||
fs.Hooks[name] = append(fs.Hooks[name], hook)
|
||||
return
|
||||
}
|
||||
fs.Hooks[name] = []Hook{hook}
|
||||
}
|
||||
|
||||
// CleanHooks 清空钩子,name为空表示全部清空
|
||||
func (fs *FileSystem) CleanHooks(name string) {
|
||||
if name == "" {
|
||||
fs.Hooks = nil
|
||||
} else {
|
||||
delete(fs.Hooks, name)
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger 触发钩子,遇到第一个错误时
|
||||
// 返回错误,后续钩子不会继续执行
|
||||
func (fs *FileSystem) Trigger(ctx context.Context, name string, file fsctx.FileHeader) error {
|
||||
if hooks, ok := fs.Hooks[name]; ok {
|
||||
for _, hook := range hooks {
|
||||
err := hook(ctx, fs, file)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to execute hook:%s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookValidateFile 一系列对文件检验的集合
|
||||
func HookValidateFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
fileInfo := file.Info()
|
||||
|
||||
// 验证单文件尺寸
|
||||
if !fs.ValidateFileSize(ctx, fileInfo.Size) {
|
||||
return ErrFileSizeTooBig
|
||||
}
|
||||
|
||||
// 验证文件名
|
||||
if !fs.ValidateLegalName(ctx, fileInfo.FileName) {
|
||||
return ErrIllegalObjectName
|
||||
}
|
||||
|
||||
// 验证扩展名
|
||||
if !fs.ValidateExtension(ctx, fileInfo.FileName) {
|
||||
return ErrFileExtensionNotAllowed
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// HookResetPolicy 重设存储策略为上下文已有文件
|
||||
func HookResetPolicy(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
|
||||
if !ok {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
|
||||
fs.Policy = originFile.GetPolicy()
|
||||
return fs.DispatchHandler()
|
||||
}
|
||||
|
||||
// HookValidateCapacity 验证用户容量
|
||||
func HookValidateCapacity(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
// 验证并扣除容量
|
||||
if fs.User.GetRemainingCapacity() < file.Info().Size {
|
||||
return ErrInsufficientCapacity
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookValidateCapacityDiff 根据原有文件和新文件的大小验证用户容量
|
||||
func HookValidateCapacityDiff(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
|
||||
originFile := ctx.Value(fsctx.FileModelCtx).(model.File)
|
||||
newFileSize := newFile.Info().Size
|
||||
|
||||
if newFileSize > originFile.Size {
|
||||
return HookValidateCapacity(ctx, fs, newFile)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookDeleteTempFile 删除已保存的临时文件
|
||||
func HookDeleteTempFile(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
// 删除临时文件
|
||||
_, err := fs.Handler.Delete(ctx, []string{file.Info().SavePath})
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to clean-up temp files: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookCleanFileContent 清空文件内容
|
||||
func HookCleanFileContent(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
// 清空内容
|
||||
return fs.Handler.Put(ctx, &fsctx.FileStream{
|
||||
File: ioutil.NopCloser(strings.NewReader("")),
|
||||
SavePath: file.Info().SavePath,
|
||||
Size: 0,
|
||||
Mode: fsctx.Overwrite,
|
||||
})
|
||||
}
|
||||
|
||||
// HookClearFileSize 将原始文件的尺寸设为0
|
||||
func HookClearFileSize(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
|
||||
if !ok {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
return originFile.UpdateSize(0)
|
||||
}
|
||||
|
||||
// HookCancelContext 取消上下文
|
||||
func HookCancelContext(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
cancelFunc, ok := ctx.Value(fsctx.CancelFuncCtx).(context.CancelFunc)
|
||||
if ok {
|
||||
cancelFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookUpdateSourceName 更新文件SourceName
|
||||
func HookUpdateSourceName(ctx context.Context, fs *FileSystem, file fsctx.FileHeader) error {
|
||||
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
|
||||
if !ok {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
return originFile.UpdateSourceName(originFile.SourceName)
|
||||
}
|
||||
|
||||
// GenericAfterUpdate 文件内容更新后
|
||||
func GenericAfterUpdate(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
|
||||
// 更新文件尺寸
|
||||
originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File)
|
||||
if !ok {
|
||||
return ErrObjectNotExist
|
||||
}
|
||||
|
||||
newFile.SetModel(&originFile)
|
||||
|
||||
err := originFile.UpdateSize(newFile.Info().Size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SlaveAfterUpload Slave模式下上传完成钩子
|
||||
func SlaveAfterUpload(session *serializer.UploadSession) Hook {
|
||||
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
if session.Callback == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 发送回调请求
|
||||
callbackBody := serializer.UploadCallback{}
|
||||
return cluster.RemoteCallback(session.Callback, callbackBody)
|
||||
}
|
||||
}
|
||||
|
||||
// GenericAfterUpload 文件上传完成后,包含数据库操作
|
||||
func GenericAfterUpload(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
fileInfo := fileHeader.Info()
|
||||
|
||||
// 创建或查找根目录
|
||||
folder, err := fs.CreateDirectory(ctx, fileInfo.VirtualPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
if ok, file := fs.IsChildFileExist(
|
||||
folder,
|
||||
fileInfo.FileName,
|
||||
); ok {
|
||||
if file.UploadSessionID != nil {
|
||||
return ErrFileUploadSessionExisted
|
||||
}
|
||||
|
||||
return ErrFileExisted
|
||||
}
|
||||
|
||||
// 向数据库中插入记录
|
||||
file, err := fs.AddFile(ctx, folder, fileHeader)
|
||||
if err != nil {
|
||||
return ErrInsertFileRecord
|
||||
}
|
||||
fileHeader.SetModel(file)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookGenerateThumb 生成缩略图
|
||||
// func HookGenerateThumb(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
// // 异步尝试生成缩略图
|
||||
// fileMode := fileHeader.Info().Model.(*model.File)
|
||||
// if fs.Policy.IsThumbGenerateNeeded() {
|
||||
// fs.recycleLock.Lock()
|
||||
// go func() {
|
||||
// defer fs.recycleLock.Unlock()
|
||||
// _, _ = fs.Handler.Delete(ctx, []string{fileMode.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")})
|
||||
// fs.GenerateThumbnail(ctx, fileMode)
|
||||
// }()
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// HookClearFileHeaderSize 将FileHeader大小设定为0
|
||||
func HookClearFileHeaderSize(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
fileHeader.SetSize(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HookTruncateFileTo 将物理文件截断至 size
|
||||
func HookTruncateFileTo(size uint64) Hook {
|
||||
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
if handler, ok := fs.Handler.(local.Driver); ok {
|
||||
return handler.Truncate(ctx, fileHeader.Info().SavePath, size)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// HookChunkUploadFinished 单个分片上传结束后
|
||||
func HookChunkUploaded(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
fileInfo := fileHeader.Info()
|
||||
|
||||
// 更新文件大小
|
||||
return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart + fileInfo.Size)
|
||||
}
|
||||
|
||||
// HookChunkUploadFailed 单个分片上传失败后
|
||||
func HookChunkUploadFailed(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
fileInfo := fileHeader.Info()
|
||||
|
||||
// 更新文件大小
|
||||
return fileInfo.Model.(*model.File).UpdateSize(fileInfo.AppendStart)
|
||||
}
|
||||
|
||||
// HookPopPlaceholderToFile 将占位文件提升为正式文件
|
||||
func HookPopPlaceholderToFile(picInfo string) Hook {
|
||||
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
fileInfo := fileHeader.Info()
|
||||
fileModel := fileInfo.Model.(*model.File)
|
||||
return fileModel.PopChunkToFile(fileInfo.LastModified, picInfo)
|
||||
}
|
||||
}
|
||||
|
||||
// HookChunkUploadFinished 分片上传结束后处理文件
|
||||
func HookDeleteUploadSession(id string) Hook {
|
||||
return func(ctx context.Context, fs *FileSystem, fileHeader fsctx.FileHeader) error {
|
||||
cache.Deletes([]string{id}, UploadSessionCachePrefix)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewWebdavAfterUploadHook 每次创建一个新的钩子函数 rclone 在 PUT 请求里有 OC-Checksum 字符串
|
||||
// 和 X-OC-Mtime
|
||||
func NewWebdavAfterUploadHook(request *http.Request) func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
|
||||
var modtime time.Time
|
||||
if timeVal := request.Header.Get("X-OC-Mtime"); timeVal != "" {
|
||||
timeUnix, err := strconv.ParseInt(timeVal, 10, 64)
|
||||
if err == nil {
|
||||
modtime = time.Unix(timeUnix, 0)
|
||||
}
|
||||
}
|
||||
checksum := request.Header.Get("OC-Checksum")
|
||||
|
||||
return func(ctx context.Context, fs *FileSystem, newFile fsctx.FileHeader) error {
|
||||
file := newFile.Info().Model.(*model.File)
|
||||
if !modtime.IsZero() {
|
||||
err := model.DB.Model(file).UpdateColumn("updated_at", modtime).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if checksum != "" {
|
||||
return file.UpdateMetadata(map[string]string{
|
||||
model.ChecksumMetadataKey: checksum,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
218
pkg/filesystem/image.go
Normal file
218
pkg/filesystem/image.go
Normal file
@ -0,0 +1,218 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"runtime"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/thumb"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
/* ================
|
||||
图像处理相关
|
||||
================
|
||||
*/
|
||||
|
||||
// GetThumb 获取文件的缩略图
|
||||
func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentResponse, error) {
|
||||
// 根据 ID 查找文件
|
||||
err := fs.resetFileIDIfNotExist(ctx, id)
|
||||
if err != nil {
|
||||
return nil, ErrObjectNotExist
|
||||
}
|
||||
|
||||
file := fs.FileTarget[0]
|
||||
if !file.ShouldLoadThumb() {
|
||||
return nil, ErrObjectNotExist
|
||||
}
|
||||
|
||||
w, h := fs.GenerateThumbnailSize(0, 0)
|
||||
ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, [2]uint{w, h})
|
||||
ctx = context.WithValue(ctx, fsctx.FileModelCtx, file)
|
||||
res, err := fs.Handler.Thumb(ctx, &file)
|
||||
if errors.Is(err, driver.ErrorThumbNotExist) {
|
||||
// Regenerate thumb if the thumb is not initialized yet
|
||||
if generateErr := fs.generateThumbnail(ctx, &file); generateErr == nil {
|
||||
res, err = fs.Handler.Thumb(ctx, &file)
|
||||
} else {
|
||||
err = generateErr
|
||||
}
|
||||
} else if errors.Is(err, driver.ErrorThumbNotSupported) {
|
||||
// Policy handler explicitly indicates thumb not available, check if proxy is enabled
|
||||
if fs.Policy.CouldProxyThumb() {
|
||||
// if thumb id marked as existed, redirect to "sidecar" thumb file.
|
||||
if file.MetadataSerialized != nil &&
|
||||
file.MetadataSerialized[model.ThumbStatusMetadataKey] == model.ThumbStatusExist {
|
||||
// redirect to sidecar file
|
||||
res = &response.ContentResponse{
|
||||
Redirect: true,
|
||||
}
|
||||
res.URL, err = fs.Handler.Source(ctx, file.ThumbFile(), int64(model.GetIntSetting("preview_timeout", 60)), false, 0)
|
||||
} else {
|
||||
// if not exist, generate and upload the sidecar thumb.
|
||||
if err = fs.generateThumbnail(ctx, &file); err == nil {
|
||||
return fs.GetThumb(ctx, id)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// thumb not supported and proxy is disabled, mark as not available
|
||||
_ = updateThumbStatus(&file, model.ThumbStatusNotAvailable)
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil && conf.SystemConfig.Mode == "master" {
|
||||
res.MaxAge = model.GetIntSetting("preview_timeout", 60)
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
// thumbPool 要使用的任务池
|
||||
var thumbPool *Pool
|
||||
var once sync.Once
|
||||
|
||||
// Pool 带有最大配额的任务池
|
||||
type Pool struct {
|
||||
// 容量
|
||||
worker chan int
|
||||
}
|
||||
|
||||
// Init 初始化任务池
|
||||
func getThumbWorker() *Pool {
|
||||
once.Do(func() {
|
||||
maxWorker := model.GetIntSetting("thumb_max_task_count", -1)
|
||||
if maxWorker <= 0 {
|
||||
maxWorker = runtime.GOMAXPROCS(0)
|
||||
}
|
||||
thumbPool = &Pool{
|
||||
worker: make(chan int, maxWorker),
|
||||
}
|
||||
util.Log().Debug("Initialize thumbnails task queue with: WorkerNum = %d", maxWorker)
|
||||
})
|
||||
return thumbPool
|
||||
}
|
||||
func (pool *Pool) addWorker() {
|
||||
pool.worker <- 1
|
||||
util.Log().Debug("Worker added to thumbnails task queue.")
|
||||
}
|
||||
func (pool *Pool) releaseWorker() {
|
||||
util.Log().Debug("Worker released from thumbnails task queue.")
|
||||
<-pool.worker
|
||||
}
|
||||
|
||||
// generateThumbnail generates thumb for given file, upload the thumb file back with given suffix
|
||||
func (fs *FileSystem) generateThumbnail(ctx context.Context, file *model.File) error {
|
||||
// 新建上下文
|
||||
newCtx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
// TODO: check file size
|
||||
|
||||
if file.Size > uint64(model.GetIntSetting("thumb_max_src_size", 31457280)) {
|
||||
_ = updateThumbStatus(file, model.ThumbStatusNotAvailable)
|
||||
return errors.New("file too large")
|
||||
}
|
||||
|
||||
getThumbWorker().addWorker()
|
||||
defer getThumbWorker().releaseWorker()
|
||||
|
||||
// 获取文件数据
|
||||
source, err := fs.Handler.Get(newCtx, file.SourceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("faield to fetch original file %q: %w", file.SourceName, err)
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
// Provide file source path for local policy files
|
||||
src := ""
|
||||
if conf.SystemConfig.Mode == "slave" || file.GetPolicy().Type == "local" {
|
||||
src = file.SourceName
|
||||
}
|
||||
|
||||
thumbRes, err := thumb.Generators.Generate(ctx, source, src, file.Name, model.GetSettingByNames(
|
||||
"thumb_width",
|
||||
"thumb_height",
|
||||
"thumb_builtin_enabled",
|
||||
"thumb_vips_enabled",
|
||||
"thumb_ffmpeg_enabled",
|
||||
"thumb_libreoffice_enabled",
|
||||
))
|
||||
if err != nil {
|
||||
_ = updateThumbStatus(file, model.ThumbStatusNotAvailable)
|
||||
return fmt.Errorf("failed to generate thumb for %q: %w", file.Name, err)
|
||||
}
|
||||
|
||||
defer os.Remove(thumbRes.Path)
|
||||
|
||||
thumbFile, err := os.Open(thumbRes.Path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open temp thumb %q: %w", thumbRes.Path, err)
|
||||
}
|
||||
|
||||
defer thumbFile.Close()
|
||||
fileInfo, err := thumbFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat temp thumb %q: %w", thumbRes.Path, err)
|
||||
}
|
||||
|
||||
if err = fs.Handler.Put(newCtx, &fsctx.FileStream{
|
||||
Mode: fsctx.Overwrite,
|
||||
File: thumbFile,
|
||||
Seeker: thumbFile,
|
||||
Size: uint64(fileInfo.Size()),
|
||||
SavePath: file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb"),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("failed to save thumb for %q: %w", file.Name, err)
|
||||
}
|
||||
|
||||
if model.IsTrueVal(model.GetSettingByName("thumb_gc_after_gen")) {
|
||||
util.Log().Debug("generateThumbnail runtime.GC")
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
// Mark this file as thumb available
|
||||
err = updateThumbStatus(file, model.ThumbStatusExist)
|
||||
|
||||
// 失败时删除缩略图文件
|
||||
if err != nil {
|
||||
_, _ = fs.Handler.Delete(newCtx, []string{file.SourceName + model.GetSettingByNameWithDefault("thumb_file_suffix", "._thumb")})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateThumbnailSize 获取要生成的缩略图的尺寸
|
||||
func (fs *FileSystem) GenerateThumbnailSize(w, h int) (uint, uint) {
|
||||
return uint(model.GetIntSetting("thumb_width", 400)), uint(model.GetIntSetting("thumb_height", 300))
|
||||
}
|
||||
|
||||
func updateThumbStatus(file *model.File, status string) error {
|
||||
if file.Model.ID > 0 {
|
||||
meta := map[string]string{
|
||||
model.ThumbStatusMetadataKey: status,
|
||||
}
|
||||
|
||||
if status == model.ThumbStatusExist {
|
||||
meta[model.ThumbSidecarMetadataKey] = "true"
|
||||
}
|
||||
|
||||
return file.UpdateMetadata(meta)
|
||||
} else {
|
||||
if file.MetadataSerialized == nil {
|
||||
file.MetadataSerialized = map[string]string{}
|
||||
}
|
||||
|
||||
file.MetadataSerialized[model.ThumbStatusMetadataKey] = status
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
479
pkg/filesystem/manage.go
Normal file
479
pkg/filesystem/manage.go
Normal file
@ -0,0 +1,479 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/hashid"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
/* =================
|
||||
文件/目录管理
|
||||
=================
|
||||
*/
|
||||
|
||||
// Rename 重命名对象
|
||||
func (fs *FileSystem) Rename(ctx context.Context, dir, file []uint, new string) (err error) {
|
||||
// 验证新名字
|
||||
if !fs.ValidateLegalName(ctx, new) || (len(file) > 0 && !fs.ValidateExtension(ctx, new)) {
|
||||
return ErrIllegalObjectName
|
||||
}
|
||||
|
||||
// 如果源对象是文件
|
||||
if len(file) > 0 {
|
||||
fileObject, err := model.GetFilesByIDs([]uint{file[0]}, fs.User.ID)
|
||||
if err != nil || len(fileObject) == 0 {
|
||||
return ErrPathNotExist
|
||||
}
|
||||
|
||||
err = fileObject[0].Rename(new)
|
||||
if err != nil {
|
||||
return ErrFileExisted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(dir) > 0 {
|
||||
folderObject, err := model.GetFoldersByIDs([]uint{dir[0]}, fs.User.ID)
|
||||
if err != nil || len(folderObject) == 0 {
|
||||
return ErrPathNotExist
|
||||
}
|
||||
|
||||
err = folderObject[0].Rename(new)
|
||||
if err != nil {
|
||||
return ErrFileExisted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrPathNotExist
|
||||
}
|
||||
|
||||
// Copy 复制src目录下的文件或目录到dst,
|
||||
// 暂时只支持单文件
|
||||
func (fs *FileSystem) Copy(ctx context.Context, dirs, files []uint, src, dst string) error {
|
||||
// 获取目的目录
|
||||
isDstExist, dstFolder := fs.IsPathExist(dst)
|
||||
isSrcExist, srcFolder := fs.IsPathExist(src)
|
||||
// 不存在时返回空的结果
|
||||
if !isDstExist || !isSrcExist {
|
||||
return ErrPathNotExist
|
||||
}
|
||||
|
||||
// 记录复制的文件的总容量
|
||||
var newUsedStorage uint64
|
||||
|
||||
// 设置webdav目标名
|
||||
if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok {
|
||||
dstFolder.WebdavDstName = dstName
|
||||
}
|
||||
|
||||
// 复制目录
|
||||
if len(dirs) > 0 {
|
||||
subFileSizes, err := srcFolder.CopyFolderTo(dirs[0], dstFolder)
|
||||
if err != nil {
|
||||
return ErrObjectNotExist.WithError(err)
|
||||
}
|
||||
newUsedStorage += subFileSizes
|
||||
}
|
||||
|
||||
// 复制文件
|
||||
if len(files) > 0 {
|
||||
subFileSizes, err := srcFolder.MoveOrCopyFileTo(files, dstFolder, true)
|
||||
if err != nil {
|
||||
return ErrObjectNotExist.WithError(err)
|
||||
}
|
||||
newUsedStorage += subFileSizes
|
||||
}
|
||||
|
||||
// 扣除容量
|
||||
fs.User.IncreaseStorageWithoutCheck(newUsedStorage)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Move 移动文件和目录, 将id列表dirs和files从src移动至dst
|
||||
func (fs *FileSystem) Move(ctx context.Context, dirs, files []uint, src, dst string) error {
|
||||
// 获取目的目录
|
||||
isDstExist, dstFolder := fs.IsPathExist(dst)
|
||||
isSrcExist, srcFolder := fs.IsPathExist(src)
|
||||
// 不存在时返回空的结果
|
||||
if !isDstExist || !isSrcExist {
|
||||
return ErrPathNotExist
|
||||
}
|
||||
|
||||
// 设置webdav目标名
|
||||
if dstName, ok := ctx.Value(fsctx.WebdavDstName).(string); ok {
|
||||
dstFolder.WebdavDstName = dstName
|
||||
}
|
||||
|
||||
// 处理目录及子文件移动
|
||||
err := srcFolder.MoveFolderTo(dirs, dstFolder)
|
||||
if err != nil {
|
||||
return ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
// 处理文件移动
|
||||
_, err = srcFolder.MoveOrCopyFileTo(files, dstFolder, false)
|
||||
if err != nil {
|
||||
return ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
// 移动文件
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete 递归删除对象, force 为 true 时强制删除文件记录,忽略物理删除是否成功;
|
||||
// unlink 为 true 时只删除虚拟文件系统的文件记录,不删除物理文件。
|
||||
func (fs *FileSystem) Delete(ctx context.Context, dirs, files []uint, force, unlink bool) error {
|
||||
// 已删除的文件ID
|
||||
var deletedFiles = make([]*model.File, 0, len(fs.FileTarget))
|
||||
// 删除失败的文件的父目录ID
|
||||
|
||||
// 所有文件的ID
|
||||
var allFiles = make([]*model.File, 0, len(fs.FileTarget))
|
||||
|
||||
// 列出要删除的目录
|
||||
if len(dirs) > 0 {
|
||||
err := fs.ListDeleteDirs(ctx, dirs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 列出要删除的文件
|
||||
if len(files) > 0 {
|
||||
err := fs.ListDeleteFiles(ctx, files)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 去除待删除文件中包含软连接的部分
|
||||
filesToBeDelete, err := model.RemoveFilesWithSoftLinks(fs.FileTarget)
|
||||
if err != nil {
|
||||
return ErrDBListObjects.WithError(err)
|
||||
}
|
||||
|
||||
// 根据存储策略将文件分组
|
||||
policyGroup := fs.GroupFileByPolicy(ctx, filesToBeDelete)
|
||||
|
||||
// 按照存储策略分组删除对象
|
||||
failed := make(map[uint][]string)
|
||||
if !unlink {
|
||||
failed = fs.deleteGroupedFile(ctx, policyGroup)
|
||||
}
|
||||
|
||||
// 整理删除结果
|
||||
for i := 0; i < len(fs.FileTarget); i++ {
|
||||
if !util.ContainsString(failed[fs.FileTarget[i].PolicyID], fs.FileTarget[i].SourceName) {
|
||||
// 已成功删除的文件
|
||||
deletedFiles = append(deletedFiles, &fs.FileTarget[i])
|
||||
}
|
||||
|
||||
// 全部文件
|
||||
allFiles = append(allFiles, &fs.FileTarget[i])
|
||||
}
|
||||
|
||||
// 如果强制删除,则将全部文件视为删除成功
|
||||
if force {
|
||||
deletedFiles = allFiles
|
||||
}
|
||||
|
||||
// 删除文件记录
|
||||
err = model.DeleteFiles(deletedFiles, fs.User.ID)
|
||||
if err != nil {
|
||||
return ErrDBDeleteObjects.WithError(err)
|
||||
}
|
||||
|
||||
// 删除文件记录对应的分享记录
|
||||
// TODO 先取消分享再删除文件
|
||||
deletedFileIDs := make([]uint, len(deletedFiles))
|
||||
for k, file := range deletedFiles {
|
||||
deletedFileIDs[k] = file.ID
|
||||
}
|
||||
|
||||
model.DeleteShareBySourceIDs(deletedFileIDs, false)
|
||||
|
||||
// 如果文件全部删除成功,继续删除目录
|
||||
if len(deletedFiles) == len(allFiles) {
|
||||
var allFolderIDs = make([]uint, 0, len(fs.DirTarget))
|
||||
for _, value := range fs.DirTarget {
|
||||
allFolderIDs = append(allFolderIDs, value.ID)
|
||||
}
|
||||
err = model.DeleteFolderByIDs(allFolderIDs)
|
||||
if err != nil {
|
||||
return ErrDBDeleteObjects.WithError(err)
|
||||
}
|
||||
|
||||
// 删除目录记录对应的分享记录
|
||||
model.DeleteShareBySourceIDs(allFolderIDs, true)
|
||||
}
|
||||
|
||||
if notDeleted := len(fs.FileTarget) - len(deletedFiles); notDeleted > 0 {
|
||||
return serializer.NewError(
|
||||
serializer.CodeNotFullySuccess,
|
||||
fmt.Sprintf("Failed to delete %d file(s).", notDeleted),
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDeleteDirs 递归列出要删除目录,及目录下所有文件
|
||||
func (fs *FileSystem) ListDeleteDirs(ctx context.Context, ids []uint) error {
|
||||
// 列出所有递归子目录
|
||||
folders, err := model.GetRecursiveChildFolder(ids, fs.User.ID, true)
|
||||
if err != nil {
|
||||
return ErrDBListObjects.WithError(err)
|
||||
}
|
||||
|
||||
// 忽略根目录
|
||||
for i := 0; i < len(folders); i++ {
|
||||
if folders[i].ParentID == nil {
|
||||
folders = append(folders[:i], folders[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fs.SetTargetDir(&folders)
|
||||
|
||||
// 检索目录下的子文件
|
||||
files, err := model.GetChildFilesOfFolders(&folders)
|
||||
if err != nil {
|
||||
return ErrDBListObjects.WithError(err)
|
||||
}
|
||||
fs.SetTargetFile(&files)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListDeleteFiles 根据给定的路径列出要删除的文件
|
||||
func (fs *FileSystem) ListDeleteFiles(ctx context.Context, ids []uint) error {
|
||||
files, err := model.GetFilesByIDs(ids, fs.User.ID)
|
||||
if err != nil {
|
||||
return ErrDBListObjects.WithError(err)
|
||||
}
|
||||
fs.SetTargetFile(&files)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 列出路径下的内容,
|
||||
// pathProcessor为最终对象路径的处理钩子。
|
||||
// 有些情况下(如在分享页面列对象)时,
|
||||
// 路径需要截取掉被分享目录路径之前的部分。
|
||||
func (fs *FileSystem) List(ctx context.Context, dirPath string, pathProcessor func(string) string) ([]serializer.Object, error) {
|
||||
// 获取父目录
|
||||
isExist, folder := fs.IsPathExist(dirPath)
|
||||
if !isExist {
|
||||
return nil, ErrPathNotExist
|
||||
}
|
||||
fs.SetTargetDir(&[]model.Folder{*folder})
|
||||
|
||||
var parentPath = path.Join(folder.Position, folder.Name)
|
||||
var childFolders []model.Folder
|
||||
var childFiles []model.File
|
||||
|
||||
// 获取子目录
|
||||
childFolders, _ = folder.GetChildFolder()
|
||||
|
||||
// 获取子文件
|
||||
childFiles, _ = folder.GetChildFiles()
|
||||
|
||||
return fs.listObjects(ctx, parentPath, childFiles, childFolders, pathProcessor), nil
|
||||
}
|
||||
|
||||
// ListPhysical 列出存储策略中的外部目录
|
||||
// TODO:测试
|
||||
func (fs *FileSystem) ListPhysical(ctx context.Context, dirPath string) ([]serializer.Object, error) {
|
||||
if err := fs.DispatchHandler(); fs.Policy == nil || err != nil {
|
||||
return nil, ErrUnknownPolicyType
|
||||
}
|
||||
|
||||
// 存储策略不支持列取时,返回空结果
|
||||
if !fs.Policy.CanStructureBeListed() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 列取路径
|
||||
objects, err := fs.Handler.List(ctx, dirPath, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
folders []model.Folder
|
||||
)
|
||||
for _, object := range objects {
|
||||
if object.IsDir {
|
||||
folders = append(folders, model.Folder{
|
||||
Name: object.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return fs.listObjects(ctx, dirPath, nil, folders, nil), nil
|
||||
}
|
||||
|
||||
func (fs *FileSystem) listObjects(ctx context.Context, parent string, files []model.File, folders []model.Folder, pathProcessor func(string) string) []serializer.Object {
|
||||
// 分享文件的ID
|
||||
shareKey := ""
|
||||
if key, ok := ctx.Value(fsctx.ShareKeyCtx).(string); ok {
|
||||
shareKey = key
|
||||
}
|
||||
|
||||
// 汇总处理结果
|
||||
objects := make([]serializer.Object, 0, len(files)+len(folders))
|
||||
|
||||
// 所有对象的父目录
|
||||
var processedPath string
|
||||
|
||||
for _, subFolder := range folders {
|
||||
// 路径处理钩子,
|
||||
// 所有对象父目录都是一样的,所以只处理一次
|
||||
if processedPath == "" {
|
||||
if pathProcessor != nil {
|
||||
processedPath = pathProcessor(parent)
|
||||
} else {
|
||||
processedPath = parent
|
||||
}
|
||||
}
|
||||
|
||||
objects = append(objects, serializer.Object{
|
||||
ID: hashid.HashID(subFolder.ID, hashid.FolderID),
|
||||
Name: subFolder.Name,
|
||||
Path: processedPath,
|
||||
Size: 0,
|
||||
Type: "dir",
|
||||
Date: subFolder.UpdatedAt,
|
||||
CreateDate: subFolder.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if processedPath == "" {
|
||||
if pathProcessor != nil {
|
||||
processedPath = pathProcessor(parent)
|
||||
} else {
|
||||
processedPath = parent
|
||||
}
|
||||
}
|
||||
|
||||
if file.UploadSessionID == nil {
|
||||
newFile := serializer.Object{
|
||||
ID: hashid.HashID(file.ID, hashid.FileID),
|
||||
Name: file.Name,
|
||||
Path: processedPath,
|
||||
Thumb: file.ShouldLoadThumb(),
|
||||
Size: file.Size,
|
||||
Type: "file",
|
||||
Date: file.UpdatedAt,
|
||||
SourceEnabled: file.GetPolicy().IsOriginLinkEnable,
|
||||
CreateDate: file.CreatedAt,
|
||||
}
|
||||
if shareKey != "" {
|
||||
newFile.Key = shareKey
|
||||
}
|
||||
objects = append(objects, newFile)
|
||||
}
|
||||
}
|
||||
|
||||
return objects
|
||||
}
|
||||
|
||||
// CreateDirectory 根据给定的完整创建目录,支持递归创建。如果目录已存在,则直接
|
||||
// 返回已存在的目录。
|
||||
func (fs *FileSystem) CreateDirectory(ctx context.Context, fullPath string) (*model.Folder, error) {
|
||||
if fullPath == "." || fullPath == "" {
|
||||
return nil, ErrRootProtected
|
||||
}
|
||||
|
||||
if fullPath == "/" {
|
||||
if fs.Root != nil {
|
||||
return fs.Root, nil
|
||||
}
|
||||
return fs.User.Root()
|
||||
}
|
||||
|
||||
// 获取要创建目录的父路径和目录名
|
||||
fullPath = path.Clean(fullPath)
|
||||
base := path.Dir(fullPath)
|
||||
dir := path.Base(fullPath)
|
||||
|
||||
// 去掉结尾空格
|
||||
dir = strings.TrimRight(dir, " ")
|
||||
|
||||
// 检查目录名是否合法
|
||||
if !fs.ValidateLegalName(ctx, dir) {
|
||||
return nil, ErrIllegalObjectName
|
||||
}
|
||||
|
||||
// 父目录是否存在
|
||||
isExist, parent := fs.IsPathExist(base)
|
||||
if !isExist {
|
||||
newParent, err := fs.CreateDirectory(ctx, base)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parent = newParent
|
||||
}
|
||||
|
||||
// 是否有同名文件
|
||||
if ok, _ := fs.IsChildFileExist(parent, dir); ok {
|
||||
return nil, ErrFileExisted
|
||||
}
|
||||
|
||||
// 创建目录
|
||||
newFolder := model.Folder{
|
||||
Name: dir,
|
||||
ParentID: &parent.ID,
|
||||
OwnerID: fs.User.ID,
|
||||
}
|
||||
_, err := newFolder.Create()
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create folder: %w", err)
|
||||
}
|
||||
|
||||
return &newFolder, nil
|
||||
}
|
||||
|
||||
// SaveTo 将别人分享的文件转存到目标路径下
|
||||
func (fs *FileSystem) SaveTo(ctx context.Context, path string) error {
|
||||
// 获取父目录
|
||||
isExist, folder := fs.IsPathExist(path)
|
||||
if !isExist {
|
||||
return ErrPathNotExist
|
||||
}
|
||||
|
||||
var (
|
||||
totalSize uint64
|
||||
err error
|
||||
)
|
||||
|
||||
if len(fs.DirTarget) > 0 {
|
||||
totalSize, err = fs.DirTarget[0].CopyFolderTo(fs.DirTarget[0].ID, folder)
|
||||
} else {
|
||||
parent := model.Folder{
|
||||
OwnerID: fs.FileTarget[0].UserID,
|
||||
}
|
||||
parent.ID = fs.FileTarget[0].FolderID
|
||||
totalSize, err = parent.MoveOrCopyFileTo([]uint{fs.FileTarget[0].ID}, folder, true)
|
||||
}
|
||||
|
||||
// 扣除用户容量
|
||||
fs.User.IncreaseStorageWithoutCheck(totalSize)
|
||||
if err != nil {
|
||||
return ErrFileExisted.WithError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
25
pkg/filesystem/oauth/mutex.go
Normal file
25
pkg/filesystem/oauth/mutex.go
Normal file
@ -0,0 +1,25 @@
|
||||
package oauth
|
||||
|
||||
import "sync"
|
||||
|
||||
// CredentialLock 针对存储策略凭证的锁
|
||||
type CredentialLock interface {
|
||||
Lock(uint)
|
||||
Unlock(uint)
|
||||
}
|
||||
|
||||
var GlobalMutex = mutexMap{}
|
||||
|
||||
type mutexMap struct {
|
||||
locks sync.Map
|
||||
}
|
||||
|
||||
func (m *mutexMap) Lock(id uint) {
|
||||
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
|
||||
lock.(*sync.Mutex).Lock()
|
||||
}
|
||||
|
||||
func (m *mutexMap) Unlock(id uint) {
|
||||
lock, _ := m.locks.LoadOrStore(id, &sync.Mutex{})
|
||||
lock.(*sync.Mutex).Unlock()
|
||||
}
|
8
pkg/filesystem/oauth/token.go
Normal file
8
pkg/filesystem/oauth/token.go
Normal file
@ -0,0 +1,8 @@
|
||||
package oauth
|
||||
|
||||
import "context"
|
||||
|
||||
type TokenProvider interface {
|
||||
UpdateCredential(ctx context.Context, isSlave bool) error
|
||||
AccessToken() string
|
||||
}
|
84
pkg/filesystem/path.go
Normal file
84
pkg/filesystem/path.go
Normal file
@ -0,0 +1,84 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"path"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
/* =================
|
||||
路径/目录相关
|
||||
=================
|
||||
*/
|
||||
|
||||
// IsPathExist 返回给定目录是否存在
|
||||
// 如果存在就返回目录
|
||||
func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) {
|
||||
tracedEnd, currentFolder := fs.getClosedParent(path)
|
||||
if tracedEnd {
|
||||
return true, currentFolder
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (fs *FileSystem) getClosedParent(path string) (bool, *model.Folder) {
|
||||
pathList := util.SplitPath(path)
|
||||
if len(pathList) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 递归步入目录
|
||||
var currentFolder *model.Folder
|
||||
|
||||
// 如果已设定跟目录对象,则从给定目录向下遍历
|
||||
if fs.Root != nil {
|
||||
currentFolder = fs.Root
|
||||
}
|
||||
|
||||
for _, folderName := range pathList {
|
||||
var err error
|
||||
|
||||
// 根目录
|
||||
if folderName == "/" {
|
||||
if currentFolder != nil {
|
||||
continue
|
||||
}
|
||||
currentFolder, err = fs.User.Root()
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
} else {
|
||||
nextFolder, err := currentFolder.GetChild(folderName)
|
||||
if err != nil {
|
||||
return false, currentFolder
|
||||
}
|
||||
|
||||
currentFolder = nextFolder
|
||||
}
|
||||
}
|
||||
|
||||
return true, currentFolder
|
||||
}
|
||||
|
||||
// IsFileExist 返回给定路径的文件是否存在
|
||||
func (fs *FileSystem) IsFileExist(fullPath string) (bool, *model.File) {
|
||||
basePath := path.Dir(fullPath)
|
||||
fileName := path.Base(fullPath)
|
||||
|
||||
// 获得父目录
|
||||
exist, parent := fs.IsPathExist(basePath)
|
||||
if !exist {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
file, err := parent.GetChildFile(fileName)
|
||||
|
||||
return err == nil, file
|
||||
}
|
||||
|
||||
// IsChildFileExist 确定folder目录下是否有名为name的文件
|
||||
func (fs *FileSystem) IsChildFileExist(folder *model.Folder, name string) (bool, *model.File) {
|
||||
file, err := folder.GetChildFile(name)
|
||||
return err == nil, file
|
||||
}
|
102
pkg/filesystem/relocate.go
Normal file
102
pkg/filesystem/relocate.go
Normal file
@ -0,0 +1,102 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
/* ================
|
||||
存储策略迁移
|
||||
================
|
||||
*/
|
||||
|
||||
// Relocate 将目标文件转移到当前存储策略下
|
||||
func (fs *FileSystem) Relocate(ctx context.Context, files []model.File, policy *model.Policy) error {
|
||||
// 重设存储策略为要转移的目的策略
|
||||
fs.Policy = policy
|
||||
if err := fs.DispatchHandler(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将目前文件根据存储策略分组
|
||||
fileGroup := fs.GroupFileByPolicy(ctx, files)
|
||||
|
||||
// 按照存储策略分组处理每个文件
|
||||
for _, fileList := range fileGroup {
|
||||
// 如果存储策略一样,则跳过
|
||||
if fileList[0].GetPolicy().ID == fs.Policy.ID {
|
||||
util.Log().Debug("Skip relocating %d file(s), since they are already in desired policy.",
|
||||
len(fileList))
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取当前存储策略的处理器
|
||||
currentPolicy, _ := model.GetPolicyByID(fileList[0].PolicyID)
|
||||
currentHandler, err := getNewPolicyHandler(¤tPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录转移完毕需要删除的文件
|
||||
toBeDeleted := make([]model.File, 0, len(fileList))
|
||||
|
||||
// 循环处理每一个文件
|
||||
// for id, r := 0, len(fileList); id < r; id++ {
|
||||
for id, _ := range fileList {
|
||||
// 验证文件是否符合新存储策略的规定
|
||||
if err := HookValidateFile(ctx, fs, fileList[id]); err != nil {
|
||||
util.Log().Debug("File %q failed to pass validators in new policy %q, skipping.",
|
||||
fileList[id].Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 为文件生成新存储策略下的物理路径
|
||||
savePath := fs.GenerateSavePath(ctx, fileList[id])
|
||||
|
||||
// 获取原始文件
|
||||
src, err := currentHandler.Get(ctx, fileList[id].SourceName)
|
||||
if err != nil {
|
||||
util.Log().Debug("Failed to get file %q: %s, skipping.",
|
||||
fileList[id].Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 转存到新的存储策略
|
||||
if err := fs.Handler.Put(ctx, &fsctx.FileStream{
|
||||
File: src,
|
||||
SavePath: savePath,
|
||||
Size: fileList[id].Size,
|
||||
}); err != nil {
|
||||
util.Log().Debug("Failed to migrate file %q: %s, skipping.",
|
||||
fileList[id].Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
toBeDeleted = append(toBeDeleted, *fileList[id])
|
||||
|
||||
// 更新文件信息
|
||||
fileList[id].Relocate(savePath, fs.Policy.ID)
|
||||
}
|
||||
|
||||
// 排除带有软链接的文件
|
||||
toBeDeletedClean, err := model.RemoveFilesWithSoftLinks(toBeDeleted)
|
||||
if err != nil {
|
||||
util.Log().Warning("Failed to check soft links: %s", err)
|
||||
}
|
||||
|
||||
deleteSourceNames := make([]string, 0, len(toBeDeleted))
|
||||
for i := 0; i < len(toBeDeletedClean); i++ {
|
||||
deleteSourceNames = append(deleteSourceNames, toBeDeletedClean[i].SourceName)
|
||||
}
|
||||
|
||||
// 删除原始策略中的文件
|
||||
if _, err := currentHandler.Delete(ctx, deleteSourceNames); err != nil {
|
||||
util.Log().Warning("Cannot delete files in origin policy after relocating: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
32
pkg/filesystem/response/common.go
Normal file
32
pkg/filesystem/response/common.go
Normal file
@ -0,0 +1,32 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ContentResponse 获取文件内容类方法的通用返回值。
|
||||
// 有些上传策略需要重定向,
|
||||
// 有些直接写文件数据到浏览器
|
||||
type ContentResponse struct {
|
||||
Redirect bool
|
||||
Content RSCloser
|
||||
URL string
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// RSCloser 存储策略适配器返回的文件流,有些策略需要带有Closer
|
||||
type RSCloser interface {
|
||||
io.ReadSeeker
|
||||
io.Closer
|
||||
}
|
||||
|
||||
// Object 列出文件、目录时返回的对象
|
||||
type Object struct {
|
||||
Name string `json:"name"`
|
||||
RelativePath string `json:"relative_path"`
|
||||
Source string `json:"source"`
|
||||
Size uint64 `json:"size"`
|
||||
IsDir bool `json:"is_dir"`
|
||||
LastModify time.Time `json:"last_modify"`
|
||||
}
|
0
pkg/filesystem/tests/file1.txt
Normal file
0
pkg/filesystem/tests/file1.txt
Normal file
0
pkg/filesystem/tests/file2.txt
Normal file
0
pkg/filesystem/tests/file2.txt
Normal file
BIN
pkg/filesystem/tests/test.zip
Normal file
BIN
pkg/filesystem/tests/test.zip
Normal file
Binary file not shown.
243
pkg/filesystem/upload.go
Normal file
243
pkg/filesystem/upload.go
Normal file
@ -0,0 +1,243 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gofrs/uuid"
|
||||
)
|
||||
|
||||
/* ================
|
||||
上传处理相关
|
||||
================
|
||||
*/
|
||||
|
||||
const (
|
||||
UploadSessionMetaKey = "upload_session"
|
||||
UploadSessionCtx = "uploadSession"
|
||||
UserCtx = "user"
|
||||
UploadSessionCachePrefix = "callback_"
|
||||
)
|
||||
|
||||
// Upload 上传文件
|
||||
func (fs *FileSystem) Upload(ctx context.Context, file *fsctx.FileStream) (err error) {
|
||||
// 上传前的钩子
|
||||
err = fs.Trigger(ctx, "BeforeUpload", file)
|
||||
if err != nil {
|
||||
request.BlackHole(file)
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成文件名和路径,
|
||||
var savePath string
|
||||
if file.SavePath == "" {
|
||||
// 如果是更新操作就从上下文中获取
|
||||
if originFile, ok := ctx.Value(fsctx.FileModelCtx).(model.File); ok {
|
||||
savePath = originFile.SourceName
|
||||
} else {
|
||||
savePath = fs.GenerateSavePath(ctx, file)
|
||||
}
|
||||
file.SavePath = savePath
|
||||
}
|
||||
|
||||
// 保存文件
|
||||
if file.Mode&fsctx.Nop != fsctx.Nop {
|
||||
// 处理客户端未完成上传时,关闭连接
|
||||
go fs.CancelUpload(ctx, savePath, file)
|
||||
|
||||
err = fs.Handler.Put(ctx, file)
|
||||
if err != nil {
|
||||
fs.Trigger(ctx, "AfterUploadFailed", file)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 上传完成后的钩子
|
||||
err = fs.Trigger(ctx, "AfterUpload", file)
|
||||
|
||||
if err != nil {
|
||||
// 上传完成后续处理失败
|
||||
followUpErr := fs.Trigger(ctx, "AfterValidateFailed", file)
|
||||
// 失败后再失败...
|
||||
if followUpErr != nil {
|
||||
util.Log().Debug("AfterValidateFailed hook execution failed: %s", followUpErr)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateSavePath 生成要存放文件的路径
|
||||
// TODO 完善测试
|
||||
func (fs *FileSystem) GenerateSavePath(ctx context.Context, file fsctx.FileHeader) string {
|
||||
fileInfo := file.Info()
|
||||
return path.Join(
|
||||
fs.Policy.GeneratePath(
|
||||
fs.User.Model.ID,
|
||||
fileInfo.VirtualPath,
|
||||
),
|
||||
fs.Policy.GenerateFileName(
|
||||
fs.User.Model.ID,
|
||||
fileInfo.FileName,
|
||||
),
|
||||
)
|
||||
|
||||
}
|
||||
|
||||
// CancelUpload 监测客户端取消上传
|
||||
func (fs *FileSystem) CancelUpload(ctx context.Context, path string, file fsctx.FileHeader) {
|
||||
var reqContext context.Context
|
||||
if ginCtx, ok := ctx.Value(fsctx.GinCtx).(*gin.Context); ok {
|
||||
reqContext = ginCtx.Request.Context()
|
||||
} else if reqCtx, ok := ctx.Value(fsctx.HTTPCtx).(context.Context); ok {
|
||||
reqContext = reqCtx
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-reqContext.Done():
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// 客户端正常关闭,不执行操作
|
||||
default:
|
||||
// 客户端取消上传,删除临时文件
|
||||
util.Log().Debug("Client canceled upload.")
|
||||
if fs.Hooks["AfterUploadCanceled"] == nil {
|
||||
return
|
||||
}
|
||||
err := fs.Trigger(ctx, "AfterUploadCanceled", file)
|
||||
if err != nil {
|
||||
util.Log().Debug("AfterUploadCanceled hook execution failed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUploadSession 创建上传会话
|
||||
func (fs *FileSystem) CreateUploadSession(ctx context.Context, file *fsctx.FileStream) (*serializer.UploadCredential, error) {
|
||||
// 获取相关有效期设置
|
||||
callBackSessionTTL := model.GetIntSetting("upload_session_timeout", 86400)
|
||||
|
||||
callbackKey := uuid.Must(uuid.NewV4()).String()
|
||||
fileSize := file.Size
|
||||
|
||||
// 创建占位的文件,同时校验文件信息
|
||||
file.Mode = fsctx.Nop
|
||||
if callbackKey != "" {
|
||||
file.UploadSessionID = &callbackKey
|
||||
}
|
||||
|
||||
fs.Use("BeforeUpload", HookValidateFile)
|
||||
fs.Use("BeforeUpload", HookValidateCapacity)
|
||||
|
||||
// 验证文件规格
|
||||
if err := fs.Upload(ctx, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uploadSession := &serializer.UploadSession{
|
||||
Key: callbackKey,
|
||||
UID: fs.User.ID,
|
||||
Policy: *fs.Policy,
|
||||
VirtualPath: file.VirtualPath,
|
||||
Name: file.Name,
|
||||
Size: fileSize,
|
||||
SavePath: file.SavePath,
|
||||
LastModified: file.LastModified,
|
||||
CallbackSecret: util.RandStringRunes(32),
|
||||
}
|
||||
|
||||
// 获取上传凭证
|
||||
credential, err := fs.Handler.Token(ctx, int64(callBackSessionTTL), uploadSession, file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建占位符
|
||||
if !fs.Policy.IsUploadPlaceholderWithSize() {
|
||||
fs.Use("AfterUpload", HookClearFileHeaderSize)
|
||||
}
|
||||
fs.Use("AfterUpload", GenericAfterUpload)
|
||||
ctx = context.WithValue(ctx, fsctx.IgnoreDirectoryConflictCtx, true)
|
||||
if err := fs.Upload(ctx, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建回调会话
|
||||
err = cache.Set(
|
||||
UploadSessionCachePrefix+callbackKey,
|
||||
*uploadSession,
|
||||
callBackSessionTTL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 补全上传凭证其他信息
|
||||
credential.Expires = time.Now().Add(time.Duration(callBackSessionTTL) * time.Second).Unix()
|
||||
|
||||
return credential, nil
|
||||
}
|
||||
|
||||
// UploadFromStream 从文件流上传文件
|
||||
func (fs *FileSystem) UploadFromStream(ctx context.Context, file *fsctx.FileStream, resetPolicy bool) error {
|
||||
// 给文件系统分配钩子
|
||||
fs.Lock.Lock()
|
||||
if resetPolicy {
|
||||
err := fs.SetPolicyFromPath(file.VirtualPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if fs.Hooks == nil {
|
||||
fs.Use("BeforeUpload", HookValidateFile)
|
||||
fs.Use("BeforeUpload", HookValidateCapacity)
|
||||
fs.Use("AfterUploadCanceled", HookDeleteTempFile)
|
||||
fs.Use("AfterUpload", GenericAfterUpload)
|
||||
fs.Use("AfterValidateFailed", HookDeleteTempFile)
|
||||
}
|
||||
fs.Lock.Unlock()
|
||||
|
||||
// 开始上传
|
||||
return fs.Upload(ctx, file)
|
||||
}
|
||||
|
||||
// UploadFromPath 将本机已有文件上传到用户的文件系统
|
||||
func (fs *FileSystem) UploadFromPath(ctx context.Context, src, dst string, mode fsctx.WriteMode) error {
|
||||
file, err := os.Open(util.RelativePath(src))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 获取源文件大小
|
||||
fi, err := file.Stat()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
size := fi.Size()
|
||||
|
||||
// 开始上传
|
||||
return fs.UploadFromStream(ctx, &fsctx.FileStream{
|
||||
File: file,
|
||||
Seeker: file,
|
||||
Size: uint64(size),
|
||||
Name: path.Base(dst),
|
||||
VirtualPath: path.Dir(dst),
|
||||
Mode: mode,
|
||||
}, true)
|
||||
}
|
66
pkg/filesystem/validator.go
Normal file
66
pkg/filesystem/validator.go
Normal file
@ -0,0 +1,66 @@
|
||||
package filesystem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
)
|
||||
|
||||
/* ==========
|
||||
验证器
|
||||
==========
|
||||
*/
|
||||
|
||||
// 文件/路径名保留字符
|
||||
var reservedCharacter = []string{"\\", "?", "*", "<", "\"", ":", ">", "/", "|"}
|
||||
|
||||
// ValidateLegalName 验证文件名/文件夹名是否合法
|
||||
func (fs *FileSystem) ValidateLegalName(ctx context.Context, name string) bool {
|
||||
// 是否包含保留字符
|
||||
for _, value := range reservedCharacter {
|
||||
if strings.Contains(name, value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 是否超出长度限制
|
||||
if len(name) >= 256 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 是否为空限制
|
||||
if len(name) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 结尾不能是空格
|
||||
if strings.HasSuffix(name, " ") {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidateFileSize 验证上传的文件大小是否超出限制
|
||||
func (fs *FileSystem) ValidateFileSize(ctx context.Context, size uint64) bool {
|
||||
if fs.Policy.MaxSize == 0 {
|
||||
return true
|
||||
}
|
||||
return size <= fs.Policy.MaxSize
|
||||
}
|
||||
|
||||
// ValidateCapacity 验证并扣除用户容量
|
||||
func (fs *FileSystem) ValidateCapacity(ctx context.Context, size uint64) bool {
|
||||
return fs.User.IncreaseStorage(size)
|
||||
}
|
||||
|
||||
// ValidateExtension 验证文件扩展名
|
||||
func (fs *FileSystem) ValidateExtension(ctx context.Context, fileName string) bool {
|
||||
// 不需要验证
|
||||
if len(fs.Policy.OptionsSerialized.FileType) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return util.IsInExtensionList(fs.Policy.OptionsSerialized.FileType, fileName)
|
||||
}
|
70
pkg/hashid/hash.go
Normal file
70
pkg/hashid/hash.go
Normal file
@ -0,0 +1,70 @@
|
||||
package hashid
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/conf"
|
||||
"github.com/speps/go-hashids"
|
||||
)
|
||||
|
||||
// ID类型
|
||||
const (
|
||||
ShareID = iota // 分享
|
||||
UserID // 用户
|
||||
FileID // 文件ID
|
||||
FolderID // 目录ID
|
||||
TagID // 标签ID
|
||||
PolicyID // 存储策略ID
|
||||
SourceLinkID
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrTypeNotMatch ID类型不匹配
|
||||
ErrTypeNotMatch = errors.New("mismatched ID type.")
|
||||
)
|
||||
|
||||
// HashEncode 对给定数据计算HashID
|
||||
func HashEncode(v []int) (string, error) {
|
||||
hd := hashids.NewData()
|
||||
hd.Salt = conf.SystemConfig.HashIDSalt
|
||||
|
||||
h, err := hashids.NewWithData(hd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
id, err := h.Encode(v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// HashDecode 对给定数据计算原始数据
|
||||
func HashDecode(raw string) ([]int, error) {
|
||||
hd := hashids.NewData()
|
||||
hd.Salt = conf.SystemConfig.HashIDSalt
|
||||
|
||||
h, err := hashids.NewWithData(hd)
|
||||
if err != nil {
|
||||
return []int{}, err
|
||||
}
|
||||
|
||||
return h.DecodeWithError(raw)
|
||||
|
||||
}
|
||||
|
||||
// HashID 计算数据库内主键对应的HashID
|
||||
func HashID(id uint, t int) string {
|
||||
v, _ := HashEncode([]int{int(id), t})
|
||||
return v
|
||||
}
|
||||
|
||||
// DecodeHashID 计算HashID对应的数据库ID
|
||||
func DecodeHashID(id string, t int) (uint, error) {
|
||||
v, _ := HashDecode(id)
|
||||
if len(v) != 2 || v[1] != t {
|
||||
return 0, ErrTypeNotMatch
|
||||
}
|
||||
return uint(v[0]), nil
|
||||
}
|
37
pkg/mocks/cachemock/mock.go
Normal file
37
pkg/mocks/cachemock/mock.go
Normal file
@ -0,0 +1,37 @@
|
||||
package cachemock
|
||||
|
||||
import "github.com/stretchr/testify/mock"
|
||||
|
||||
type CacheClientMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (c CacheClientMock) Set(key string, value interface{}, ttl int) error {
|
||||
return c.Called(key, value, ttl).Error(0)
|
||||
}
|
||||
|
||||
func (c CacheClientMock) Get(key string) (interface{}, bool) {
|
||||
args := c.Called(key)
|
||||
return args.Get(0), args.Bool(1)
|
||||
}
|
||||
|
||||
func (c CacheClientMock) Gets(keys []string, prefix string) (map[string]interface{}, []string) {
|
||||
args := c.Called(keys, prefix)
|
||||
return args.Get(0).(map[string]interface{}), args.Get(1).([]string)
|
||||
}
|
||||
|
||||
func (c CacheClientMock) Sets(values map[string]interface{}, prefix string) error {
|
||||
return c.Called(values).Error(0)
|
||||
}
|
||||
|
||||
func (c CacheClientMock) Delete(keys []string, prefix string) error {
|
||||
return c.Called(keys, prefix).Error(0)
|
||||
}
|
||||
|
||||
func (c CacheClientMock) Persist(path string) error {
|
||||
return c.Called(path).Error(0)
|
||||
}
|
||||
|
||||
func (c CacheClientMock) Restore(path string) error {
|
||||
return c.Called(path).Error(0)
|
||||
}
|
43
pkg/mocks/controllermock/c.go
Normal file
43
pkg/mocks/controllermock/c.go
Normal file
@ -0,0 +1,43 @@
|
||||
package controllermock
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/mq"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type SlaveControllerMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) HandleHeartBeat(pingReq *serializer.NodePingReq) (serializer.NodePingResp, error) {
|
||||
args := s.Called(pingReq)
|
||||
return args.Get(0).(serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) GetAria2Instance(s2 string) (common.Aria2, error) {
|
||||
args := s.Called(s2)
|
||||
return args.Get(0).(common.Aria2), args.Error(1)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) SendNotification(s3 string, s2 string, message mq.Message) error {
|
||||
args := s.Called(s3, s2, message)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) SubmitTask(s3 string, i interface{}, s2 string, f func(interface{})) error {
|
||||
args := s.Called(s3, i, s2, f)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) GetMasterInfo(s2 string) (*cluster.MasterInfo, error) {
|
||||
args := s.Called(s2)
|
||||
return args.Get(0).(*cluster.MasterInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (s SlaveControllerMock) GetPolicyOauthToken(s2 string, u uint) (string, error) {
|
||||
args := s.Called(s2, u)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
151
pkg/mocks/mocks.go
Normal file
151
pkg/mocks/mocks.go
Normal file
@ -0,0 +1,151 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/balancer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cluster"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/task"
|
||||
testMock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type NodePoolMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n NodePoolMock) BalanceNodeByFeature(feature string, lb balancer.Balancer) (error, cluster.Node) {
|
||||
args := n.Called(feature, lb)
|
||||
return args.Error(0), args.Get(1).(cluster.Node)
|
||||
}
|
||||
|
||||
func (n NodePoolMock) GetNodeByID(id uint) cluster.Node {
|
||||
args := n.Called(id)
|
||||
if res, ok := args.Get(0).(cluster.Node); ok {
|
||||
return res
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n NodePoolMock) Add(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n NodePoolMock) Delete(id uint) {
|
||||
n.Called(id)
|
||||
}
|
||||
|
||||
type NodeMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (n NodeMock) Init(node *model.Node) {
|
||||
n.Called(node)
|
||||
}
|
||||
|
||||
func (n NodeMock) IsFeatureEnabled(feature string) bool {
|
||||
args := n.Called(feature)
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n NodeMock) SubscribeStatusChange(callback func(isActive bool, id uint)) {
|
||||
n.Called(callback)
|
||||
}
|
||||
|
||||
func (n NodeMock) Ping(req *serializer.NodePingReq) (*serializer.NodePingResp, error) {
|
||||
args := n.Called(req)
|
||||
return args.Get(0).(*serializer.NodePingResp), args.Error(1)
|
||||
}
|
||||
|
||||
func (n NodeMock) IsActive() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n NodeMock) GetAria2Instance() common.Aria2 {
|
||||
args := n.Called()
|
||||
return args.Get(0).(common.Aria2)
|
||||
}
|
||||
|
||||
func (n NodeMock) ID() uint {
|
||||
args := n.Called()
|
||||
return args.Get(0).(uint)
|
||||
}
|
||||
|
||||
func (n NodeMock) Kill() {
|
||||
n.Called()
|
||||
}
|
||||
|
||||
func (n NodeMock) IsMater() bool {
|
||||
args := n.Called()
|
||||
return args.Bool(0)
|
||||
}
|
||||
|
||||
func (n NodeMock) MasterAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n NodeMock) SlaveAuthInstance() auth.Auth {
|
||||
args := n.Called()
|
||||
return args.Get(0).(auth.Auth)
|
||||
}
|
||||
|
||||
func (n NodeMock) DBModel() *model.Node {
|
||||
args := n.Called()
|
||||
return args.Get(0).(*model.Node)
|
||||
}
|
||||
|
||||
type Aria2Mock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Init() error {
|
||||
args := a.Called()
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) CreateTask(task *model.Download, options map[string]interface{}) (string, error) {
|
||||
args := a.Called(task, options)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Status(task *model.Download) (rpc.StatusInfo, error) {
|
||||
args := a.Called(task)
|
||||
return args.Get(0).(rpc.StatusInfo), args.Error(1)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Cancel(task *model.Download) error {
|
||||
args := a.Called(task)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) Select(task *model.Download, files []int) error {
|
||||
args := a.Called(task, files)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) GetConfig() model.Aria2Option {
|
||||
args := a.Called()
|
||||
return args.Get(0).(model.Aria2Option)
|
||||
}
|
||||
|
||||
func (a Aria2Mock) DeleteTempFile(download *model.Download) error {
|
||||
args := a.Called(download)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
type TaskPoolMock struct {
|
||||
testMock.Mock
|
||||
}
|
||||
|
||||
func (t TaskPoolMock) Add(num int) {
|
||||
t.Called(num)
|
||||
}
|
||||
|
||||
func (t TaskPoolMock) Submit(job task.Job) {
|
||||
t.Called(job)
|
||||
}
|
33
pkg/mocks/remoteclientmock/mock.go
Normal file
33
pkg/mocks/remoteclientmock/mock.go
Normal file
@ -0,0 +1,33 @@
|
||||
package remoteclientmock
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type RemoteClientMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (r *RemoteClientMock) CreateUploadSession(ctx context.Context, session *serializer.UploadSession, ttl int64, overwrite bool) error {
|
||||
return r.Called(ctx, session, ttl, overwrite).Error(0)
|
||||
}
|
||||
|
||||
func (r *RemoteClientMock) GetUploadURL(ttl int64, sessionID string) (string, string, error) {
|
||||
args := r.Called(ttl, sessionID)
|
||||
|
||||
return args.String(0), args.String(1), args.Error(2)
|
||||
}
|
||||
|
||||
func (r *RemoteClientMock) Upload(ctx context.Context, file fsctx.FileHeader) error {
|
||||
args := r.Called(ctx, file)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (r *RemoteClientMock) DeleteUploadSession(ctx context.Context, sessionID string) error {
|
||||
args := r.Called(ctx, sessionID)
|
||||
return args.Error(0)
|
||||
}
|
15
pkg/mocks/requestmock/request.go
Normal file
15
pkg/mocks/requestmock/request.go
Normal file
@ -0,0 +1,15 @@
|
||||
package requestmock
|
||||
|
||||
import (
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
)
|
||||
|
||||
type RequestMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (r RequestMock) Request(method, target string, body io.Reader, opts ...request.Option) *request.Response {
|
||||
return r.Called(method, target, body, opts).Get(0).(*request.Response)
|
||||
}
|
25
pkg/mocks/thumbmock/thumb.go
Normal file
25
pkg/mocks/thumbmock/thumb.go
Normal file
@ -0,0 +1,25 @@
|
||||
package thumbmock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/thumb"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"io"
|
||||
)
|
||||
|
||||
type GeneratorMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (g GeneratorMock) Generate(ctx context.Context, file io.Reader, src string, name string, options map[string]string) (*thumb.Result, error) {
|
||||
res := g.Called(ctx, file, src, name, options)
|
||||
return res.Get(0).(*thumb.Result), res.Error(1)
|
||||
}
|
||||
|
||||
func (g GeneratorMock) Priority() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (g GeneratorMock) EnableFlag() string {
|
||||
return "thumb_vips_enabled"
|
||||
}
|
21
pkg/mocks/wopimock/mock.go
Normal file
21
pkg/mocks/wopimock/mock.go
Normal file
@ -0,0 +1,21 @@
|
||||
package wopimock
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/wopi"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
type WopiClientMock struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (w *WopiClientMock) NewSession(user uint, file *model.File, action wopi.ActonType) (*wopi.Session, error) {
|
||||
args := w.Called(user, file, action)
|
||||
return args.Get(0).(*wopi.Session), args.Error(1)
|
||||
}
|
||||
|
||||
func (w *WopiClientMock) AvailableExts() []string {
|
||||
args := w.Called()
|
||||
return args.Get(0).([]string)
|
||||
}
|
160
pkg/mq/mq.go
Normal file
160
pkg/mq/mq.go
Normal file
@ -0,0 +1,160 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/common"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Message 消息事件正文
|
||||
type Message struct {
|
||||
// 消息触发者
|
||||
TriggeredBy string
|
||||
|
||||
// 事件标识
|
||||
Event string
|
||||
|
||||
// 消息正文
|
||||
Content interface{}
|
||||
}
|
||||
|
||||
type CallbackFunc func(Message)
|
||||
|
||||
// MQ 消息队列
|
||||
type MQ interface {
|
||||
rpc.Notifier
|
||||
|
||||
// 发布一个消息
|
||||
Publish(string, Message)
|
||||
|
||||
// 订阅一个消息主题
|
||||
Subscribe(string, int) <-chan Message
|
||||
|
||||
// 订阅一个消息主题,注册触发回调函数
|
||||
SubscribeCallback(string, CallbackFunc)
|
||||
|
||||
// 取消订阅一个消息主题
|
||||
Unsubscribe(string, <-chan Message)
|
||||
}
|
||||
|
||||
var GlobalMQ = NewMQ()
|
||||
|
||||
func NewMQ() MQ {
|
||||
return &inMemoryMQ{
|
||||
topics: make(map[string][]chan Message),
|
||||
callbacks: make(map[string][]CallbackFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
gob.Register(Message{})
|
||||
gob.Register([]rpc.Event{})
|
||||
}
|
||||
|
||||
type inMemoryMQ struct {
|
||||
topics map[string][]chan Message
|
||||
callbacks map[string][]CallbackFunc
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Publish(topic string, message Message) {
|
||||
i.RLock()
|
||||
subscribersChan, okChan := i.topics[topic]
|
||||
subscribersCallback, okCallback := i.callbacks[topic]
|
||||
i.RUnlock()
|
||||
|
||||
if okChan {
|
||||
go func(subscribersChan []chan Message) {
|
||||
for i := 0; i < len(subscribersChan); i++ {
|
||||
select {
|
||||
case subscribersChan[i] <- message:
|
||||
case <-time.After(time.Millisecond * 500):
|
||||
}
|
||||
}
|
||||
}(subscribersChan)
|
||||
|
||||
}
|
||||
|
||||
if okCallback {
|
||||
for i := 0; i < len(subscribersCallback); i++ {
|
||||
go subscribersCallback[i](message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Subscribe(topic string, buffer int) <-chan Message {
|
||||
ch := make(chan Message, buffer)
|
||||
i.Lock()
|
||||
i.topics[topic] = append(i.topics[topic], ch)
|
||||
i.Unlock()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) SubscribeCallback(topic string, callbackFunc CallbackFunc) {
|
||||
i.Lock()
|
||||
i.callbacks[topic] = append(i.callbacks[topic], callbackFunc)
|
||||
i.Unlock()
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Unsubscribe(topic string, sub <-chan Message) {
|
||||
i.Lock()
|
||||
defer i.Unlock()
|
||||
|
||||
subscribers, ok := i.topics[topic]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var newSubs []chan Message
|
||||
for _, subscriber := range subscribers {
|
||||
if subscriber == sub {
|
||||
continue
|
||||
}
|
||||
newSubs = append(newSubs, subscriber)
|
||||
}
|
||||
|
||||
i.topics[topic] = newSubs
|
||||
}
|
||||
|
||||
func (i *inMemoryMQ) Aria2Notify(events []rpc.Event, status int) {
|
||||
for _, event := range events {
|
||||
i.Publish(event.Gid, Message{
|
||||
TriggeredBy: event.Gid,
|
||||
Event: strconv.FormatInt(int64(status), 10),
|
||||
Content: events,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// OnDownloadStart 下载开始
|
||||
func (i *inMemoryMQ) OnDownloadStart(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Downloading)
|
||||
}
|
||||
|
||||
// OnDownloadPause 下载暂停
|
||||
func (i *inMemoryMQ) OnDownloadPause(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Paused)
|
||||
}
|
||||
|
||||
// OnDownloadStop 下载停止
|
||||
func (i *inMemoryMQ) OnDownloadStop(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Canceled)
|
||||
}
|
||||
|
||||
// OnDownloadComplete 下载完成
|
||||
func (i *inMemoryMQ) OnDownloadComplete(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Complete)
|
||||
}
|
||||
|
||||
// OnDownloadError 下载出错
|
||||
func (i *inMemoryMQ) OnDownloadError(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Error)
|
||||
}
|
||||
|
||||
// OnBtDownloadComplete BT下载完成
|
||||
func (i *inMemoryMQ) OnBtDownloadComplete(events []rpc.Event) {
|
||||
i.Aria2Notify(events, common.Complete)
|
||||
}
|
43
pkg/payment/alipay.go
Normal file
43
pkg/payment/alipay.go
Normal file
@ -0,0 +1,43 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
alipay "github.com/smartwalle/alipay/v3"
|
||||
)
|
||||
|
||||
// Alipay 支付宝当面付支付处理
|
||||
type Alipay struct {
|
||||
Client *alipay.Client
|
||||
}
|
||||
|
||||
// Create 创建订单
|
||||
func (pay *Alipay) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
|
||||
gateway, _ := url.Parse("/api/v3/callback/alipay")
|
||||
var p = alipay.TradePreCreate{
|
||||
Trade: alipay.Trade{
|
||||
NotifyURL: model.GetSiteURL().ResolveReference(gateway).String(),
|
||||
Subject: order.Name,
|
||||
OutTradeNo: order.OrderNo,
|
||||
TotalAmount: fmt.Sprintf("%.2f", float64(order.Price*order.Num)/100),
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := order.Create(); err != nil {
|
||||
return nil, ErrInsertOrder.WithError(err)
|
||||
}
|
||||
|
||||
res, err := pay.Client.TradePreCreate(p)
|
||||
if err != nil {
|
||||
return nil, ErrIssueOrder.WithError(err)
|
||||
}
|
||||
|
||||
return &OrderCreateRes{
|
||||
Payment: true,
|
||||
QRCode: res.QRCode,
|
||||
ID: order.OrderNo,
|
||||
}, nil
|
||||
}
|
93
pkg/payment/custom.go
Normal file
93
pkg/payment/custom.go
Normal file
@ -0,0 +1,93 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/auth"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/request"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/qiniu/go-sdk/v7/sms/bytes"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Custom payment client
|
||||
type Custom struct {
|
||||
client request.Client
|
||||
endpoint string
|
||||
authClient auth.Auth
|
||||
}
|
||||
|
||||
const (
|
||||
paymentTTL = 3600 * 24 // 24h
|
||||
CallbackSessionPrefix = "custom_payment_callback_"
|
||||
)
|
||||
|
||||
func newCustomClient(endpoint, secret string) *Custom {
|
||||
authClient := auth.HMACAuth{
|
||||
SecretKey: []byte(secret),
|
||||
}
|
||||
return &Custom{
|
||||
endpoint: endpoint,
|
||||
authClient: auth.General,
|
||||
client: request.NewClient(
|
||||
request.WithCredential(authClient, paymentTTL),
|
||||
request.WithMasterMeta(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// Request body from Cloudreve to create a new payment
|
||||
type NewCustomOrderRequest struct {
|
||||
Name string `json:"name"` // Order name
|
||||
OrderNo string `json:"order_no"` // Order number
|
||||
NotifyURL string `json:"notify_url"` // Payment callback url
|
||||
Amount int64 `json:"amount"` // Order total amount
|
||||
}
|
||||
|
||||
// Create a new payment
|
||||
func (pay *Custom) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
|
||||
callbackID := uuid.Must(uuid.NewV4())
|
||||
gateway, _ := url.Parse(fmt.Sprintf("/api/v3/callback/custom/%s/%s", order.OrderNo, callbackID))
|
||||
callback, err := auth.SignURI(pay.authClient, model.GetSiteURL().ResolveReference(gateway).String(), paymentTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to sign callback url: %w", err)
|
||||
}
|
||||
|
||||
cache.Set(CallbackSessionPrefix+callbackID.String(), order.OrderNo, paymentTTL)
|
||||
|
||||
body := &NewCustomOrderRequest{
|
||||
Name: order.Name,
|
||||
OrderNo: order.OrderNo,
|
||||
NotifyURL: callback.String(),
|
||||
Amount: int64(order.Price * order.Num),
|
||||
}
|
||||
bodyJson, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode body: %w", err)
|
||||
}
|
||||
|
||||
res, err := pay.client.Request("POST", pay.endpoint, bytes.NewReader(bodyJson)).
|
||||
CheckHTTPResponse(http.StatusOK).DecodeResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request payment gateway: %w", err)
|
||||
}
|
||||
|
||||
if res.Code != 0 {
|
||||
return nil, errors.New(res.Error)
|
||||
}
|
||||
|
||||
if _, err := order.Create(); err != nil {
|
||||
return nil, ErrInsertOrder.WithError(err)
|
||||
}
|
||||
|
||||
return &OrderCreateRes{
|
||||
Payment: true,
|
||||
QRCode: res.Data.(string),
|
||||
ID: order.OrderNo,
|
||||
}, nil
|
||||
}
|
171
pkg/payment/order.go
Normal file
171
pkg/payment/order.go
Normal file
@ -0,0 +1,171 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/iGoogle-ink/gopay/wechat/v3"
|
||||
"github.com/qingwg/payjs"
|
||||
"github.com/smartwalle/alipay/v3"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrUnknownPaymentMethod 未知支付方式
|
||||
ErrUnknownPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "Unknown payment method", nil)
|
||||
// ErrUnsupportedPaymentMethod 未知支付方式
|
||||
ErrUnsupportedPaymentMethod = serializer.NewError(serializer.CodeInternalSetting, "This order cannot be paid with this method", nil)
|
||||
// ErrInsertOrder 无法插入订单记录
|
||||
ErrInsertOrder = serializer.NewError(serializer.CodeDBError, "Failed to insert order record", nil)
|
||||
// ErrScoreNotEnough 积分不足
|
||||
ErrScoreNotEnough = serializer.NewError(serializer.CodeInsufficientCredit, "", nil)
|
||||
// ErrCreateStoragePack 无法创建容量包
|
||||
ErrCreateStoragePack = serializer.NewError(serializer.CodeDBError, "Failed to create storage pack record", nil)
|
||||
// ErrGroupConflict 用户组冲突
|
||||
ErrGroupConflict = serializer.NewError(serializer.CodeGroupConflict, "", nil)
|
||||
// ErrGroupInvalid 用户组冲突
|
||||
ErrGroupInvalid = serializer.NewError(serializer.CodeGroupInvalid, "", nil)
|
||||
// ErrAdminFulfillGroup 管理员无法购买用户组
|
||||
ErrAdminFulfillGroup = serializer.NewError(serializer.CodeFulfillAdminGroup, "", nil)
|
||||
// ErrUpgradeGroup 用户组冲突
|
||||
ErrUpgradeGroup = serializer.NewError(serializer.CodeDBError, "Failed to update user's group", nil)
|
||||
// ErrUInitPayment 无法初始化支付实例
|
||||
ErrUInitPayment = serializer.NewError(serializer.CodeInternalSetting, "Failed to initialize payment client", nil)
|
||||
// ErrIssueOrder 订单接口请求失败
|
||||
ErrIssueOrder = serializer.NewError(serializer.CodeInternalSetting, "Failed to create order", nil)
|
||||
// ErrOrderNotFound 订单不存在
|
||||
ErrOrderNotFound = serializer.NewError(serializer.CodeNotFound, "", nil)
|
||||
)
|
||||
|
||||
// Pay 支付处理接口
|
||||
type Pay interface {
|
||||
Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error)
|
||||
}
|
||||
|
||||
// OrderCreateRes 订单创建结果
|
||||
type OrderCreateRes struct {
|
||||
Payment bool `json:"payment"` // 是否需要支付
|
||||
ID string `json:"id,omitempty"` // 订单号
|
||||
QRCode string `json:"qr_code,omitempty"` // 支付二维码指向的地址
|
||||
}
|
||||
|
||||
// NewPaymentInstance 获取新的支付实例
|
||||
func NewPaymentInstance(method string) (Pay, error) {
|
||||
switch method {
|
||||
case "score":
|
||||
return &ScorePayment{}, nil
|
||||
case "alipay":
|
||||
options := model.GetSettingByNames("alipay_enabled", "appid", "appkey", "shopid")
|
||||
if options["alipay_enabled"] != "1" {
|
||||
return nil, ErrUnknownPaymentMethod
|
||||
}
|
||||
|
||||
// 初始化支付宝客户端
|
||||
var client, err = alipay.New(options["appid"], options["appkey"], true)
|
||||
if err != nil {
|
||||
return nil, ErrUInitPayment.WithError(err)
|
||||
}
|
||||
|
||||
// 加载支付宝公钥
|
||||
err = client.LoadAliPayPublicKey(options["shopid"])
|
||||
if err != nil {
|
||||
return nil, ErrUInitPayment.WithError(err)
|
||||
}
|
||||
|
||||
return &Alipay{Client: client}, nil
|
||||
case "payjs":
|
||||
options := model.GetSettingByNames("payjs_enabled", "payjs_secret", "payjs_id")
|
||||
if options["payjs_enabled"] != "1" {
|
||||
return nil, ErrUnknownPaymentMethod
|
||||
}
|
||||
|
||||
callback, _ := url.Parse("/api/v3/callback/payjs")
|
||||
payjsConfig := &payjs.Config{
|
||||
Key: options["payjs_secret"],
|
||||
MchID: options["payjs_id"],
|
||||
NotifyUrl: model.GetSiteURL().ResolveReference(callback).String(),
|
||||
}
|
||||
|
||||
return &PayJSClient{Client: payjs.New(payjsConfig)}, nil
|
||||
case "wechat":
|
||||
options := model.GetSettingByNames("wechat_enabled", "wechat_appid", "wechat_mchid", "wechat_serial_no", "wechat_api_key", "wechat_pk_content")
|
||||
if options["wechat_enabled"] != "1" {
|
||||
return nil, ErrUnknownPaymentMethod
|
||||
}
|
||||
client, err := wechat.NewClientV3(options["wechat_appid"], options["wechat_mchid"], options["wechat_serial_no"], options["wechat_api_key"], options["wechat_pk_content"])
|
||||
if err != nil {
|
||||
return nil, ErrUInitPayment.WithError(err)
|
||||
}
|
||||
|
||||
return &Wechat{Client: client, ApiV3Key: options["wechat_api_key"]}, nil
|
||||
case "custom":
|
||||
options := model.GetSettingByNames("custom_payment_enabled", "custom_payment_endpoint", "custom_payment_secret")
|
||||
if !model.IsTrueVal(options["custom_payment_enabled"]) {
|
||||
return nil, ErrUnknownPaymentMethod
|
||||
}
|
||||
|
||||
return newCustomClient(options["custom_payment_endpoint"], options["custom_payment_secret"]), nil
|
||||
default:
|
||||
return nil, ErrUnknownPaymentMethod
|
||||
}
|
||||
}
|
||||
|
||||
// NewOrder 创建新订单
|
||||
func NewOrder(pack *serializer.PackProduct, group *serializer.GroupProducts, num int, method string, user *model.User) (*OrderCreateRes, error) {
|
||||
// 获取支付实例
|
||||
pay, err := NewPaymentInstance(method)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
orderType int
|
||||
productID int64
|
||||
title string
|
||||
price int
|
||||
)
|
||||
if pack != nil {
|
||||
orderType = model.PackOrderType
|
||||
productID = pack.ID
|
||||
title = pack.Name
|
||||
price = pack.Price
|
||||
} else if group != nil {
|
||||
if err := checkGroupUpgrade(user, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
orderType = model.GroupOrderType
|
||||
productID = group.ID
|
||||
title = group.Name
|
||||
price = group.Price
|
||||
} else {
|
||||
orderType = model.ScoreOrderType
|
||||
productID = 0
|
||||
title = fmt.Sprintf("%d 积分", num)
|
||||
price = model.GetIntSetting("score_price", 1)
|
||||
}
|
||||
|
||||
// 创建订单记录
|
||||
order := &model.Order{
|
||||
UserID: user.ID,
|
||||
OrderNo: orderID(),
|
||||
Type: orderType,
|
||||
Method: method,
|
||||
ProductID: productID,
|
||||
Num: num,
|
||||
Name: fmt.Sprintf("%s - %s", model.GetSettingByName("siteName"), title),
|
||||
Price: price,
|
||||
Status: model.OrderUnpaid,
|
||||
}
|
||||
|
||||
return pay.Create(order, pack, group, user)
|
||||
}
|
||||
|
||||
func orderID() string {
|
||||
return fmt.Sprintf("%s%d",
|
||||
time.Now().Format("20060102150405"),
|
||||
100000+rand.Intn(900000),
|
||||
)
|
||||
}
|
31
pkg/payment/payjs.go
Normal file
31
pkg/payment/payjs.go
Normal file
@ -0,0 +1,31 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/qingwg/payjs"
|
||||
)
|
||||
|
||||
// PayJSClient PayJS支付处理
|
||||
type PayJSClient struct {
|
||||
Client *payjs.PayJS
|
||||
}
|
||||
|
||||
// Create 创建订单
|
||||
func (pay *PayJSClient) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
|
||||
if _, err := order.Create(); err != nil {
|
||||
return nil, ErrInsertOrder.WithError(err)
|
||||
}
|
||||
|
||||
PayNative := pay.Client.GetNative()
|
||||
res, err := PayNative.Create(int64(order.Price*order.Num), order.Name, order.OrderNo, "", "")
|
||||
if err != nil {
|
||||
return nil, ErrIssueOrder.WithError(err)
|
||||
}
|
||||
|
||||
return &OrderCreateRes{
|
||||
Payment: true,
|
||||
QRCode: res.CodeUrl,
|
||||
ID: order.OrderNo,
|
||||
}, nil
|
||||
}
|
137
pkg/payment/purchase.go
Normal file
137
pkg/payment/purchase.go
Normal file
@ -0,0 +1,137 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GivePack 创建容量包
|
||||
func GivePack(user *model.User, packInfo *serializer.PackProduct, num int) error {
|
||||
timeNow := time.Now()
|
||||
expires := timeNow.Add(time.Duration(packInfo.Time*int64(num)) * time.Second)
|
||||
pack := model.StoragePack{
|
||||
Name: packInfo.Name,
|
||||
UserID: user.ID,
|
||||
ActiveTime: &timeNow,
|
||||
ExpiredTime: &expires,
|
||||
Size: packInfo.Size,
|
||||
}
|
||||
if _, err := pack.Create(); err != nil {
|
||||
return ErrCreateStoragePack.WithError(err)
|
||||
}
|
||||
cache.Deletes([]string{strconv.FormatUint(uint64(user.ID), 10)}, "pack_size_")
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkGroupUpgrade(user *model.User, groupInfo *serializer.GroupProducts) error {
|
||||
if user.Group.ID == 1 {
|
||||
return ErrAdminFulfillGroup
|
||||
}
|
||||
|
||||
// 检查用户是否已有未过期用户
|
||||
if user.PreviousGroupID != 0 && user.GroupID != groupInfo.GroupID {
|
||||
return ErrGroupConflict
|
||||
}
|
||||
|
||||
// 用户组不能相同
|
||||
if user.GroupID == groupInfo.GroupID && user.PreviousGroupID == 0 {
|
||||
return ErrGroupInvalid
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GiveGroup 升级用户组
|
||||
func GiveGroup(user *model.User, groupInfo *serializer.GroupProducts, num int) error {
|
||||
if err := checkGroupUpgrade(user, groupInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeNow := time.Now()
|
||||
expires := timeNow.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second)
|
||||
if user.PreviousGroupID != 0 {
|
||||
expires = user.GroupExpires.Add(time.Duration(groupInfo.Time*int64(num)) * time.Second)
|
||||
}
|
||||
|
||||
if err := user.UpgradeGroup(groupInfo.GroupID, &expires); err != nil {
|
||||
return ErrUpgradeGroup.WithError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GiveScore 积分充值
|
||||
func GiveScore(user *model.User, num int) error {
|
||||
user.AddScore(num)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GiveProduct “发货”
|
||||
func GiveProduct(user *model.User, pack *serializer.PackProduct, group *serializer.GroupProducts, num int) error {
|
||||
if pack != nil {
|
||||
return GivePack(user, pack, num)
|
||||
} else if group != nil {
|
||||
return GiveGroup(user, group, num)
|
||||
} else {
|
||||
return GiveScore(user, num)
|
||||
}
|
||||
}
|
||||
|
||||
// OrderPaid 订单已支付处理
|
||||
func OrderPaid(orderNo string) error {
|
||||
order, err := model.GetOrderByNo(orderNo)
|
||||
if err != nil || order.Status == model.OrderPaid {
|
||||
return ErrOrderNotFound.WithError(err)
|
||||
}
|
||||
|
||||
// 更新订单状态为 已支付
|
||||
order.UpdateStatus(model.OrderPaid)
|
||||
|
||||
user, err := model.GetActiveUserByID(order.UserID)
|
||||
if err != nil {
|
||||
return serializer.NewError(serializer.CodeUserNotFound, "", err)
|
||||
}
|
||||
|
||||
// 查询商品
|
||||
options := model.GetSettingByNames("pack_data", "group_sell_data")
|
||||
|
||||
var (
|
||||
packs []serializer.PackProduct
|
||||
groups []serializer.GroupProducts
|
||||
)
|
||||
if err := json.Unmarshal([]byte(options["pack_data"]), &packs); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(options["group_sell_data"]), &groups); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 查找要购买的商品
|
||||
var (
|
||||
pack *serializer.PackProduct
|
||||
group *serializer.GroupProducts
|
||||
)
|
||||
if order.Type == model.GroupOrderType {
|
||||
for _, v := range groups {
|
||||
if v.ID == order.ProductID {
|
||||
group = &v
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if order.Type == model.PackOrderType {
|
||||
for _, v := range packs {
|
||||
if v.ID == order.ProductID {
|
||||
pack = &v
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// "发货"
|
||||
return GiveProduct(&user, pack, group, order.Num)
|
||||
|
||||
}
|
45
pkg/payment/score.go
Normal file
45
pkg/payment/score.go
Normal file
@ -0,0 +1,45 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
)
|
||||
|
||||
// ScorePayment 积分支付处理
|
||||
type ScorePayment struct {
|
||||
}
|
||||
|
||||
// Create 创建新订单
|
||||
func (pay *ScorePayment) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
|
||||
if pack != nil {
|
||||
order.Price = pack.Score
|
||||
} else {
|
||||
order.Price = group.Score
|
||||
}
|
||||
|
||||
// 检查此订单是否可用积分支付
|
||||
if order.Price == 0 {
|
||||
return nil, ErrUnsupportedPaymentMethod
|
||||
}
|
||||
|
||||
// 扣除用户积分
|
||||
if !user.PayScore(order.Price * order.Num) {
|
||||
return nil, ErrScoreNotEnough
|
||||
}
|
||||
|
||||
// 商品“发货”
|
||||
if err := GiveProduct(user, pack, group, order.Num); err != nil {
|
||||
user.AddScore(order.Price * order.Num)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建订单记录
|
||||
order.Status = model.OrderPaid
|
||||
if _, err := order.Create(); err != nil {
|
||||
return nil, ErrInsertOrder.WithError(err)
|
||||
}
|
||||
|
||||
return &OrderCreateRes{
|
||||
Payment: false,
|
||||
}, nil
|
||||
}
|
88
pkg/payment/wechat.go
Normal file
88
pkg/payment/wechat.go
Normal file
@ -0,0 +1,88 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"errors"
|
||||
model "github.com/cloudreve/Cloudreve/v3/models"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/cache"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/serializer"
|
||||
"github.com/cloudreve/Cloudreve/v3/pkg/util"
|
||||
"github.com/iGoogle-ink/gopay"
|
||||
"github.com/iGoogle-ink/gopay/wechat/v3"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Wechat 微信扫码支付接口
|
||||
type Wechat struct {
|
||||
Client *wechat.ClientV3
|
||||
ApiV3Key string
|
||||
}
|
||||
|
||||
// Create 创建订单
|
||||
func (pay *Wechat) Create(order *model.Order, pack *serializer.PackProduct, group *serializer.GroupProducts, user *model.User) (*OrderCreateRes, error) {
|
||||
gateway, _ := url.Parse("/api/v3/callback/wechat")
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.
|
||||
Set("description", order.Name).
|
||||
Set("out_trade_no", order.OrderNo).
|
||||
Set("notify_url", model.GetSiteURL().ResolveReference(gateway).String()).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", int64(order.Price*order.Num)).
|
||||
Set("currency", "CNY")
|
||||
})
|
||||
|
||||
wxRsp, err := pay.Client.V3TransactionNative(bm)
|
||||
if err != nil {
|
||||
return nil, ErrIssueOrder.WithError(err)
|
||||
}
|
||||
|
||||
if wxRsp.Code == wechat.Success {
|
||||
if _, err := order.Create(); err != nil {
|
||||
return nil, ErrInsertOrder.WithError(err)
|
||||
}
|
||||
|
||||
return &OrderCreateRes{
|
||||
Payment: true,
|
||||
QRCode: wxRsp.Response.CodeUrl,
|
||||
ID: order.OrderNo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, ErrIssueOrder.WithError(errors.New(wxRsp.Error))
|
||||
}
|
||||
|
||||
// GetPlatformCert 获取微信平台证书
|
||||
func (pay *Wechat) GetPlatformCert() string {
|
||||
if cert, ok := cache.Get("wechat_platform_cert"); ok {
|
||||
return cert.(string)
|
||||
}
|
||||
|
||||
res, err := pay.Client.GetPlatformCerts()
|
||||
if err == nil {
|
||||
// 使用反馈证书中启用时间较晚的
|
||||
var (
|
||||
currentLatest *time.Time
|
||||
currentCert string
|
||||
)
|
||||
for _, cert := range res.Certs {
|
||||
effectiveTime, err := time.Parse("2006-01-02T15:04:05-0700", cert.EffectiveTime)
|
||||
if err != nil {
|
||||
if currentLatest == nil {
|
||||
currentLatest = &effectiveTime
|
||||
currentCert = cert.PublicKey
|
||||
continue
|
||||
}
|
||||
if currentLatest.Before(effectiveTime) {
|
||||
currentLatest = &effectiveTime
|
||||
currentCert = cert.PublicKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cache.Set("wechat_platform_cert", currentCert, 3600*10)
|
||||
return currentCert
|
||||
}
|
||||
|
||||
util.Log().Debug("Failed to get Wechat Pay platform certificate: %s", err)
|
||||
return ""
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user