diff --git a/storage/webdav/client._test.go b/storage/webdav/client._test.go new file mode 100644 index 0000000..8808881 --- /dev/null +++ b/storage/webdav/client._test.go @@ -0,0 +1,130 @@ +package webdav + +import ( + "context" + "net/http/httptest" + "os" + "path" + "path/filepath" + "strings" + "testing" + + "golang.org/x/net/webdav" +) + +func setupWebDAVServer(t *testing.T) (*httptest.Server, string) { + t.Helper() + tempDir, err := os.MkdirTemp("", "webdav_test") + if err != nil { + t.Fatalf("mk temp dir failed: %v", err) + } + + handler := &webdav.Handler{ + Prefix: "/", + FileSystem: webdav.Dir(tempDir), + LockSystem: webdav.NewMemLS(), + } + + server := httptest.NewServer(handler) + return server, tempDir +} + +func TestMkDirAndExists(t *testing.T) { + server, tempDir := setupWebDAVServer(t) + defer os.RemoveAll(tempDir) + defer server.Close() + + client := NewClient(server.URL, "", "", nil) + ctx := context.Background() + + testpaths := []string{"testdir", "testdir/subdir", "testdir/子目录", "/testdir/测试路径/测试路径2"} + for _, p := range testpaths { + exists, err := client.Exists(ctx, p) + if err != nil { + t.Fatalf("Call Exists Err: %v", err) + } + if exists { + t.Fatalf("Dir should not exist") + } + + if err := client.MkDir(ctx, p); err != nil { + t.Fatalf("Call MkDir Err: %v", err) + } + + exists, err = client.Exists(ctx, p) + if err != nil { + t.Fatalf("Call Exists Err: %v", err) + } + if !exists { + t.Fatalf("Dir should exist") + } + } + +} + +func TestWriteFile(t *testing.T) { + server, tempDir := setupWebDAVServer(t) + defer os.RemoveAll(tempDir) + defer server.Close() + + client := NewClient(server.URL, "", "", nil) + ctx := context.Background() + + testCases := []struct { + remotePath string + content string + }{ + { + remotePath: "hello.txt", + content: "Hello webdav", + }, + { + remotePath: "nested/dir/test.txt", + content: "Nested file", + }, + { + remotePath: "empty.txt", + content: "", + }, + { + remotePath: "unicode.txt", + content: "测试", + }, + } + + for _, tc := range testCases { + t.Run(tc.remotePath, func(t *testing.T) { + dir := path.Dir(tc.remotePath) + if dir != "." { + if err := client.MkDir(ctx, dir); err != nil { + t.Fatalf("创建目录 %s 失败: %v", dir, err) + } + } + + if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(tc.content)); err != nil { + t.Fatalf("写入文件 %s 失败: %v", tc.remotePath, err) + } + + localPath := filepath.Join(tempDir, tc.remotePath) + data, err := os.ReadFile(localPath) + if err != nil { + t.Fatalf("读取文件 %s 失败: %v", localPath, err) + } + if string(data) != tc.content { + t.Fatalf("文件内容不匹配: got %s, want %s", string(data), tc.content) + } + + appended := tc.content + " Overwritten." + if err := client.WriteFile(ctx, tc.remotePath, strings.NewReader(appended)); err != nil { + t.Fatalf("覆盖写入文件 %s 失败: %v", tc.remotePath, err) + } + data, err = os.ReadFile(localPath) + if err != nil { + t.Fatalf("读取覆盖后的文件 %s 失败: %v", localPath, err) + } + if string(data) != appended { + t.Fatalf("文件覆盖后的内容不匹配: got %s, want %s", string(data), appended) + } + }) + } +} diff --git a/storage/webdav/client.go b/storage/webdav/client.go index b4ddcc1..8c2b4ae 100644 --- a/storage/webdav/client.go +++ b/storage/webdav/client.go @@ -48,18 +48,55 @@ func (c *Client) doRequest(ctx context.Context, method, url string, body io.Read return c.httpClient.Do(req) } -func (c *Client) MkDir(ctx context.Context, dirPath string) error { - url := c.BaseURL + dirPath - resp, err := c.doRequest(ctx, "MKCOL", url, nil) +func (c *Client) Exists(ctx context.Context, remotePath string) (bool, error) { + url := c.BaseURL + remotePath + resp, err := c.doRequest(ctx, "PROPFIND", url, nil) if err != nil { - return err + return false, err } defer resp.Body.Close() if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return true, nil + } + if resp.StatusCode == http.StatusNotFound { + return false, nil + } + return false, fmt.Errorf("PROPFIND: %s", resp.Status) +} + +func (c *Client) MkDir(ctx context.Context, dirPath string) error { + dirPath = strings.Trim(dirPath, "/") + if dirPath == "" { return nil } - return fmt.Errorf("MKCOL: %s", resp.Status) + parts := strings.Split(dirPath, "/") + currentPath := "" + for i, part := range parts { + if i > 0 { + currentPath += "/" + } + currentPath += part + + exists, err := c.Exists(ctx, currentPath) + if err != nil { + return err + } + if exists { + continue + } + url := c.BaseURL + currentPath + resp, err := c.doRequest(ctx, "MKCOL", url, nil) + if err != nil { + return err + } + resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("MKCOL %s: %s", currentPath, resp.Status) + } + } + return nil } func (c *Client) WriteFile(ctx context.Context, remotePath string, content io.Reader) error {