Fix the dependency issue (#231)

This commit is contained in:
Robbie Zhang
2018-06-21 12:09:42 -07:00
committed by GitHub
parent 027b76651d
commit 6ec1098bb8
16629 changed files with 74837 additions and 4975021 deletions

View File

@@ -1,11 +0,0 @@
pkg/ is a collection of utility packages used by the Docker project without being specific to its internals.
Utility packages are kept separate from the docker core codebase to keep it as small and concise as possible.
If some utilities grow larger and their APIs stabilize, they may be moved to their own repository under the
Docker organization, to facilitate re-use by other projects. However that is not the priority.
The directory `pkg` is named after the same directory in the camlistore project. Since Brad is a core
Go maintainer, we thought it made sense to copy his methods for organizing Go code :) Thanks Brad!
Because utility packages are small and neatly separated from the rest of the codebase, they are a good
place to start for aspiring maintainers and contributors. Get in touch if you want to help maintain them!

View File

@@ -1,76 +0,0 @@
// Package aaparser is a convenience package interacting with `apparmor_parser`.
package aaparser
import (
"fmt"
"os/exec"
"path/filepath"
"strconv"
"strings"
)
const (
binary = "apparmor_parser"
)
// GetVersion returns the major and minor version of apparmor_parser.
func GetVersion() (int, int, error) {
output, err := cmd("", "--version")
if err != nil {
return -1, -1, err
}
return parseVersion(string(output))
}
// LoadProfile runs `apparmor_parser -r -W` on a specified apparmor profile to
// replace and write it to disk.
func LoadProfile(profilePath string) error {
_, err := cmd(filepath.Dir(profilePath), "-r", "-W", filepath.Base(profilePath))
if err != nil {
return err
}
return nil
}
// cmd runs `apparmor_parser` with the passed arguments.
func cmd(dir string, arg ...string) (string, error) {
c := exec.Command(binary, arg...)
c.Dir = dir
output, err := c.CombinedOutput()
if err != nil {
return "", fmt.Errorf("running `%s %s` failed with output: %s\nerror: %v", c.Path, strings.Join(c.Args, " "), string(output), err)
}
return string(output), nil
}
// parseVersion takes the output from `apparmor_parser --version` and returns
// the major and minor version for `apparor_parser`.
func parseVersion(output string) (int, int, error) {
// output is in the form of the following:
// AppArmor parser version 2.9.1
// Copyright (C) 1999-2008 Novell Inc.
// Copyright 2009-2012 Canonical Ltd.
lines := strings.SplitN(output, "\n", 2)
words := strings.Split(lines[0], " ")
version := words[len(words)-1]
// split by major minor version
v := strings.Split(version, ".")
if len(v) < 2 {
return -1, -1, fmt.Errorf("parsing major minor version failed for output: `%s`", output)
}
majorVersion, err := strconv.Atoi(v[0])
if err != nil {
return -1, -1, err
}
minorVersion, err := strconv.Atoi(v[1])
if err != nil {
return -1, -1, err
}
return majorVersion, minorVersion, nil
}

View File

@@ -1,65 +0,0 @@
package aaparser
import (
"testing"
)
type versionExpected struct {
output string
major int
minor int
}
func TestParseVersion(t *testing.T) {
versions := []versionExpected{
{
output: `AppArmor parser version 2.10
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 10,
},
{
output: `AppArmor parser version 2.8
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 8,
},
{
output: `AppArmor parser version 2.20
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 20,
},
{
output: `AppArmor parser version 2.05
Copyright (C) 1999-2008 Novell Inc.
Copyright 2009-2012 Canonical Ltd.
`,
major: 2,
minor: 5,
},
}
for _, v := range versions {
major, minor, err := parseVersion(v.output)
if err != nil {
t.Fatalf("expected error to be nil for %#v, got: %v", v, err)
}
if major != v.major {
t.Fatalf("expected major version to be %d, was %d, for: %#v\n", v.major, major, v)
}
if minor != v.minor {
t.Fatalf("expected minor version to be %d, was %d, for: %#v\n", v.minor, minor, v)
}
}
}

View File

@@ -1 +0,0 @@
This code provides helper functions for dealing with archive files.

File diff suppressed because it is too large Load Diff

View File

@@ -1,60 +0,0 @@
// +build !windows
package archive
import (
"os"
"testing"
)
func TestCanonicalTarNameForPath(t *testing.T) {
cases := []struct{ in, expected string }{
{"foo", "foo"},
{"foo/bar", "foo/bar"},
{"foo/dir/", "foo/dir/"},
}
for _, v := range cases {
if out, err := CanonicalTarNameForPath(v.in); err != nil {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestCanonicalTarName(t *testing.T) {
cases := []struct {
in string
isDir bool
expected string
}{
{"foo", false, "foo"},
{"foo", true, "foo/"},
{"foo/bar", false, "foo/bar"},
{"foo/bar", true, "foo/bar/"},
}
for _, v := range cases {
if out, err := canonicalTarName(v.in, v.isDir); err != nil {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestChmodTarEntry(t *testing.T) {
cases := []struct {
in, expected os.FileMode
}{
{0000, 0000},
{0777, 0777},
{0644, 0644},
{0755, 0755},
{0444, 0444},
}
for _, v := range cases {
if out := chmodTarEntry(v.in); out != v.expected {
t.Fatalf("wrong chmod. expected:%v got:%v", v.expected, out)
}
}
}

View File

@@ -1,87 +0,0 @@
// +build windows
package archive
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
)
func TestCopyFileWithInvalidDest(t *testing.T) {
folder, err := ioutil.TempDir("", "docker-archive-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(folder)
dest := "c:dest"
srcFolder := filepath.Join(folder, "src")
src := filepath.Join(folder, "src", "src")
err = os.MkdirAll(srcFolder, 0740)
if err != nil {
t.Fatal(err)
}
ioutil.WriteFile(src, []byte("content"), 0777)
err = CopyWithTar(src, dest)
if err == nil {
t.Fatalf("archiver.CopyWithTar should throw an error on invalid dest.")
}
}
func TestCanonicalTarNameForPath(t *testing.T) {
cases := []struct {
in, expected string
shouldFail bool
}{
{"foo", "foo", false},
{"foo/bar", "___", true}, // unix-styled windows path must fail
{`foo\bar`, "foo/bar", false},
}
for _, v := range cases {
if out, err := CanonicalTarNameForPath(v.in); err != nil && !v.shouldFail {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if v.shouldFail && err == nil {
t.Fatalf("canonical path call should have failed with error. in=%s out=%s", v.in, out)
} else if !v.shouldFail && out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestCanonicalTarName(t *testing.T) {
cases := []struct {
in string
isDir bool
expected string
}{
{"foo", false, "foo"},
{"foo", true, "foo/"},
{`foo\bar`, false, "foo/bar"},
{`foo\bar`, true, "foo/bar/"},
}
for _, v := range cases {
if out, err := canonicalTarName(v.in, v.isDir); err != nil {
t.Fatalf("cannot get canonical name for path: %s: %v", v.in, err)
} else if out != v.expected {
t.Fatalf("wrong canonical tar name. expected:%s got:%s", v.expected, out)
}
}
}
func TestChmodTarEntry(t *testing.T) {
cases := []struct {
in, expected os.FileMode
}{
{0000, 0111},
{0777, 0755},
{0644, 0755},
{0755, 0755},
{0444, 0555},
}
for _, v := range cases {
if out := chmodTarEntry(v.in); out != v.expected {
t.Fatalf("wrong chmod. expected:%v got:%v", v.expected, out)
}
}
}

View File

@@ -1,127 +0,0 @@
package archive
import (
"archive/tar"
"fmt"
"io"
"io/ioutil"
"os"
"path"
"sort"
"testing"
)
func TestHardLinkOrder(t *testing.T) {
names := []string{"file1.txt", "file2.txt", "file3.txt"}
msg := []byte("Hey y'all")
// Create dir
src, err := ioutil.TempDir("", "docker-hardlink-test-src-")
if err != nil {
t.Fatal(err)
}
//defer os.RemoveAll(src)
for _, name := range names {
func() {
fh, err := os.Create(path.Join(src, name))
if err != nil {
t.Fatal(err)
}
defer fh.Close()
if _, err = fh.Write(msg); err != nil {
t.Fatal(err)
}
}()
}
// Create dest, with changes that includes hardlinks
dest, err := ioutil.TempDir("", "docker-hardlink-test-dest-")
if err != nil {
t.Fatal(err)
}
os.RemoveAll(dest) // we just want the name, at first
if err := copyDir(src, dest); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dest)
for _, name := range names {
for i := 0; i < 5; i++ {
if err := os.Link(path.Join(dest, name), path.Join(dest, fmt.Sprintf("%s.link%d", name, i))); err != nil {
t.Fatal(err)
}
}
}
// get changes
changes, err := ChangesDirs(dest, src)
if err != nil {
t.Fatal(err)
}
// sort
sort.Sort(changesByPath(changes))
// ExportChanges
ar, err := ExportChanges(dest, changes, nil, nil)
if err != nil {
t.Fatal(err)
}
hdrs, err := walkHeaders(ar)
if err != nil {
t.Fatal(err)
}
// reverse sort
sort.Sort(sort.Reverse(changesByPath(changes)))
// ExportChanges
arRev, err := ExportChanges(dest, changes, nil, nil)
if err != nil {
t.Fatal(err)
}
hdrsRev, err := walkHeaders(arRev)
if err != nil {
t.Fatal(err)
}
// line up the two sets
sort.Sort(tarHeaders(hdrs))
sort.Sort(tarHeaders(hdrsRev))
// compare Size and LinkName
for i := range hdrs {
if hdrs[i].Name != hdrsRev[i].Name {
t.Errorf("headers - expected name %q; but got %q", hdrs[i].Name, hdrsRev[i].Name)
}
if hdrs[i].Size != hdrsRev[i].Size {
t.Errorf("headers - %q expected size %d; but got %d", hdrs[i].Name, hdrs[i].Size, hdrsRev[i].Size)
}
if hdrs[i].Typeflag != hdrsRev[i].Typeflag {
t.Errorf("headers - %q expected type %d; but got %d", hdrs[i].Name, hdrs[i].Typeflag, hdrsRev[i].Typeflag)
}
if hdrs[i].Linkname != hdrsRev[i].Linkname {
t.Errorf("headers - %q expected linkname %q; but got %q", hdrs[i].Name, hdrs[i].Linkname, hdrsRev[i].Linkname)
}
}
}
type tarHeaders []tar.Header
func (th tarHeaders) Len() int { return len(th) }
func (th tarHeaders) Swap(i, j int) { th[j], th[i] = th[i], th[j] }
func (th tarHeaders) Less(i, j int) bool { return th[i].Name < th[j].Name }
func walkHeaders(r io.Reader) ([]tar.Header, error) {
t := tar.NewReader(r)
headers := []tar.Header{}
for {
hdr, err := t.Next()
if err != nil {
if err == io.EOF {
break
}
return headers, err
}
headers = append(headers, *hdr)
}
return headers, nil
}

View File

@@ -1,527 +0,0 @@
package archive
import (
"io/ioutil"
"os"
"os/exec"
"path"
"sort"
"testing"
"time"
)
func max(x, y int) int {
if x >= y {
return x
}
return y
}
func copyDir(src, dst string) error {
cmd := exec.Command("cp", "-a", src, dst)
if err := cmd.Run(); err != nil {
return err
}
return nil
}
type FileType uint32
const (
Regular FileType = iota
Dir
Symlink
)
type FileData struct {
filetype FileType
path string
contents string
permissions os.FileMode
}
func createSampleDir(t *testing.T, root string) {
files := []FileData{
{Regular, "file1", "file1\n", 0600},
{Regular, "file2", "file2\n", 0666},
{Regular, "file3", "file3\n", 0404},
{Regular, "file4", "file4\n", 0600},
{Regular, "file5", "file5\n", 0600},
{Regular, "file6", "file6\n", 0600},
{Regular, "file7", "file7\n", 0600},
{Dir, "dir1", "", 0740},
{Regular, "dir1/file1-1", "file1-1\n", 01444},
{Regular, "dir1/file1-2", "file1-2\n", 0666},
{Dir, "dir2", "", 0700},
{Regular, "dir2/file2-1", "file2-1\n", 0666},
{Regular, "dir2/file2-2", "file2-2\n", 0666},
{Dir, "dir3", "", 0700},
{Regular, "dir3/file3-1", "file3-1\n", 0666},
{Regular, "dir3/file3-2", "file3-2\n", 0666},
{Dir, "dir4", "", 0700},
{Regular, "dir4/file3-1", "file4-1\n", 0666},
{Regular, "dir4/file3-2", "file4-2\n", 0666},
{Symlink, "symlink1", "target1", 0666},
{Symlink, "symlink2", "target2", 0666},
{Symlink, "symlink3", root + "/file1", 0666},
{Symlink, "symlink4", root + "/symlink3", 0666},
{Symlink, "dirSymlink", root + "/dir1", 0740},
}
now := time.Now()
for _, info := range files {
p := path.Join(root, info.path)
if info.filetype == Dir {
if err := os.MkdirAll(p, info.permissions); err != nil {
t.Fatal(err)
}
} else if info.filetype == Regular {
if err := ioutil.WriteFile(p, []byte(info.contents), info.permissions); err != nil {
t.Fatal(err)
}
} else if info.filetype == Symlink {
if err := os.Symlink(info.contents, p); err != nil {
t.Fatal(err)
}
}
if info.filetype != Symlink {
// Set a consistent ctime, atime for all files and dirs
if err := os.Chtimes(p, now, now); err != nil {
t.Fatal(err)
}
}
}
}
func TestChangeString(t *testing.T) {
modifiyChange := Change{"change", ChangeModify}
toString := modifiyChange.String()
if toString != "C change" {
t.Fatalf("String() of a change with ChangeModifiy Kind should have been %s but was %s", "C change", toString)
}
addChange := Change{"change", ChangeAdd}
toString = addChange.String()
if toString != "A change" {
t.Fatalf("String() of a change with ChangeAdd Kind should have been %s but was %s", "A change", toString)
}
deleteChange := Change{"change", ChangeDelete}
toString = deleteChange.String()
if toString != "D change" {
t.Fatalf("String() of a change with ChangeDelete Kind should have been %s but was %s", "D change", toString)
}
}
func TestChangesWithNoChanges(t *testing.T) {
rwLayer, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rwLayer)
layer, err := ioutil.TempDir("", "docker-changes-test-layer")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(layer)
createSampleDir(t, layer)
changes, err := Changes([]string{layer}, rwLayer)
if err != nil {
t.Fatal(err)
}
if len(changes) != 0 {
t.Fatalf("Changes with no difference should have detect no changes, but detected %d", len(changes))
}
}
func TestChangesWithChanges(t *testing.T) {
// Mock the readonly layer
layer, err := ioutil.TempDir("", "docker-changes-test-layer")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(layer)
createSampleDir(t, layer)
os.MkdirAll(path.Join(layer, "dir1/subfolder"), 0740)
// Mock the RW layer
rwLayer, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(rwLayer)
// Create a folder in RW layer
dir1 := path.Join(rwLayer, "dir1")
os.MkdirAll(dir1, 0740)
deletedFile := path.Join(dir1, ".wh.file1-2")
ioutil.WriteFile(deletedFile, []byte{}, 0600)
modifiedFile := path.Join(dir1, "file1-1")
ioutil.WriteFile(modifiedFile, []byte{0x00}, 01444)
// Let's add a subfolder for a newFile
subfolder := path.Join(dir1, "subfolder")
os.MkdirAll(subfolder, 0740)
newFile := path.Join(subfolder, "newFile")
ioutil.WriteFile(newFile, []byte{}, 0740)
changes, err := Changes([]string{layer}, rwLayer)
if err != nil {
t.Fatal(err)
}
expectedChanges := []Change{
{"/dir1", ChangeModify},
{"/dir1/file1-1", ChangeModify},
{"/dir1/file1-2", ChangeDelete},
{"/dir1/subfolder", ChangeModify},
{"/dir1/subfolder/newFile", ChangeAdd},
}
checkChanges(expectedChanges, changes, t)
}
// See https://github.com/hyperhq/hypercli/pull/13590
func TestChangesWithChangesGH13590(t *testing.T) {
baseLayer, err := ioutil.TempDir("", "docker-changes-test.")
defer os.RemoveAll(baseLayer)
dir3 := path.Join(baseLayer, "dir1/dir2/dir3")
os.MkdirAll(dir3, 07400)
file := path.Join(dir3, "file.txt")
ioutil.WriteFile(file, []byte("hello"), 0666)
layer, err := ioutil.TempDir("", "docker-changes-test2.")
defer os.RemoveAll(layer)
// Test creating a new file
if err := copyDir(baseLayer+"/dir1", layer+"/"); err != nil {
t.Fatalf("Cmd failed: %q", err)
}
os.Remove(path.Join(layer, "dir1/dir2/dir3/file.txt"))
file = path.Join(layer, "dir1/dir2/dir3/file1.txt")
ioutil.WriteFile(file, []byte("bye"), 0666)
changes, err := Changes([]string{baseLayer}, layer)
if err != nil {
t.Fatal(err)
}
expectedChanges := []Change{
{"/dir1/dir2/dir3", ChangeModify},
{"/dir1/dir2/dir3/file1.txt", ChangeAdd},
}
checkChanges(expectedChanges, changes, t)
// Now test changing a file
layer, err = ioutil.TempDir("", "docker-changes-test3.")
defer os.RemoveAll(layer)
if err := copyDir(baseLayer+"/dir1", layer+"/"); err != nil {
t.Fatalf("Cmd failed: %q", err)
}
file = path.Join(layer, "dir1/dir2/dir3/file.txt")
ioutil.WriteFile(file, []byte("bye"), 0666)
changes, err = Changes([]string{baseLayer}, layer)
if err != nil {
t.Fatal(err)
}
expectedChanges = []Change{
{"/dir1/dir2/dir3/file.txt", ChangeModify},
}
checkChanges(expectedChanges, changes, t)
}
// Create an directory, copy it, make sure we report no changes between the two
func TestChangesDirsEmpty(t *testing.T) {
src, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(src)
createSampleDir(t, src)
dst := src + "-copy"
if err := copyDir(src, dst); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(dst)
changes, err := ChangesDirs(dst, src)
if err != nil {
t.Fatal(err)
}
if len(changes) != 0 {
t.Fatalf("Reported changes for identical dirs: %v", changes)
}
os.RemoveAll(src)
os.RemoveAll(dst)
}
func mutateSampleDir(t *testing.T, root string) {
// Remove a regular file
if err := os.RemoveAll(path.Join(root, "file1")); err != nil {
t.Fatal(err)
}
// Remove a directory
if err := os.RemoveAll(path.Join(root, "dir1")); err != nil {
t.Fatal(err)
}
// Remove a symlink
if err := os.RemoveAll(path.Join(root, "symlink1")); err != nil {
t.Fatal(err)
}
// Rewrite a file
if err := ioutil.WriteFile(path.Join(root, "file2"), []byte("fileNN\n"), 0777); err != nil {
t.Fatal(err)
}
// Replace a file
if err := os.RemoveAll(path.Join(root, "file3")); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(path.Join(root, "file3"), []byte("fileMM\n"), 0404); err != nil {
t.Fatal(err)
}
// Touch file
if err := os.Chtimes(path.Join(root, "file4"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
// Replace file with dir
if err := os.RemoveAll(path.Join(root, "file5")); err != nil {
t.Fatal(err)
}
if err := os.MkdirAll(path.Join(root, "file5"), 0666); err != nil {
t.Fatal(err)
}
// Create new file
if err := ioutil.WriteFile(path.Join(root, "filenew"), []byte("filenew\n"), 0777); err != nil {
t.Fatal(err)
}
// Create new dir
if err := os.MkdirAll(path.Join(root, "dirnew"), 0766); err != nil {
t.Fatal(err)
}
// Create a new symlink
if err := os.Symlink("targetnew", path.Join(root, "symlinknew")); err != nil {
t.Fatal(err)
}
// Change a symlink
if err := os.RemoveAll(path.Join(root, "symlink2")); err != nil {
t.Fatal(err)
}
if err := os.Symlink("target2change", path.Join(root, "symlink2")); err != nil {
t.Fatal(err)
}
// Replace dir with file
if err := os.RemoveAll(path.Join(root, "dir2")); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(path.Join(root, "dir2"), []byte("dir2\n"), 0777); err != nil {
t.Fatal(err)
}
// Touch dir
if err := os.Chtimes(path.Join(root, "dir3"), time.Now().Add(time.Second), time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
}
func TestChangesDirsMutated(t *testing.T) {
src, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
createSampleDir(t, src)
dst := src + "-copy"
if err := copyDir(src, dst); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(src)
defer os.RemoveAll(dst)
mutateSampleDir(t, dst)
changes, err := ChangesDirs(dst, src)
if err != nil {
t.Fatal(err)
}
sort.Sort(changesByPath(changes))
expectedChanges := []Change{
{"/dir1", ChangeDelete},
{"/dir2", ChangeModify},
{"/dirnew", ChangeAdd},
{"/file1", ChangeDelete},
{"/file2", ChangeModify},
{"/file3", ChangeModify},
{"/file4", ChangeModify},
{"/file5", ChangeModify},
{"/filenew", ChangeAdd},
{"/symlink1", ChangeDelete},
{"/symlink2", ChangeModify},
{"/symlinknew", ChangeAdd},
}
for i := 0; i < max(len(changes), len(expectedChanges)); i++ {
if i >= len(expectedChanges) {
t.Fatalf("unexpected change %s\n", changes[i].String())
}
if i >= len(changes) {
t.Fatalf("no change for expected change %s\n", expectedChanges[i].String())
}
if changes[i].Path == expectedChanges[i].Path {
if changes[i] != expectedChanges[i] {
t.Fatalf("Wrong change for %s, expected %s, got %s\n", changes[i].Path, changes[i].String(), expectedChanges[i].String())
}
} else if changes[i].Path < expectedChanges[i].Path {
t.Fatalf("unexpected change %s\n", changes[i].String())
} else {
t.Fatalf("no change for expected change %s != %s\n", expectedChanges[i].String(), changes[i].String())
}
}
}
func TestApplyLayer(t *testing.T) {
src, err := ioutil.TempDir("", "docker-changes-test")
if err != nil {
t.Fatal(err)
}
createSampleDir(t, src)
defer os.RemoveAll(src)
dst := src + "-copy"
if err := copyDir(src, dst); err != nil {
t.Fatal(err)
}
mutateSampleDir(t, dst)
defer os.RemoveAll(dst)
changes, err := ChangesDirs(dst, src)
if err != nil {
t.Fatal(err)
}
layer, err := ExportChanges(dst, changes, nil, nil)
if err != nil {
t.Fatal(err)
}
layerCopy, err := NewTempArchive(layer, "")
if err != nil {
t.Fatal(err)
}
if _, err := ApplyLayer(src, layerCopy); err != nil {
t.Fatal(err)
}
changes2, err := ChangesDirs(src, dst)
if err != nil {
t.Fatal(err)
}
if len(changes2) != 0 {
t.Fatalf("Unexpected differences after reapplying mutation: %v", changes2)
}
}
func TestChangesSizeWithHardlinks(t *testing.T) {
srcDir, err := ioutil.TempDir("", "docker-test-srcDir")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(srcDir)
destDir, err := ioutil.TempDir("", "docker-test-destDir")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(destDir)
creationSize, err := prepareUntarSourceDirectory(100, destDir, true)
if err != nil {
t.Fatal(err)
}
changes, err := ChangesDirs(destDir, srcDir)
if err != nil {
t.Fatal(err)
}
got := ChangesSize(destDir, changes)
if got != int64(creationSize) {
t.Errorf("Expected %d bytes of changes, got %d", creationSize, got)
}
}
func TestChangesSizeWithNoChanges(t *testing.T) {
size := ChangesSize("/tmp", nil)
if size != 0 {
t.Fatalf("ChangesSizes with no changes should be 0, was %d", size)
}
}
func TestChangesSizeWithOnlyDeleteChanges(t *testing.T) {
changes := []Change{
{Path: "deletedPath", Kind: ChangeDelete},
}
size := ChangesSize("/tmp", changes)
if size != 0 {
t.Fatalf("ChangesSizes with only delete changes should be 0, was %d", size)
}
}
func TestChangesSize(t *testing.T) {
parentPath, err := ioutil.TempDir("", "docker-changes-test")
defer os.RemoveAll(parentPath)
addition := path.Join(parentPath, "addition")
if err := ioutil.WriteFile(addition, []byte{0x01, 0x01, 0x01}, 0744); err != nil {
t.Fatal(err)
}
modification := path.Join(parentPath, "modification")
if err = ioutil.WriteFile(modification, []byte{0x01, 0x01, 0x01}, 0744); err != nil {
t.Fatal(err)
}
changes := []Change{
{Path: "addition", Kind: ChangeAdd},
{Path: "modification", Kind: ChangeModify},
}
size := ChangesSize(parentPath, changes)
if size != 6 {
t.Fatalf("Expected 6 bytes of changes, got %d", size)
}
}
func checkChanges(expectedChanges, changes []Change, t *testing.T) {
sort.Sort(changesByPath(expectedChanges))
sort.Sort(changesByPath(changes))
for i := 0; i < max(len(changes), len(expectedChanges)); i++ {
if i >= len(expectedChanges) {
t.Fatalf("unexpected change %s\n", changes[i].String())
}
if i >= len(changes) {
t.Fatalf("no change for expected change %s\n", expectedChanges[i].String())
}
if changes[i].Path == expectedChanges[i].Path {
if changes[i] != expectedChanges[i] {
t.Fatalf("Wrong change for %s, expected %s, got %s\n", changes[i].Path, changes[i].String(), expectedChanges[i].String())
}
} else if changes[i].Path < expectedChanges[i].Path {
t.Fatalf("unexpected change %s\n", changes[i].String())
} else {
t.Fatalf("no change for expected change %s != %s\n", expectedChanges[i].String(), changes[i].String())
}
}
}

View File

@@ -1,974 +0,0 @@
package archive
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
)
func removeAllPaths(paths ...string) {
for _, path := range paths {
os.RemoveAll(path)
}
}
func getTestTempDirs(t *testing.T) (tmpDirA, tmpDirB string) {
var err error
if tmpDirA, err = ioutil.TempDir("", "archive-copy-test"); err != nil {
t.Fatal(err)
}
if tmpDirB, err = ioutil.TempDir("", "archive-copy-test"); err != nil {
t.Fatal(err)
}
return
}
func isNotDir(err error) bool {
return strings.Contains(err.Error(), "not a directory")
}
func joinTrailingSep(pathElements ...string) string {
joined := filepath.Join(pathElements...)
return fmt.Sprintf("%s%c", joined, filepath.Separator)
}
func fileContentsEqual(t *testing.T, filenameA, filenameB string) (err error) {
t.Logf("checking for equal file contents: %q and %q\n", filenameA, filenameB)
fileA, err := os.Open(filenameA)
if err != nil {
return
}
defer fileA.Close()
fileB, err := os.Open(filenameB)
if err != nil {
return
}
defer fileB.Close()
hasher := sha256.New()
if _, err = io.Copy(hasher, fileA); err != nil {
return
}
hashA := hasher.Sum(nil)
hasher.Reset()
if _, err = io.Copy(hasher, fileB); err != nil {
return
}
hashB := hasher.Sum(nil)
if !bytes.Equal(hashA, hashB) {
err = fmt.Errorf("file content hashes not equal - expected %s, got %s", hex.EncodeToString(hashA), hex.EncodeToString(hashB))
}
return
}
func dirContentsEqual(t *testing.T, newDir, oldDir string) (err error) {
t.Logf("checking for equal directory contents: %q and %q\n", newDir, oldDir)
var changes []Change
if changes, err = ChangesDirs(newDir, oldDir); err != nil {
return
}
if len(changes) != 0 {
err = fmt.Errorf("expected no changes between directories, but got: %v", changes)
}
return
}
func logDirContents(t *testing.T, dirPath string) {
logWalkedPaths := filepath.WalkFunc(func(path string, info os.FileInfo, err error) error {
if err != nil {
t.Errorf("stat error for path %q: %s", path, err)
return nil
}
if info.IsDir() {
path = joinTrailingSep(path)
}
t.Logf("\t%s", path)
return nil
})
t.Logf("logging directory contents: %q", dirPath)
if err := filepath.Walk(dirPath, logWalkedPaths); err != nil {
t.Fatal(err)
}
}
func testCopyHelper(t *testing.T, srcPath, dstPath string) (err error) {
t.Logf("copying from %q to %q (not follow symbol link)", srcPath, dstPath)
return CopyResource(srcPath, dstPath, false)
}
func testCopyHelperFSym(t *testing.T, srcPath, dstPath string) (err error) {
t.Logf("copying from %q to %q (follow symbol link)", srcPath, dstPath)
return CopyResource(srcPath, dstPath, true)
}
// Basic assumptions about SRC and DST:
// 1. SRC must exist.
// 2. If SRC ends with a trailing separator, it must be a directory.
// 3. DST parent directory must exist.
// 4. If DST exists as a file, it must not end with a trailing separator.
// First get these easy error cases out of the way.
// Test for error when SRC does not exist.
func TestCopyErrSrcNotExists(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
if _, err := CopyInfoSourcePath(filepath.Join(tmpDirA, "file1"), false); !os.IsNotExist(err) {
t.Fatalf("expected IsNotExist error, but got %T: %s", err, err)
}
}
// Test for error when SRC ends in a trailing
// path separator but it exists as a file.
func TestCopyErrSrcNotDir(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
if _, err := CopyInfoSourcePath(joinTrailingSep(tmpDirA, "file1"), false); !isNotDir(err) {
t.Fatalf("expected IsNotDir error, but got %T: %s", err, err)
}
}
// Test for error when SRC is a valid file or directory,
// but the DST parent directory does not exist.
func TestCopyErrDstParentNotExists(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcInfo := CopyInfo{Path: filepath.Join(tmpDirA, "file1"), Exists: true, IsDir: false}
// Try with a file source.
content, err := TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
// Copy to a file whose parent does not exist.
if err = CopyTo(content, srcInfo, filepath.Join(tmpDirB, "fakeParentDir", "file1")); err == nil {
t.Fatal("expected IsNotExist error, but got nil instead")
}
if !os.IsNotExist(err) {
t.Fatalf("expected IsNotExist error, but got %T: %s", err, err)
}
// Try with a directory source.
srcInfo = CopyInfo{Path: filepath.Join(tmpDirA, "dir1"), Exists: true, IsDir: true}
content, err = TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
// Copy to a directory whose parent does not exist.
if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "fakeParentDir", "fakeDstDir")); err == nil {
t.Fatal("expected IsNotExist error, but got nil instead")
}
if !os.IsNotExist(err) {
t.Fatalf("expected IsNotExist error, but got %T: %s", err, err)
}
}
// Test for error when DST ends in a trailing
// path separator but exists as a file.
func TestCopyErrDstNotDir(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
// Try with a file source.
srcInfo := CopyInfo{Path: filepath.Join(tmpDirA, "file1"), Exists: true, IsDir: false}
content, err := TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "file1")); err == nil {
t.Fatal("expected IsNotDir error, but got nil instead")
}
if !isNotDir(err) {
t.Fatalf("expected IsNotDir error, but got %T: %s", err, err)
}
// Try with a directory source.
srcInfo = CopyInfo{Path: filepath.Join(tmpDirA, "dir1"), Exists: true, IsDir: true}
content, err = TarResource(srcInfo)
if err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
defer content.Close()
if err = CopyTo(content, srcInfo, joinTrailingSep(tmpDirB, "file1")); err == nil {
t.Fatal("expected IsNotDir error, but got nil instead")
}
if !isNotDir(err) {
t.Fatalf("expected IsNotDir error, but got %T: %s", err, err)
}
}
// Possibilities are reduced to the remaining 10 cases:
//
// case | srcIsDir | onlyDirContents | dstExists | dstIsDir | dstTrSep | action
// ===================================================================================================
// A | no | - | no | - | no | create file
// B | no | - | no | - | yes | error
// C | no | - | yes | no | - | overwrite file
// D | no | - | yes | yes | - | create file in dst dir
// E | yes | no | no | - | - | create dir, copy contents
// F | yes | no | yes | no | - | error
// G | yes | no | yes | yes | - | copy dir and contents
// H | yes | yes | no | - | - | create dir, copy contents
// I | yes | yes | yes | no | - | error
// J | yes | yes | yes | yes | - | copy dir contents
//
// A. SRC specifies a file and DST (no trailing path separator) doesn't
// exist. This should create a file with the name DST and copy the
// contents of the source file into it.
func TestCopyCaseA(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcPath := filepath.Join(tmpDirA, "file1")
dstPath := filepath.Join(tmpDirB, "itWorks.txt")
var err error
if err = testCopyHelper(t, srcPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
os.Remove(dstPath)
symlinkPath := filepath.Join(tmpDirA, "symlink3")
symlinkPath1 := filepath.Join(tmpDirA, "symlink4")
linkTarget := filepath.Join(tmpDirA, "file1")
if err = testCopyHelperFSym(t, symlinkPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
os.Remove(dstPath)
if err = testCopyHelperFSym(t, symlinkPath1, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
}
// B. SRC specifies a file and DST (with trailing path separator) doesn't
// exist. This should cause an error because the copy operation cannot
// create a directory when copying a single file.
func TestCopyCaseB(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcPath := filepath.Join(tmpDirA, "file1")
dstDir := joinTrailingSep(tmpDirB, "testDir")
var err error
if err = testCopyHelper(t, srcPath, dstDir); err == nil {
t.Fatal("expected ErrDirNotExists error, but got nil instead")
}
if err != ErrDirNotExists {
t.Fatalf("expected ErrDirNotExists error, but got %T: %s", err, err)
}
symlinkPath := filepath.Join(tmpDirA, "symlink3")
if err = testCopyHelperFSym(t, symlinkPath, dstDir); err == nil {
t.Fatal("expected ErrDirNotExists error, but got nil instead")
}
if err != ErrDirNotExists {
t.Fatalf("expected ErrDirNotExists error, but got %T: %s", err, err)
}
}
// C. SRC specifies a file and DST exists as a file. This should overwrite
// the file at DST with the contents of the source file.
func TestCopyCaseC(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcPath := filepath.Join(tmpDirA, "file1")
dstPath := filepath.Join(tmpDirB, "file2")
var err error
// Ensure they start out different.
if err = fileContentsEqual(t, srcPath, dstPath); err == nil {
t.Fatal("expected different file contents")
}
if err = testCopyHelper(t, srcPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
}
// C. Symbol link following version:
// SRC specifies a file and DST exists as a file. This should overwrite
// the file at DST with the contents of the source file.
func TestCopyCaseCFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
symlinkPathBad := filepath.Join(tmpDirA, "symlink1")
symlinkPath := filepath.Join(tmpDirA, "symlink3")
linkTarget := filepath.Join(tmpDirA, "file1")
dstPath := filepath.Join(tmpDirB, "file2")
var err error
// first to test broken link
if err = testCopyHelperFSym(t, symlinkPathBad, dstPath); err == nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
// test symbol link -> symbol link -> target
// Ensure they start out different.
if err = fileContentsEqual(t, linkTarget, dstPath); err == nil {
t.Fatal("expected different file contents")
}
if err = testCopyHelperFSym(t, symlinkPath, dstPath); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
}
// D. SRC specifies a file and DST exists as a directory. This should place
// a copy of the source file inside it using the basename from SRC. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseD(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcPath := filepath.Join(tmpDirA, "file1")
dstDir := filepath.Join(tmpDirB, "dir1")
dstPath := filepath.Join(dstDir, "file1")
var err error
// Ensure that dstPath doesn't exist.
if _, err = os.Stat(dstPath); !os.IsNotExist(err) {
t.Fatalf("did not expect dstPath %q to exist", dstPath)
}
if err = testCopyHelper(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir1")
if err = testCopyHelper(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, srcPath, dstPath); err != nil {
t.Fatal(err)
}
}
// D. Symbol link following version:
// SRC specifies a file and DST exists as a directory. This should place
// a copy of the source file inside it using the basename from SRC. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseDFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcPath := filepath.Join(tmpDirA, "symlink4")
linkTarget := filepath.Join(tmpDirA, "file1")
dstDir := filepath.Join(tmpDirB, "dir1")
dstPath := filepath.Join(dstDir, "symlink4")
var err error
// Ensure that dstPath doesn't exist.
if _, err = os.Stat(dstPath); !os.IsNotExist(err) {
t.Fatalf("did not expect dstPath %q to exist", dstPath)
}
if err = testCopyHelperFSym(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir1")
if err = testCopyHelperFSym(t, srcPath, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = fileContentsEqual(t, linkTarget, dstPath); err != nil {
t.Fatal(err)
}
}
// E. SRC specifies a directory and DST does not exist. This should create a
// directory at DST and copy the contents of the SRC directory into the DST
// directory. Ensure this works whether DST has a trailing path separator or
// not.
func TestCopyCaseE(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Fatal(err)
}
}
// E. Symbol link following version:
// SRC specifies a directory and DST does not exist. This should create a
// directory at DST and copy the contents of the SRC directory into the DST
// directory. Ensure this works whether DST has a trailing path separator or
// not.
func TestCopyCaseEFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := filepath.Join(tmpDirA, "dirSymlink")
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Fatal(err)
}
}
// F. SRC specifies a directory and DST exists as a file. This should cause an
// error as it is not possible to overwrite a file with a directory.
func TestCopyCaseF(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := filepath.Join(tmpDirA, "dir1")
symSrcDir := filepath.Join(tmpDirA, "dirSymlink")
dstFile := filepath.Join(tmpDirB, "file1")
var err error
if err = testCopyHelper(t, srcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
// now test with symbol link
if err = testCopyHelperFSym(t, symSrcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
}
// G. SRC specifies a directory and DST exists as a directory. This should copy
// the SRC directory and all its contents to the DST directory. Ensure this
// works whether DST has a trailing path separator or not.
func TestCopyCaseG(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "dir2")
resultDir := filepath.Join(dstDir, "dir1")
var err error
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, srcDir); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir2")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, srcDir); err != nil {
t.Fatal(err)
}
}
// G. Symbol link version:
// SRC specifies a directory and DST exists as a directory. This should copy
// the SRC directory and all its contents to the DST directory. Ensure this
// works whether DST has a trailing path separator or not.
func TestCopyCaseGFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := filepath.Join(tmpDirA, "dirSymlink")
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "dir2")
resultDir := filepath.Join(dstDir, "dirSymlink")
var err error
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, linkTarget); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir2")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, resultDir, linkTarget); err != nil {
t.Fatal(err)
}
}
// H. SRC specifies a directory's contents only and DST does not exist. This
// should create a directory at DST and copy the contents of the SRC
// directory (but not the directory itself) into the DST directory. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseH(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := joinTrailingSep(tmpDirA, "dir1") + "."
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
}
// H. Symbol link following version:
// SRC specifies a directory's contents only and DST does not exist. This
// should create a directory at DST and copy the contents of the SRC
// directory (but not the directory itself) into the DST directory. Ensure
// this works whether DST has a trailing path separator or not.
func TestCopyCaseHFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A with some sample files and directories.
createSampleDir(t, tmpDirA)
srcDir := joinTrailingSep(tmpDirA, "dirSymlink") + "."
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "testDir")
var err error
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "testDir")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Log("dir contents not equal")
logDirContents(t, tmpDirA)
logDirContents(t, tmpDirB)
t.Fatal(err)
}
}
// I. SRC specifies a directory's contents only and DST exists as a file. This
// should cause an error as it is not possible to overwrite a file with a
// directory.
func TestCopyCaseI(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := joinTrailingSep(tmpDirA, "dir1") + "."
symSrcDir := filepath.Join(tmpDirB, "dirSymlink")
dstFile := filepath.Join(tmpDirB, "file1")
var err error
if err = testCopyHelper(t, srcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
// now try with symbol link of dir
if err = testCopyHelperFSym(t, symSrcDir, dstFile); err == nil {
t.Fatal("expected ErrCannotCopyDir error, but got nil instead")
}
if err != ErrCannotCopyDir {
t.Fatalf("expected ErrCannotCopyDir error, but got %T: %s", err, err)
}
}
// J. SRC specifies a directory's contents only and DST exists as a directory.
// This should copy the contents of the SRC directory (but not the directory
// itself) into the DST directory. Ensure this works whether DST has a
// trailing path separator or not.
func TestCopyCaseJ(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := joinTrailingSep(tmpDirA, "dir1") + "."
dstDir := filepath.Join(tmpDirB, "dir5")
var err error
// first to create an empty dir
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir5")
if err = testCopyHelper(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, srcDir); err != nil {
t.Fatal(err)
}
}
// J. Symbol link following version:
// SRC specifies a directory's contents only and DST exists as a directory.
// This should copy the contents of the SRC directory (but not the directory
// itself) into the DST directory. Ensure this works whether DST has a
// trailing path separator or not.
func TestCopyCaseJFSym(t *testing.T) {
tmpDirA, tmpDirB := getTestTempDirs(t)
defer removeAllPaths(tmpDirA, tmpDirB)
// Load A and B with some sample files and directories.
createSampleDir(t, tmpDirA)
createSampleDir(t, tmpDirB)
srcDir := joinTrailingSep(tmpDirA, "dirSymlink") + "."
linkTarget := filepath.Join(tmpDirA, "dir1")
dstDir := filepath.Join(tmpDirB, "dir5")
var err error
// first to create an empty dir
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Fatal(err)
}
// Now try again but using a trailing path separator for dstDir.
if err = os.RemoveAll(dstDir); err != nil {
t.Fatalf("unable to remove dstDir: %s", err)
}
if err = os.MkdirAll(dstDir, os.FileMode(0755)); err != nil {
t.Fatalf("unable to make dstDir: %s", err)
}
dstDir = joinTrailingSep(tmpDirB, "dir5")
if err = testCopyHelperFSym(t, srcDir, dstDir); err != nil {
t.Fatalf("unexpected error %T: %s", err, err)
}
if err = dirContentsEqual(t, dstDir, linkTarget); err != nil {
t.Fatal(err)
}
}

View File

@@ -1,370 +0,0 @@
package archive
import (
"archive/tar"
"io"
"io/ioutil"
"os"
"path/filepath"
"reflect"
"testing"
"github.com/hyperhq/hypercli/pkg/ioutils"
)
func TestApplyLayerInvalidFilenames(t *testing.T) {
for i, headers := range [][]*tar.Header{
{
{
Name: "../victim/dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{
{
// Note the leading slash
Name: "/../victim/slash-dotdot",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidFilenames", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerInvalidHardlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeLink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeLink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (hardlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try reading victim/hello (hardlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // Try removing victim directory (hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeLink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidHardlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerInvalidSymlink(t *testing.T) {
for i, headers := range [][]*tar.Header{
{ // try reading victim/hello (../)
{
Name: "dotdot",
Typeflag: tar.TypeSymlink,
Linkname: "../victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (/../)
{
Name: "slash-dotdot",
Typeflag: tar.TypeSymlink,
// Note the leading slash
Linkname: "/../victim/hello",
Mode: 0644,
},
},
{ // try writing victim/file
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim/file",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "symlink",
Typeflag: tar.TypeSymlink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try reading victim/hello (symlink, hardlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "hardlink",
Typeflag: tar.TypeLink,
Linkname: "loophole-victim/hello",
Mode: 0644,
},
},
{ // try removing victim directory (symlink)
{
Name: "loophole-victim",
Typeflag: tar.TypeSymlink,
Linkname: "../victim",
Mode: 0755,
},
{
Name: "loophole-victim",
Typeflag: tar.TypeReg,
Mode: 0644,
},
},
} {
if err := testBreakout("applylayer", "docker-TestApplyLayerInvalidSymlink", headers); err != nil {
t.Fatalf("i=%d. %v", i, err)
}
}
}
func TestApplyLayerWhiteouts(t *testing.T) {
wd, err := ioutil.TempDir("", "graphdriver-test-whiteouts")
if err != nil {
return
}
defer os.RemoveAll(wd)
base := []string{
".baz",
"bar/",
"bar/bax",
"bar/bay/",
"baz",
"foo/",
"foo/.abc",
"foo/.bcd/",
"foo/.bcd/a",
"foo/cde/",
"foo/cde/def",
"foo/cde/efg",
"foo/fgh",
"foobar",
}
type tcase struct {
change, expected []string
}
tcases := []tcase{
{
base,
base,
},
{
[]string{
".bay",
".wh.baz",
"foo/",
"foo/.bce",
"foo/.wh..wh..opq",
"foo/cde/",
"foo/cde/efg",
},
[]string{
".bay",
".baz",
"bar/",
"bar/bax",
"bar/bay/",
"foo/",
"foo/.bce",
"foo/cde/",
"foo/cde/efg",
"foobar",
},
},
{
[]string{
".bay",
".wh..baz",
".wh.foobar",
"foo/",
"foo/.abc",
"foo/.wh.cde",
"bar/",
},
[]string{
".bay",
"bar/",
"bar/bax",
"bar/bay/",
"foo/",
"foo/.abc",
"foo/.bce",
},
},
{
[]string{
".abc",
".wh..wh..opq",
"foobar",
},
[]string{
".abc",
"foobar",
},
},
}
for i, tc := range tcases {
l, err := makeTestLayer(tc.change)
if err != nil {
t.Fatal(err)
}
_, err = UnpackLayer(wd, l, nil)
if err != nil {
t.Fatal(err)
}
err = l.Close()
if err != nil {
t.Fatal(err)
}
paths, err := readDirContents(wd)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(tc.expected, paths) {
t.Fatalf("invalid files for layer %d: expected %q, got %q", i, tc.expected, paths)
}
}
}
func makeTestLayer(paths []string) (rc io.ReadCloser, err error) {
tmpDir, err := ioutil.TempDir("", "graphdriver-test-mklayer")
if err != nil {
return
}
defer func() {
if err != nil {
os.RemoveAll(tmpDir)
}
}()
for _, p := range paths {
if p[len(p)-1] == filepath.Separator {
if err = os.MkdirAll(filepath.Join(tmpDir, p), 0700); err != nil {
return
}
} else {
if err = ioutil.WriteFile(filepath.Join(tmpDir, p), nil, 0600); err != nil {
return
}
}
}
archive, err := Tar(tmpDir, Uncompressed)
if err != nil {
return
}
return ioutils.NewReadCloserWrapper(archive, func() error {
err := archive.Close()
os.RemoveAll(tmpDir)
return err
}), nil
}
func readDirContents(root string) ([]string, error) {
var files []string
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == root {
return nil
}
rel, err := filepath.Rel(root, path)
if err != nil {
return err
}
if info.IsDir() {
rel = rel + "/"
}
files = append(files, rel)
return nil
})
if err != nil {
return nil, err
}
return files, nil
}

View File

@@ -1,166 +0,0 @@
package archive
import (
"archive/tar"
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"time"
)
var testUntarFns = map[string]func(string, io.Reader) error{
"untar": func(dest string, r io.Reader) error {
return Untar(r, dest, nil)
},
"applylayer": func(dest string, r io.Reader) error {
_, err := ApplyLayer(dest, Reader(r))
return err
},
}
// testBreakout is a helper function that, within the provided `tmpdir` directory,
// creates a `victim` folder with a generated `hello` file in it.
// `untar` extracts to a directory named `dest`, the tar file created from `headers`.
//
// Here are the tested scenarios:
// - removed `victim` folder (write)
// - removed files from `victim` folder (write)
// - new files in `victim` folder (write)
// - modified files in `victim` folder (write)
// - file in `dest` with same content as `victim/hello` (read)
//
// When using testBreakout make sure you cover one of the scenarios listed above.
func testBreakout(untarFn string, tmpdir string, headers []*tar.Header) error {
tmpdir, err := ioutil.TempDir("", tmpdir)
if err != nil {
return err
}
defer os.RemoveAll(tmpdir)
dest := filepath.Join(tmpdir, "dest")
if err := os.Mkdir(dest, 0755); err != nil {
return err
}
victim := filepath.Join(tmpdir, "victim")
if err := os.Mkdir(victim, 0755); err != nil {
return err
}
hello := filepath.Join(victim, "hello")
helloData, err := time.Now().MarshalText()
if err != nil {
return err
}
if err := ioutil.WriteFile(hello, helloData, 0644); err != nil {
return err
}
helloStat, err := os.Stat(hello)
if err != nil {
return err
}
reader, writer := io.Pipe()
go func() {
t := tar.NewWriter(writer)
for _, hdr := range headers {
t.WriteHeader(hdr)
}
t.Close()
}()
untar := testUntarFns[untarFn]
if untar == nil {
return fmt.Errorf("could not find untar function %q in testUntarFns", untarFn)
}
if err := untar(dest, reader); err != nil {
if _, ok := err.(breakoutError); !ok {
// If untar returns an error unrelated to an archive breakout,
// then consider this an unexpected error and abort.
return err
}
// Here, untar detected the breakout.
// Let's move on verifying that indeed there was no breakout.
fmt.Printf("breakoutError: %v\n", err)
}
// Check victim folder
f, err := os.Open(victim)
if err != nil {
// codepath taken if victim folder was removed
return fmt.Errorf("archive breakout: error reading %q: %v", victim, err)
}
defer f.Close()
// Check contents of victim folder
//
// We are only interested in getting 2 files from the victim folder, because if all is well
// we expect only one result, the `hello` file. If there is a second result, it cannot
// hold the same name `hello` and we assume that a new file got created in the victim folder.
// That is enough to detect an archive breakout.
names, err := f.Readdirnames(2)
if err != nil {
// codepath taken if victim is not a folder
return fmt.Errorf("archive breakout: error reading directory content of %q: %v", victim, err)
}
for _, name := range names {
if name != "hello" {
// codepath taken if new file was created in victim folder
return fmt.Errorf("archive breakout: new file %q", name)
}
}
// Check victim/hello
f, err = os.Open(hello)
if err != nil {
// codepath taken if read permissions were removed
return fmt.Errorf("archive breakout: could not lstat %q: %v", hello, err)
}
defer f.Close()
b, err := ioutil.ReadAll(f)
if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
if helloStat.IsDir() != fi.IsDir() ||
// TODO: cannot check for fi.ModTime() change
helloStat.Mode() != fi.Mode() ||
helloStat.Size() != fi.Size() ||
!bytes.Equal(helloData, b) {
// codepath taken if hello has been modified
return fmt.Errorf("archive breakout: file %q has been modified. Contents: expected=%q, got=%q. FileInfo: expected=%#v, got=%#v", hello, helloData, b, helloStat, fi)
}
// Check that nothing in dest/ has the same content as victim/hello.
// Since victim/hello was generated with time.Now(), it is safe to assume
// that any file whose content matches exactly victim/hello, managed somehow
// to access victim/hello.
return filepath.Walk(dest, func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
if err != nil {
// skip directory if error
return filepath.SkipDir
}
// enter directory
return nil
}
if err != nil {
// skip file if error
return nil
}
b, err := ioutil.ReadFile(path)
if err != nil {
// Houston, we have a problem. Aborting (space)walk.
return err
}
if bytes.Equal(helloData, b) {
return fmt.Errorf("archive breakout: file %q has been accessed via %q", hello, path)
}
return nil
})
}

View File

@@ -1,98 +0,0 @@
package archive
import (
"archive/tar"
"bytes"
"io"
"testing"
)
func TestGenerateEmptyFile(t *testing.T) {
archive, err := Generate("emptyFile")
if err != nil {
t.Fatal(err)
}
if archive == nil {
t.Fatal("The generated archive should not be nil.")
}
expectedFiles := [][]string{
{"emptyFile", ""},
}
tr := tar.NewReader(archive)
actualFiles := make([][]string, 0, 10)
i := 0
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(tr)
content := buf.String()
actualFiles = append(actualFiles, []string{hdr.Name, content})
i++
}
if len(actualFiles) != len(expectedFiles) {
t.Fatalf("Number of expected file %d, got %d.", len(expectedFiles), len(actualFiles))
}
for i := 0; i < len(expectedFiles); i++ {
actual := actualFiles[i]
expected := expectedFiles[i]
if actual[0] != expected[0] {
t.Fatalf("Expected name '%s', Actual name '%s'", expected[0], actual[0])
}
if actual[1] != expected[1] {
t.Fatalf("Expected content '%s', Actual content '%s'", expected[1], actual[1])
}
}
}
func TestGenerateWithContent(t *testing.T) {
archive, err := Generate("file", "content")
if err != nil {
t.Fatal(err)
}
if archive == nil {
t.Fatal("The generated archive should not be nil.")
}
expectedFiles := [][]string{
{"file", "content"},
}
tr := tar.NewReader(archive)
actualFiles := make([][]string, 0, 10)
i := 0
for {
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(tr)
content := buf.String()
actualFiles = append(actualFiles, []string{hdr.Name, content})
i++
}
if len(actualFiles) != len(expectedFiles) {
t.Fatalf("Number of expected file %d, got %d.", len(expectedFiles), len(actualFiles))
}
for i := 0; i < len(expectedFiles); i++ {
actual := actualFiles[i]
expected := expectedFiles[i]
if actual[0] != expected[0] {
t.Fatalf("Expected name '%s', Actual name '%s'", expected[0], actual[0])
}
if actual[1] != expected[1] {
t.Fatalf("Expected content '%s', Actual content '%s'", expected[1], actual[1])
}
}
}

View File

@@ -1,54 +0,0 @@
package authorization
const (
// AuthZApiRequest is the url for daemon request authorization
AuthZApiRequest = "AuthZPlugin.AuthZReq"
// AuthZApiResponse is the url for daemon response authorization
AuthZApiResponse = "AuthZPlugin.AuthZRes"
// AuthZApiImplements is the name of the interface all AuthZ plugins implement
AuthZApiImplements = "authz"
)
// Request holds data required for authZ plugins
type Request struct {
// User holds the user extracted by AuthN mechanism
User string `json:"User,omitempty"`
// UserAuthNMethod holds the mechanism used to extract user details (e.g., krb)
UserAuthNMethod string `json:"UserAuthNMethod,omitempty"`
// RequestMethod holds the HTTP method (GET/POST/PUT)
RequestMethod string `json:"RequestMethod,omitempty"`
// RequestUri holds the full HTTP uri (e.g., /v1.21/version)
RequestURI string `json:"RequestUri,omitempty"`
// RequestBody stores the raw request body sent to the docker daemon
RequestBody []byte `json:"RequestBody,omitempty"`
// RequestHeaders stores the raw request headers sent to the docker daemon
RequestHeaders map[string]string `json:"RequestHeaders,omitempty"`
// ResponseStatusCode stores the status code returned from docker daemon
ResponseStatusCode int `json:"ResponseStatusCode,omitempty"`
// ResponseBody stores the raw response body sent from docker daemon
ResponseBody []byte `json:"ResponseBody,omitempty"`
// ResponseHeaders stores the response headers sent to the docker daemon
ResponseHeaders map[string]string `json:"ResponseHeaders,omitempty"`
}
// Response represents authZ plugin response
type Response struct {
// Allow indicating whether the user is allowed or not
Allow bool `json:"Allow"`
// Msg stores the authorization message
Msg string `json:"Msg,omitempty"`
// Err stores a message in case there's an error
Err string `json:"Err,omitempty"`
}

View File

@@ -1,168 +0,0 @@
package authorization
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"strings"
"github.com/Sirupsen/logrus"
"github.com/hyperhq/hypercli/pkg/ioutils"
)
const maxBodySize = 1048576 // 1MB
// NewCtx creates new authZ context, it is used to store authorization information related to a specific docker
// REST http session
// A context provides two method:
// Authenticate Request:
// Call authZ plugins with current REST request and AuthN response
// Request contains full HTTP packet sent to the docker daemon
// https://docs.docker.com/reference/api/docker_remote_api/
//
// Authenticate Response:
// Call authZ plugins with full info about current REST request, REST response and AuthN response
// The response from this method may contains content that overrides the daemon response
// This allows authZ plugins to filter privileged content
//
// If multiple authZ plugins are specified, the block/allow decision is based on ANDing all plugin results
// For response manipulation, the response from each plugin is piped between plugins. Plugin execution order
// is determined according to daemon parameters
func NewCtx(authZPlugins []Plugin, user, userAuthNMethod, requestMethod, requestURI string) *Ctx {
return &Ctx{
plugins: authZPlugins,
user: user,
userAuthNMethod: userAuthNMethod,
requestMethod: requestMethod,
requestURI: requestURI,
}
}
// Ctx stores a a single request-response interaction context
type Ctx struct {
user string
userAuthNMethod string
requestMethod string
requestURI string
plugins []Plugin
// authReq stores the cached request object for the current transaction
authReq *Request
}
// AuthZRequest authorized the request to the docker daemon using authZ plugins
func (ctx *Ctx) AuthZRequest(w http.ResponseWriter, r *http.Request) error {
var body []byte
if sendBody(ctx.requestURI, r.Header) {
if r.ContentLength < maxBodySize {
var err error
body, r.Body, err = drainBody(r.Body)
if err != nil {
return err
}
}
}
var h bytes.Buffer
if err := r.Header.Write(&h); err != nil {
return err
}
ctx.authReq = &Request{
User: ctx.user,
UserAuthNMethod: ctx.userAuthNMethod,
RequestMethod: ctx.requestMethod,
RequestURI: ctx.requestURI,
RequestBody: body,
RequestHeaders: headers(r.Header),
}
for _, plugin := range ctx.plugins {
logrus.Debugf("AuthZ request using plugin %s", plugin.Name())
authRes, err := plugin.AuthZRequest(ctx.authReq)
if err != nil {
return fmt.Errorf("plugin %s failed with error: %s", plugin.Name(), err)
}
if !authRes.Allow {
return fmt.Errorf("authorization denied by plugin %s: %s", plugin.Name(), authRes.Msg)
}
}
return nil
}
// AuthZResponse authorized and manipulates the response from docker daemon using authZ plugins
func (ctx *Ctx) AuthZResponse(rm ResponseModifier, r *http.Request) error {
ctx.authReq.ResponseStatusCode = rm.StatusCode()
ctx.authReq.ResponseHeaders = headers(rm.Header())
if sendBody(ctx.requestURI, rm.Header()) {
ctx.authReq.ResponseBody = rm.RawBody()
}
for _, plugin := range ctx.plugins {
logrus.Debugf("AuthZ response using plugin %s", plugin.Name())
authRes, err := plugin.AuthZResponse(ctx.authReq)
if err != nil {
return fmt.Errorf("plugin %s failed with error: %s", plugin.Name(), err)
}
if !authRes.Allow {
return fmt.Errorf("authorization denied by plugin %s: %s", plugin.Name(), authRes.Msg)
}
}
rm.Flush()
return nil
}
// drainBody dump the body, it reads the body data into memory and
// see go sources /go/src/net/http/httputil/dump.go
func drainBody(body io.ReadCloser) ([]byte, io.ReadCloser, error) {
bufReader := bufio.NewReaderSize(body, maxBodySize)
newBody := ioutils.NewReadCloserWrapper(bufReader, func() error { return body.Close() })
data, err := bufReader.Peek(maxBodySize)
if err != io.EOF {
// This means the request is larger than our max
if err == bufio.ErrBufferFull {
return nil, newBody, nil
}
// This means we had an error reading
return nil, nil, err
}
return data, newBody, nil
}
// sendBody returns true when request/response body should be sent to AuthZPlugin
func sendBody(url string, header http.Header) bool {
// Skip body for auth endpoint
if strings.HasSuffix(url, "/auth") {
return false
}
// body is sent only for text or json messages
v := header.Get("Content-Type")
return strings.HasPrefix(v, "text/") || v == "application/json"
}
// headers returns flatten version of the http headers excluding authorization
func headers(header http.Header) map[string]string {
v := make(map[string]string, 0)
for k, values := range header {
// Skip authorization headers
if strings.EqualFold(k, "Authorization") || strings.EqualFold(k, "X-Registry-Config") || strings.EqualFold(k, "X-Registry-Auth") {
continue
}
for _, val := range values {
v[k] = val
}
}
return v
}

View File

@@ -1,233 +0,0 @@
package authorization
import (
"encoding/json"
"io/ioutil"
"log"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"testing"
"github.com/docker/go-connections/tlsconfig"
"github.com/gorilla/mux"
"github.com/hyperhq/hypercli/pkg/plugins"
)
const pluginAddress = "authzplugin.sock"
func TestAuthZRequestPluginError(t *testing.T) {
server := authZPluginTestServer{t: t}
go server.start()
defer server.stop()
authZPlugin := createTestPlugin(t)
request := Request{
User: "user",
RequestBody: []byte("sample body"),
RequestURI: "www.authz.com",
RequestMethod: "GET",
RequestHeaders: map[string]string{"header": "value"},
}
server.replayResponse = Response{
Err: "an error",
}
actualResponse, err := authZPlugin.AuthZRequest(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatalf("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatalf("Requests must be equal")
}
}
func TestAuthZRequestPlugin(t *testing.T) {
server := authZPluginTestServer{t: t}
go server.start()
defer server.stop()
authZPlugin := createTestPlugin(t)
request := Request{
User: "user",
RequestBody: []byte("sample body"),
RequestURI: "www.authz.com",
RequestMethod: "GET",
RequestHeaders: map[string]string{"header": "value"},
}
server.replayResponse = Response{
Allow: true,
Msg: "Sample message",
}
actualResponse, err := authZPlugin.AuthZRequest(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatalf("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatalf("Requests must be equal")
}
}
func TestAuthZResponsePlugin(t *testing.T) {
server := authZPluginTestServer{t: t}
go server.start()
defer server.stop()
authZPlugin := createTestPlugin(t)
request := Request{
User: "user",
RequestBody: []byte("sample body"),
}
server.replayResponse = Response{
Allow: true,
Msg: "Sample message",
}
actualResponse, err := authZPlugin.AuthZResponse(&request)
if err != nil {
t.Fatalf("Failed to authorize request %v", err)
}
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
t.Fatalf("Response must be equal")
}
if !reflect.DeepEqual(request, server.recordedRequest) {
t.Fatalf("Requests must be equal")
}
}
func TestResponseModifier(t *testing.T) {
r := httptest.NewRecorder()
m := NewResponseModifier(r)
m.Header().Set("h1", "v1")
m.Write([]byte("body"))
m.WriteHeader(500)
m.Flush()
if r.Header().Get("h1") != "v1" {
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
}
if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
t.Fatalf("Body value must exists %s", r.Body.Bytes())
}
if r.Code != 500 {
t.Fatalf("Status code must be correct %d", r.Code)
}
}
func TestResponseModifierOverride(t *testing.T) {
r := httptest.NewRecorder()
m := NewResponseModifier(r)
m.Header().Set("h1", "v1")
m.Write([]byte("body"))
m.WriteHeader(500)
overrideHeader := make(http.Header)
overrideHeader.Add("h1", "v2")
overrideHeaderBytes, err := json.Marshal(overrideHeader)
if err != nil {
t.Fatalf("override header failed %v", err)
}
m.OverrideHeader(overrideHeaderBytes)
m.OverrideBody([]byte("override body"))
m.OverrideStatusCode(404)
m.Flush()
if r.Header().Get("h1") != "v2" {
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
}
if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
t.Fatalf("Body value must exists %s", r.Body.Bytes())
}
if r.Code != 404 {
t.Fatalf("Status code must be correct %d", r.Code)
}
}
// createTestPlugin creates a new sample authorization plugin
func createTestPlugin(t *testing.T) *authorizationPlugin {
plugin := &plugins.Plugin{Name: "authz"}
pwd, err := os.Getwd()
if err != nil {
log.Fatal(err)
}
plugin.Client, err = plugins.NewClient("unix:///"+path.Join(pwd, pluginAddress), tlsconfig.Options{InsecureSkipVerify: true})
if err != nil {
t.Fatalf("Failed to create client %v", err)
}
return &authorizationPlugin{name: "plugin", plugin: plugin}
}
// AuthZPluginTestServer is a simple server that implements the authZ plugin interface
type authZPluginTestServer struct {
listener net.Listener
t *testing.T
// request stores the request sent from the daemon to the plugin
recordedRequest Request
// response stores the response sent from the plugin to the daemon
replayResponse Response
}
// start starts the test server that implements the plugin
func (t *authZPluginTestServer) start() {
r := mux.NewRouter()
os.Remove(pluginAddress)
l, err := net.ListenUnix("unix", &net.UnixAddr{Name: pluginAddress, Net: "unix"})
if err != nil {
t.t.Fatalf("Failed to listen %v", err)
}
t.listener = l
r.HandleFunc("/Plugin.Activate", t.activate)
r.HandleFunc("/"+AuthZApiRequest, t.auth)
r.HandleFunc("/"+AuthZApiResponse, t.auth)
t.listener, err = net.Listen("tcp", pluginAddress)
server := http.Server{Handler: r, Addr: pluginAddress}
server.Serve(l)
}
// stop stops the test server that implements the plugin
func (t *authZPluginTestServer) stop() {
os.Remove(pluginAddress)
if t.listener != nil {
t.listener.Close()
}
}
// auth is a used to record/replay the authentication api messages
func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
t.recordedRequest = Request{}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
json.Unmarshal(body, &t.recordedRequest)
b, err := json.Marshal(t.replayResponse)
if err != nil {
log.Fatal(err)
}
w.Write(b)
}
func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
if err != nil {
log.Fatal(err)
}
w.Write(b)
}

View File

@@ -1,83 +0,0 @@
package authorization
import "github.com/hyperhq/hypercli/pkg/plugins"
// Plugin allows third party plugins to authorize requests and responses
// in the context of docker API
type Plugin interface {
// Name returns the registered plugin name
Name() string
// AuthZRequest authorize the request from the client to the daemon
AuthZRequest(*Request) (*Response, error)
// AuthZResponse authorize the response from the daemon to the client
AuthZResponse(*Request) (*Response, error)
}
// NewPlugins constructs and initialize the authorization plugins based on plugin names
func NewPlugins(names []string) []Plugin {
plugins := []Plugin{}
pluginsMap := make(map[string]struct{})
for _, name := range names {
if _, ok := pluginsMap[name]; ok {
continue
}
pluginsMap[name] = struct{}{}
plugins = append(plugins, newAuthorizationPlugin(name))
}
return plugins
}
// authorizationPlugin is an internal adapter to docker plugin system
type authorizationPlugin struct {
plugin *plugins.Plugin
name string
}
func newAuthorizationPlugin(name string) Plugin {
return &authorizationPlugin{name: name}
}
func (a *authorizationPlugin) Name() string {
return a.name
}
func (a *authorizationPlugin) AuthZRequest(authReq *Request) (*Response, error) {
if err := a.initPlugin(); err != nil {
return nil, err
}
authRes := &Response{}
if err := a.plugin.Client.Call(AuthZApiRequest, authReq, authRes); err != nil {
return nil, err
}
return authRes, nil
}
func (a *authorizationPlugin) AuthZResponse(authReq *Request) (*Response, error) {
if err := a.initPlugin(); err != nil {
return nil, err
}
authRes := &Response{}
if err := a.plugin.Client.Call(AuthZApiResponse, authReq, authRes); err != nil {
return nil, err
}
return authRes, nil
}
// initPlugin initialize the authorization plugin if needed
func (a *authorizationPlugin) initPlugin() error {
// Lazy loading of plugins
if a.plugin == nil {
var err error
a.plugin, err = plugins.Get(a.name, AuthZApiImplements)
if err != nil {
return err
}
}
return nil
}

View File

@@ -1,136 +0,0 @@
package authorization
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"net"
"net/http"
)
// ResponseModifier allows authorization plugins to read and modify the content of the http.response
type ResponseModifier interface {
http.ResponseWriter
// RawBody returns the current http content
RawBody() []byte
// RawHeaders returns the current content of the http headers
RawHeaders() ([]byte, error)
// StatusCode returns the current status code
StatusCode() int
// OverrideBody replace the body of the HTTP reply
OverrideBody(b []byte)
// OverrideHeader replace the headers of the HTTP reply
OverrideHeader(b []byte) error
// OverrideStatusCode replaces the status code of the HTTP reply
OverrideStatusCode(statusCode int)
// Flush flushes all data to the HTTP response
Flush() error
}
// NewResponseModifier creates a wrapper to an http.ResponseWriter to allow inspecting and modifying the content
func NewResponseModifier(rw http.ResponseWriter) ResponseModifier {
return &responseModifier{rw: rw, header: make(http.Header)}
}
// responseModifier is used as an adapter to http.ResponseWriter in order to manipulate and explore
// the http request/response from docker daemon
type responseModifier struct {
// The original response writer
rw http.ResponseWriter
status int
// body holds the response body
body []byte
// header holds the response header
header http.Header
// statusCode holds the response status code
statusCode int
}
// WriterHeader stores the http status code
func (rm *responseModifier) WriteHeader(s int) {
rm.statusCode = s
}
// Header returns the internal http header
func (rm *responseModifier) Header() http.Header {
return rm.header
}
// Header returns the internal http header
func (rm *responseModifier) StatusCode() int {
return rm.statusCode
}
// Override replace the body of the HTTP reply
func (rm *responseModifier) OverrideBody(b []byte) {
rm.body = b
}
func (rm *responseModifier) OverrideStatusCode(statusCode int) {
rm.statusCode = statusCode
}
// Override replace the headers of the HTTP reply
func (rm *responseModifier) OverrideHeader(b []byte) error {
header := http.Header{}
if err := json.Unmarshal(b, &header); err != nil {
return err
}
rm.header = header
return nil
}
// Write stores the byte array inside content
func (rm *responseModifier) Write(b []byte) (int, error) {
rm.body = append(rm.body, b...)
return len(b), nil
}
// Body returns the response body
func (rm *responseModifier) RawBody() []byte {
return rm.body
}
func (rm *responseModifier) RawHeaders() ([]byte, error) {
var b bytes.Buffer
if err := rm.header.Write(&b); err != nil {
return nil, err
}
return b.Bytes(), nil
}
// Hijack returns the internal connection of the wrapped http.ResponseWriter
func (rm *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rm.rw.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("Internal reponse writer doesn't support the Hijacker interface")
}
return hijacker.Hijack()
}
// Flush flushes all data to the HTTP response
func (rm *responseModifier) Flush() error {
// Copy the status code
if rm.statusCode > 0 {
rm.rw.WriteHeader(rm.statusCode)
}
// Copy the header
for k, vv := range rm.header {
for _, v := range vv {
rm.rw.Header().Add(k, v)
}
}
// Write body
_, err := rm.rw.Write(rm.body)
return err
}

View File

@@ -1,49 +0,0 @@
package broadcaster
import (
"io"
"sync"
)
// Unbuffered accumulates multiple io.WriteCloser by stream.
type Unbuffered struct {
mu sync.Mutex
writers []io.WriteCloser
}
// Add adds new io.WriteCloser.
func (w *Unbuffered) Add(writer io.WriteCloser) {
w.mu.Lock()
w.writers = append(w.writers, writer)
w.mu.Unlock()
}
// Write writes bytes to all writers. Failed writers will be evicted during
// this call.
func (w *Unbuffered) Write(p []byte) (n int, err error) {
w.mu.Lock()
var evict []int
for i, sw := range w.writers {
if n, err := sw.Write(p); err != nil || n != len(p) {
// On error, evict the writer
evict = append(evict, i)
}
}
for n, i := range evict {
w.writers = append(w.writers[:i-n], w.writers[i-n+1:]...)
}
w.mu.Unlock()
return len(p), nil
}
// Clean closes and removes all writers. Last non-eol-terminated part of data
// will be saved.
func (w *Unbuffered) Clean() error {
w.mu.Lock()
for _, sw := range w.writers {
sw.Close()
}
w.writers = nil
w.mu.Unlock()
return nil
}

View File

@@ -1,162 +0,0 @@
package broadcaster
import (
"bytes"
"errors"
"strings"
"testing"
)
type dummyWriter struct {
buffer bytes.Buffer
failOnWrite bool
}
func (dw *dummyWriter) Write(p []byte) (n int, err error) {
if dw.failOnWrite {
return 0, errors.New("Fake fail")
}
return dw.buffer.Write(p)
}
func (dw *dummyWriter) String() string {
return dw.buffer.String()
}
func (dw *dummyWriter) Close() error {
return nil
}
func TestUnbuffered(t *testing.T) {
writer := new(Unbuffered)
// Test 1: Both bufferA and bufferB should contain "foo"
bufferA := &dummyWriter{}
writer.Add(bufferA)
bufferB := &dummyWriter{}
writer.Add(bufferB)
writer.Write([]byte("foo"))
if bufferA.String() != "foo" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferB.String() != "foo" {
t.Errorf("Buffer contains %v", bufferB.String())
}
// Test2: bufferA and bufferB should contain "foobar",
// while bufferC should only contain "bar"
bufferC := &dummyWriter{}
writer.Add(bufferC)
writer.Write([]byte("bar"))
if bufferA.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferB.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferB.String())
}
if bufferC.String() != "bar" {
t.Errorf("Buffer contains %v", bufferC.String())
}
// Test3: Test eviction on failure
bufferA.failOnWrite = true
writer.Write([]byte("fail"))
if bufferA.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferC.String() != "barfail" {
t.Errorf("Buffer contains %v", bufferC.String())
}
// Even though we reset the flag, no more writes should go in there
bufferA.failOnWrite = false
writer.Write([]byte("test"))
if bufferA.String() != "foobar" {
t.Errorf("Buffer contains %v", bufferA.String())
}
if bufferC.String() != "barfailtest" {
t.Errorf("Buffer contains %v", bufferC.String())
}
// Test4: Test eviction on multiple simultaneous failures
bufferB.failOnWrite = true
bufferC.failOnWrite = true
bufferD := &dummyWriter{}
writer.Add(bufferD)
writer.Write([]byte("yo"))
writer.Write([]byte("ink"))
if strings.Contains(bufferB.String(), "yoink") {
t.Errorf("bufferB received write. contents: %q", bufferB)
}
if strings.Contains(bufferC.String(), "yoink") {
t.Errorf("bufferC received write. contents: %q", bufferC)
}
if g, w := bufferD.String(), "yoink"; g != w {
t.Errorf("bufferD = %q, want %q", g, w)
}
writer.Clean()
}
type devNullCloser int
func (d devNullCloser) Close() error {
return nil
}
func (d devNullCloser) Write(buf []byte) (int, error) {
return len(buf), nil
}
// This test checks for races. It is only useful when run with the race detector.
func TestRaceUnbuffered(t *testing.T) {
writer := new(Unbuffered)
c := make(chan bool)
go func() {
writer.Add(devNullCloser(0))
c <- true
}()
writer.Write([]byte("hello"))
<-c
}
func BenchmarkUnbuffered(b *testing.B) {
writer := new(Unbuffered)
setUpWriter := func() {
for i := 0; i < 100; i++ {
writer.Add(devNullCloser(0))
writer.Add(devNullCloser(0))
writer.Add(devNullCloser(0))
}
}
testLine := "Line that thinks that it is log line from docker"
var buf bytes.Buffer
for i := 0; i < 100; i++ {
buf.Write([]byte(testLine + "\n"))
}
// line without eol
buf.Write([]byte(testLine))
testText := buf.Bytes()
b.SetBytes(int64(5 * len(testText)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
setUpWriter()
b.StartTimer()
for j := 0; j < 5; j++ {
if _, err := writer.Write(testText); err != nil {
b.Fatal(err)
}
}
b.StopTimer()
writer.Clean()
b.StartTimer()
}
}

View File

@@ -1,381 +0,0 @@
package chrootarchive
import (
"bytes"
"fmt"
"hash/crc32"
"io"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/hyperhq/hypercli/pkg/archive"
"github.com/hyperhq/hypercli/pkg/reexec"
"github.com/hyperhq/hypercli/pkg/system"
)
func init() {
reexec.Init()
}
func TestChrootTarUntar(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootTarUntar")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "toto"), []byte("hello toto"), 0644); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "lolo"), []byte("hello lolo"), 0644); err != nil {
t.Fatal(err)
}
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
if err := Untar(stream, dest, &archive.TarOptions{ExcludePatterns: []string{"lolo"}}); err != nil {
t.Fatal(err)
}
}
// gh#10426: Verify the fix for having a huge excludes list (like on `docker load` with large # of
// local images)
func TestChrootUntarWithHugeExcludesList(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarHugeExcludes")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "toto"), []byte("hello toto"), 0644); err != nil {
t.Fatal(err)
}
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
options := &archive.TarOptions{}
//65534 entries of 64-byte strings ~= 4MB of environment space which should overflow
//on most systems when passed via environment or command line arguments
excludes := make([]string, 65534, 65534)
for i := 0; i < 65534; i++ {
excludes[i] = strings.Repeat(string(i), 64)
}
options.ExcludePatterns = excludes
if err := Untar(stream, dest, options); err != nil {
t.Fatal(err)
}
}
func TestChrootUntarEmptyArchive(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarEmptyArchive")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
if err := Untar(nil, tmpdir, nil); err == nil {
t.Fatal("expected error on empty archive")
}
}
func prepareSourceDirectory(numberOfFiles int, targetPath string, makeSymLinks bool) (int, error) {
fileData := []byte("fooo")
for n := 0; n < numberOfFiles; n++ {
fileName := fmt.Sprintf("file-%d", n)
if err := ioutil.WriteFile(filepath.Join(targetPath, fileName), fileData, 0700); err != nil {
return 0, err
}
if makeSymLinks {
if err := os.Symlink(filepath.Join(targetPath, fileName), filepath.Join(targetPath, fileName+"-link")); err != nil {
return 0, err
}
}
}
totalSize := numberOfFiles * len(fileData)
return totalSize, nil
}
func getHash(filename string) (uint32, error) {
stream, err := ioutil.ReadFile(filename)
if err != nil {
return 0, err
}
hash := crc32.NewIEEE()
hash.Write(stream)
return hash.Sum32(), nil
}
func compareDirectories(src string, dest string) error {
changes, err := archive.ChangesDirs(dest, src)
if err != nil {
return err
}
if len(changes) > 0 {
return fmt.Errorf("Unexpected differences after untar: %v", changes)
}
return nil
}
func compareFiles(src string, dest string) error {
srcHash, err := getHash(src)
if err != nil {
return err
}
destHash, err := getHash(dest)
if err != nil {
return err
}
if srcHash != destHash {
return fmt.Errorf("%s is different from %s", src, dest)
}
return nil
}
func TestChrootTarUntarWithSymlink(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootTarUntarWithSymlink")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
if err := TarUntar(src, dest); err != nil {
t.Fatal(err)
}
if err := compareDirectories(src, dest); err != nil {
t.Fatal(err)
}
}
func TestChrootCopyWithTar(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootCopyWithTar")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
// Copy directory
dest := filepath.Join(tmpdir, "dest")
if err := CopyWithTar(src, dest); err != nil {
t.Fatal(err)
}
if err := compareDirectories(src, dest); err != nil {
t.Fatal(err)
}
// Copy file
srcfile := filepath.Join(src, "file-1")
dest = filepath.Join(tmpdir, "destFile")
destfile := filepath.Join(dest, "file-1")
if err := CopyWithTar(srcfile, destfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcfile, destfile); err != nil {
t.Fatal(err)
}
// Copy symbolic link
srcLinkfile := filepath.Join(src, "file-1-link")
dest = filepath.Join(tmpdir, "destSymlink")
destLinkfile := filepath.Join(dest, "file-1-link")
if err := CopyWithTar(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
}
func TestChrootCopyFileWithTar(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootCopyFileWithTar")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
// Copy directory
dest := filepath.Join(tmpdir, "dest")
if err := CopyFileWithTar(src, dest); err == nil {
t.Fatal("Expected error on copying directory")
}
// Copy file
srcfile := filepath.Join(src, "file-1")
dest = filepath.Join(tmpdir, "destFile")
destfile := filepath.Join(dest, "file-1")
if err := CopyFileWithTar(srcfile, destfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcfile, destfile); err != nil {
t.Fatal(err)
}
// Copy symbolic link
srcLinkfile := filepath.Join(src, "file-1-link")
dest = filepath.Join(tmpdir, "destSymlink")
destLinkfile := filepath.Join(dest, "file-1-link")
if err := CopyFileWithTar(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
if err := compareFiles(srcLinkfile, destLinkfile); err != nil {
t.Fatal(err)
}
}
func TestChrootUntarPath(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarPath")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if _, err := prepareSourceDirectory(10, src, true); err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
// Untar a directory
if err := UntarPath(src, dest); err == nil {
t.Fatal("Expected error on untaring a directory")
}
// Untar a tar file
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
buf := new(bytes.Buffer)
buf.ReadFrom(stream)
tarfile := filepath.Join(tmpdir, "src.tar")
if err := ioutil.WriteFile(tarfile, buf.Bytes(), 0644); err != nil {
t.Fatal(err)
}
if err := UntarPath(tarfile, dest); err != nil {
t.Fatal(err)
}
if err := compareDirectories(src, dest); err != nil {
t.Fatal(err)
}
}
type slowEmptyTarReader struct {
size int
offset int
chunkSize int
}
// Read is a slow reader of an empty tar (like the output of "tar c --files-from /dev/null")
func (s *slowEmptyTarReader) Read(p []byte) (int, error) {
time.Sleep(100 * time.Millisecond)
count := s.chunkSize
if len(p) < s.chunkSize {
count = len(p)
}
for i := 0; i < count; i++ {
p[i] = 0
}
s.offset += count
if s.offset > s.size {
return count, io.EOF
}
return count, nil
}
func TestChrootUntarEmptyArchiveFromSlowReader(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootUntarEmptyArchiveFromSlowReader")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
stream := &slowEmptyTarReader{size: 10240, chunkSize: 1024}
if err := Untar(stream, dest, nil); err != nil {
t.Fatal(err)
}
}
func TestChrootApplyEmptyArchiveFromSlowReader(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootApplyEmptyArchiveFromSlowReader")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
stream := &slowEmptyTarReader{size: 10240, chunkSize: 1024}
if _, err := ApplyLayer(dest, stream); err != nil {
t.Fatal(err)
}
}
func TestChrootApplyDotDotFile(t *testing.T) {
tmpdir, err := ioutil.TempDir("", "docker-TestChrootApplyDotDotFile")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpdir)
src := filepath.Join(tmpdir, "src")
if err := system.MkdirAll(src, 0700); err != nil {
t.Fatal(err)
}
if err := ioutil.WriteFile(filepath.Join(src, "..gitme"), []byte(""), 0644); err != nil {
t.Fatal(err)
}
stream, err := archive.Tar(src, archive.Uncompressed)
if err != nil {
t.Fatal(err)
}
dest := filepath.Join(tmpdir, "dest")
if err := system.MkdirAll(dest, 0700); err != nil {
t.Fatal(err)
}
if _, err := ApplyLayer(dest, stream); err != nil {
t.Fatal(err)
}
}

View File

@@ -1,807 +0,0 @@
// +build linux
package devicemapper
import (
"errors"
"fmt"
"os"
"runtime"
"syscall"
"unsafe"
"github.com/Sirupsen/logrus"
)
// DevmapperLogger defines methods for logging with devicemapper.
type DevmapperLogger interface {
DMLog(level int, file string, line int, dmError int, message string)
}
const (
deviceCreate TaskType = iota
deviceReload
deviceRemove
deviceRemoveAll
deviceSuspend
deviceResume
deviceInfo
deviceDeps
deviceRename
deviceVersion
deviceStatus
deviceTable
deviceWaitevent
deviceList
deviceClear
deviceMknodes
deviceListVersions
deviceTargetMsg
deviceSetGeometry
)
const (
addNodeOnResume AddNodeType = iota
addNodeOnCreate
)
// List of errors returned when using devicemapper.
var (
ErrTaskRun = errors.New("dm_task_run failed")
ErrTaskSetName = errors.New("dm_task_set_name failed")
ErrTaskSetMessage = errors.New("dm_task_set_message failed")
ErrTaskSetAddNode = errors.New("dm_task_set_add_node failed")
ErrTaskSetRo = errors.New("dm_task_set_ro failed")
ErrTaskAddTarget = errors.New("dm_task_add_target failed")
ErrTaskSetSector = errors.New("dm_task_set_sector failed")
ErrTaskGetDeps = errors.New("dm_task_get_deps failed")
ErrTaskGetInfo = errors.New("dm_task_get_info failed")
ErrTaskGetDriverVersion = errors.New("dm_task_get_driver_version failed")
ErrTaskDeferredRemove = errors.New("dm_task_deferred_remove failed")
ErrTaskSetCookie = errors.New("dm_task_set_cookie failed")
ErrNilCookie = errors.New("cookie ptr can't be nil")
ErrGetBlockSize = errors.New("Can't get block size")
ErrUdevWait = errors.New("wait on udev cookie failed")
ErrSetDevDir = errors.New("dm_set_dev_dir failed")
ErrGetLibraryVersion = errors.New("dm_get_library_version failed")
ErrCreateRemoveTask = errors.New("Can't create task of type deviceRemove")
ErrRunRemoveDevice = errors.New("running RemoveDevice failed")
ErrInvalidAddNode = errors.New("Invalid AddNode type")
ErrBusy = errors.New("Device is Busy")
ErrDeviceIDExists = errors.New("Device Id Exists")
ErrEnxio = errors.New("No such device or address")
)
var (
dmSawBusy bool
dmSawExist bool
dmSawEnxio bool // No Such Device or Address
)
type (
// Task represents a devicemapper task (like lvcreate, etc.) ; a task is needed for each ioctl
// command to execute.
Task struct {
unmanaged *cdmTask
}
// Deps represents dependents (layer) of a device.
Deps struct {
Count uint32
Filler uint32
Device []uint64
}
// Info represents information about a device.
Info struct {
Exists int
Suspended int
LiveTable int
InactiveTable int
OpenCount int32
EventNr uint32
Major uint32
Minor uint32
ReadOnly int
TargetCount int32
DeferredRemove int
}
// TaskType represents a type of task
TaskType int
// AddNodeType represents a type of node to be added
AddNodeType int
)
// DeviceIDExists returns whether error conveys the information about device Id already
// exist or not. This will be true if device creation or snap creation
// operation fails if device or snap device already exists in pool.
// Current implementation is little crude as it scans the error string
// for exact pattern match. Replacing it with more robust implementation
// is desirable.
func DeviceIDExists(err error) bool {
return fmt.Sprint(err) == fmt.Sprint(ErrDeviceIDExists)
}
func (t *Task) destroy() {
if t != nil {
DmTaskDestroy(t.unmanaged)
runtime.SetFinalizer(t, nil)
}
}
// TaskCreateNamed is a convenience function for TaskCreate when a name
// will be set on the task as well
func TaskCreateNamed(t TaskType, name string) (*Task, error) {
task := TaskCreate(t)
if task == nil {
return nil, fmt.Errorf("devicemapper: Can't create task of type %d", int(t))
}
if err := task.setName(name); err != nil {
return nil, fmt.Errorf("devicemapper: Can't set task name %s", name)
}
return task, nil
}
// TaskCreate initializes a devicemapper task of tasktype
func TaskCreate(tasktype TaskType) *Task {
Ctask := DmTaskCreate(int(tasktype))
if Ctask == nil {
return nil
}
task := &Task{unmanaged: Ctask}
runtime.SetFinalizer(task, (*Task).destroy)
return task
}
func (t *Task) run() error {
if res := DmTaskRun(t.unmanaged); res != 1 {
return ErrTaskRun
}
return nil
}
func (t *Task) setName(name string) error {
if res := DmTaskSetName(t.unmanaged, name); res != 1 {
return ErrTaskSetName
}
return nil
}
func (t *Task) setMessage(message string) error {
if res := DmTaskSetMessage(t.unmanaged, message); res != 1 {
return ErrTaskSetMessage
}
return nil
}
func (t *Task) setSector(sector uint64) error {
if res := DmTaskSetSector(t.unmanaged, sector); res != 1 {
return ErrTaskSetSector
}
return nil
}
func (t *Task) setCookie(cookie *uint, flags uint16) error {
if cookie == nil {
return ErrNilCookie
}
if res := DmTaskSetCookie(t.unmanaged, cookie, flags); res != 1 {
return ErrTaskSetCookie
}
return nil
}
func (t *Task) setAddNode(addNode AddNodeType) error {
if addNode != addNodeOnResume && addNode != addNodeOnCreate {
return ErrInvalidAddNode
}
if res := DmTaskSetAddNode(t.unmanaged, addNode); res != 1 {
return ErrTaskSetAddNode
}
return nil
}
func (t *Task) setRo() error {
if res := DmTaskSetRo(t.unmanaged); res != 1 {
return ErrTaskSetRo
}
return nil
}
func (t *Task) addTarget(start, size uint64, ttype, params string) error {
if res := DmTaskAddTarget(t.unmanaged, start, size,
ttype, params); res != 1 {
return ErrTaskAddTarget
}
return nil
}
func (t *Task) getDeps() (*Deps, error) {
var deps *Deps
if deps = DmTaskGetDeps(t.unmanaged); deps == nil {
return nil, ErrTaskGetDeps
}
return deps, nil
}
func (t *Task) getInfo() (*Info, error) {
info := &Info{}
if res := DmTaskGetInfo(t.unmanaged, info); res != 1 {
return nil, ErrTaskGetInfo
}
return info, nil
}
func (t *Task) getInfoWithDeferred() (*Info, error) {
info := &Info{}
if res := DmTaskGetInfoWithDeferred(t.unmanaged, info); res != 1 {
return nil, ErrTaskGetInfo
}
return info, nil
}
func (t *Task) getDriverVersion() (string, error) {
res := DmTaskGetDriverVersion(t.unmanaged)
if res == "" {
return "", ErrTaskGetDriverVersion
}
return res, nil
}
func (t *Task) getNextTarget(next unsafe.Pointer) (nextPtr unsafe.Pointer, start uint64,
length uint64, targetType string, params string) {
return DmGetNextTarget(t.unmanaged, next, &start, &length,
&targetType, &params),
start, length, targetType, params
}
// UdevWait waits for any processes that are waiting for udev to complete the specified cookie.
func UdevWait(cookie *uint) error {
if res := DmUdevWait(*cookie); res != 1 {
logrus.Debugf("devicemapper: Failed to wait on udev cookie %d", *cookie)
return ErrUdevWait
}
return nil
}
// LogInitVerbose is an interface to initialize the verbose logger for the device mapper library.
func LogInitVerbose(level int) {
DmLogInitVerbose(level)
}
var dmLogger DevmapperLogger
// LogInit initializes the logger for the device mapper library.
func LogInit(logger DevmapperLogger) {
dmLogger = logger
LogWithErrnoInit()
}
// SetDevDir sets the dev folder for the device mapper library (usually /dev).
func SetDevDir(dir string) error {
if res := DmSetDevDir(dir); res != 1 {
logrus.Debugf("devicemapper: Error dm_set_dev_dir")
return ErrSetDevDir
}
return nil
}
// GetLibraryVersion returns the device mapper library version.
func GetLibraryVersion() (string, error) {
var version string
if res := DmGetLibraryVersion(&version); res != 1 {
return "", ErrGetLibraryVersion
}
return version, nil
}
// UdevSyncSupported returns whether device-mapper is able to sync with udev
//
// This is essential otherwise race conditions can arise where both udev and
// device-mapper attempt to create and destroy devices.
func UdevSyncSupported() bool {
return DmUdevGetSyncSupport() != 0
}
// UdevSetSyncSupport allows setting whether the udev sync should be enabled.
// The return bool indicates the state of whether the sync is enabled.
func UdevSetSyncSupport(enable bool) bool {
if enable {
DmUdevSetSyncSupport(1)
} else {
DmUdevSetSyncSupport(0)
}
return UdevSyncSupported()
}
// CookieSupported returns whether the version of device-mapper supports the
// use of cookie's in the tasks.
// This is largely a lower level call that other functions use.
func CookieSupported() bool {
return DmCookieSupported() != 0
}
// RemoveDevice is a useful helper for cleaning up a device.
func RemoveDevice(name string) error {
task, err := TaskCreateNamed(deviceRemove, name)
if task == nil {
return err
}
var cookie uint
if err := task.setCookie(&cookie, 0); err != nil {
return fmt.Errorf("devicemapper: Can not set cookie: %s", err)
}
defer UdevWait(&cookie)
dmSawBusy = false // reset before the task is run
if err = task.run(); err != nil {
if dmSawBusy {
return ErrBusy
}
return fmt.Errorf("devicemapper: Error running RemoveDevice %s", err)
}
return nil
}
// RemoveDeviceDeferred is a useful helper for cleaning up a device, but deferred.
func RemoveDeviceDeferred(name string) error {
logrus.Debugf("devicemapper: RemoveDeviceDeferred START(%s)", name)
defer logrus.Debugf("devicemapper: RemoveDeviceDeferred END(%s)", name)
task, err := TaskCreateNamed(deviceRemove, name)
if task == nil {
return err
}
if err := DmTaskDeferredRemove(task.unmanaged); err != 1 {
return ErrTaskDeferredRemove
}
if err = task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running RemoveDeviceDeferred %s", err)
}
return nil
}
// CancelDeferredRemove cancels a deferred remove for a device.
func CancelDeferredRemove(deviceName string) error {
task, err := TaskCreateNamed(deviceTargetMsg, deviceName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("@cancel_deferred_remove")); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawBusy = false
dmSawEnxio = false
if err := task.run(); err != nil {
// A device might be being deleted already
if dmSawBusy {
return ErrBusy
} else if dmSawEnxio {
return ErrEnxio
}
return fmt.Errorf("devicemapper: Error running CancelDeferredRemove %s", err)
}
return nil
}
// GetBlockDeviceSize returns the size of a block device identified by the specified file.
func GetBlockDeviceSize(file *os.File) (uint64, error) {
size, err := ioctlBlkGetSize64(file.Fd())
if err != nil {
logrus.Errorf("devicemapper: Error getblockdevicesize: %s", err)
return 0, ErrGetBlockSize
}
return uint64(size), nil
}
// BlockDeviceDiscard runs discard for the given path.
// This is used as a workaround for the kernel not discarding block so
// on the thin pool when we remove a thinp device, so we do it
// manually
func BlockDeviceDiscard(path string) error {
file, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil {
return err
}
defer file.Close()
size, err := GetBlockDeviceSize(file)
if err != nil {
return err
}
if err := ioctlBlkDiscard(file.Fd(), 0, size); err != nil {
return err
}
// Without this sometimes the remove of the device that happens after
// discard fails with EBUSY.
syscall.Sync()
return nil
}
// CreatePool is the programmatic example of "dmsetup create".
// It creates a device with the specified poolName, data and metadata file and block size.
func CreatePool(poolName string, dataFile, metadataFile *os.File, poolBlockSize uint32) error {
task, err := TaskCreateNamed(deviceCreate, poolName)
if task == nil {
return err
}
size, err := GetBlockDeviceSize(dataFile)
if err != nil {
return fmt.Errorf("devicemapper: Can't get data size %s", err)
}
params := fmt.Sprintf("%s %s %d 32768 1 skip_block_zeroing", metadataFile.Name(), dataFile.Name(), poolBlockSize)
if err := task.addTarget(0, size/512, "thin-pool", params); err != nil {
return fmt.Errorf("devicemapper: Can't add target %s", err)
}
var cookie uint
var flags uint16
flags = DmUdevDisableSubsystemRulesFlag | DmUdevDisableDiskRulesFlag | DmUdevDisableOtherRulesFlag
if err := task.setCookie(&cookie, flags); err != nil {
return fmt.Errorf("devicemapper: Can't set cookie %s", err)
}
defer UdevWait(&cookie)
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceCreate (CreatePool) %s", err)
}
return nil
}
// ReloadPool is the programmatic example of "dmsetup reload".
// It reloads the table with the specified poolName, data and metadata file and block size.
func ReloadPool(poolName string, dataFile, metadataFile *os.File, poolBlockSize uint32) error {
task, err := TaskCreateNamed(deviceReload, poolName)
if task == nil {
return err
}
size, err := GetBlockDeviceSize(dataFile)
if err != nil {
return fmt.Errorf("devicemapper: Can't get data size %s", err)
}
params := fmt.Sprintf("%s %s %d 32768 1 skip_block_zeroing", metadataFile.Name(), dataFile.Name(), poolBlockSize)
if err := task.addTarget(0, size/512, "thin-pool", params); err != nil {
return fmt.Errorf("devicemapper: Can't add target %s", err)
}
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceCreate %s", err)
}
return nil
}
// GetDeps is the programmatic example of "dmsetup deps".
// It outputs a list of devices referenced by the live table for the specified device.
func GetDeps(name string) (*Deps, error) {
task, err := TaskCreateNamed(deviceDeps, name)
if task == nil {
return nil, err
}
if err := task.run(); err != nil {
return nil, err
}
return task.getDeps()
}
// GetInfo is the programmatic example of "dmsetup info".
// It outputs some brief information about the device.
func GetInfo(name string) (*Info, error) {
task, err := TaskCreateNamed(deviceInfo, name)
if task == nil {
return nil, err
}
if err := task.run(); err != nil {
return nil, err
}
return task.getInfo()
}
// GetInfoWithDeferred is the programmatic example of "dmsetup info", but deferred.
// It outputs some brief information about the device.
func GetInfoWithDeferred(name string) (*Info, error) {
task, err := TaskCreateNamed(deviceInfo, name)
if task == nil {
return nil, err
}
if err := task.run(); err != nil {
return nil, err
}
return task.getInfoWithDeferred()
}
// GetDriverVersion is the programmatic example of "dmsetup version".
// It outputs version information of the driver.
func GetDriverVersion() (string, error) {
task := TaskCreate(deviceVersion)
if task == nil {
return "", fmt.Errorf("devicemapper: Can't create deviceVersion task")
}
if err := task.run(); err != nil {
return "", err
}
return task.getDriverVersion()
}
// GetStatus is the programmatic example of "dmsetup status".
// It outputs status information for the specified device name.
func GetStatus(name string) (uint64, uint64, string, string, error) {
task, err := TaskCreateNamed(deviceStatus, name)
if task == nil {
logrus.Debugf("devicemapper: GetStatus() Error TaskCreateNamed: %s", err)
return 0, 0, "", "", err
}
if err := task.run(); err != nil {
logrus.Debugf("devicemapper: GetStatus() Error Run: %s", err)
return 0, 0, "", "", err
}
devinfo, err := task.getInfo()
if err != nil {
logrus.Debugf("devicemapper: GetStatus() Error GetInfo: %s", err)
return 0, 0, "", "", err
}
if devinfo.Exists == 0 {
logrus.Debugf("devicemapper: GetStatus() Non existing device %s", name)
return 0, 0, "", "", fmt.Errorf("devicemapper: Non existing device %s", name)
}
_, start, length, targetType, params := task.getNextTarget(unsafe.Pointer(nil))
return start, length, targetType, params, nil
}
// GetTable is the programmatic example for "dmsetup table".
// It outputs the current table for the specified device name.
func GetTable(name string) (uint64, uint64, string, string, error) {
task, err := TaskCreateNamed(deviceTable, name)
if task == nil {
logrus.Debugf("devicemapper: GetTable() Error TaskCreateNamed: %s", err)
return 0, 0, "", "", err
}
if err := task.run(); err != nil {
logrus.Debugf("devicemapper: GetTable() Error Run: %s", err)
return 0, 0, "", "", err
}
devinfo, err := task.getInfo()
if err != nil {
logrus.Debugf("devicemapper: GetTable() Error GetInfo: %s", err)
return 0, 0, "", "", err
}
if devinfo.Exists == 0 {
logrus.Debugf("devicemapper: GetTable() Non existing device %s", name)
return 0, 0, "", "", fmt.Errorf("devicemapper: Non existing device %s", name)
}
_, start, length, targetType, params := task.getNextTarget(unsafe.Pointer(nil))
return start, length, targetType, params, nil
}
// SetTransactionID sets a transaction id for the specified device name.
func SetTransactionID(poolName string, oldID uint64, newID uint64) error {
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("set_transaction_id %d %d", oldID, newID)); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running SetTransactionID %s", err)
}
return nil
}
// SuspendDevice is the programmatic example of "dmsetup suspend".
// It suspends the specified device.
func SuspendDevice(name string) error {
task, err := TaskCreateNamed(deviceSuspend, name)
if task == nil {
return err
}
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceSuspend %s", err)
}
return nil
}
// ResumeDevice is the programmatic example of "dmsetup resume".
// It un-suspends the specified device.
func ResumeDevice(name string) error {
task, err := TaskCreateNamed(deviceResume, name)
if task == nil {
return err
}
var cookie uint
if err := task.setCookie(&cookie, 0); err != nil {
return fmt.Errorf("devicemapper: Can't set cookie %s", err)
}
defer UdevWait(&cookie)
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceResume %s", err)
}
return nil
}
// CreateDevice creates a device with the specified poolName with the specified device id.
func CreateDevice(poolName string, deviceID int) error {
logrus.Debugf("devicemapper: CreateDevice(poolName=%v, deviceID=%v)", poolName, deviceID)
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("create_thin %d", deviceID)); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawExist = false // reset before the task is run
if err := task.run(); err != nil {
// Caller wants to know about ErrDeviceIDExists so that it can try with a different device id.
if dmSawExist {
return ErrDeviceIDExists
}
return fmt.Errorf("devicemapper: Error running CreateDevice %s", err)
}
return nil
}
// DeleteDevice deletes a device with the specified poolName with the specified device id.
func DeleteDevice(poolName string, deviceID int) error {
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
return err
}
if err := task.setSector(0); err != nil {
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("delete %d", deviceID)); err != nil {
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawBusy = false
if err := task.run(); err != nil {
if dmSawBusy {
return ErrBusy
}
return fmt.Errorf("devicemapper: Error running DeleteDevice %s", err)
}
return nil
}
// ActivateDevice activates the device identified by the specified
// poolName, name and deviceID with the specified size.
func ActivateDevice(poolName string, name string, deviceID int, size uint64) error {
return activateDevice(poolName, name, deviceID, size, "")
}
// ActivateDeviceWithExternal activates the device identified by the specified
// poolName, name and deviceID with the specified size.
func ActivateDeviceWithExternal(poolName string, name string, deviceID int, size uint64, external string) error {
return activateDevice(poolName, name, deviceID, size, external)
}
func activateDevice(poolName string, name string, deviceID int, size uint64, external string) error {
task, err := TaskCreateNamed(deviceCreate, name)
if task == nil {
return err
}
var params string
if len(external) > 0 {
params = fmt.Sprintf("%s %d %s", poolName, deviceID, external)
} else {
params = fmt.Sprintf("%s %d", poolName, deviceID)
}
if err := task.addTarget(0, size/512, "thin", params); err != nil {
return fmt.Errorf("devicemapper: Can't add target %s", err)
}
if err := task.setAddNode(addNodeOnCreate); err != nil {
return fmt.Errorf("devicemapper: Can't add node %s", err)
}
var cookie uint
if err := task.setCookie(&cookie, 0); err != nil {
return fmt.Errorf("devicemapper: Can't set cookie %s", err)
}
defer UdevWait(&cookie)
if err := task.run(); err != nil {
return fmt.Errorf("devicemapper: Error running deviceCreate (ActivateDevice) %s", err)
}
return nil
}
// CreateSnapDevice creates a snapshot based on the device identified by the baseName and baseDeviceId,
func CreateSnapDevice(poolName string, deviceID int, baseName string, baseDeviceID int) error {
devinfo, _ := GetInfo(baseName)
doSuspend := devinfo != nil && devinfo.Exists != 0
if doSuspend {
if err := SuspendDevice(baseName); err != nil {
return err
}
}
task, err := TaskCreateNamed(deviceTargetMsg, poolName)
if task == nil {
if doSuspend {
ResumeDevice(baseName)
}
return err
}
if err := task.setSector(0); err != nil {
if doSuspend {
ResumeDevice(baseName)
}
return fmt.Errorf("devicemapper: Can't set sector %s", err)
}
if err := task.setMessage(fmt.Sprintf("create_snap %d %d", deviceID, baseDeviceID)); err != nil {
if doSuspend {
ResumeDevice(baseName)
}
return fmt.Errorf("devicemapper: Can't set message %s", err)
}
dmSawExist = false // reset before the task is run
if err := task.run(); err != nil {
if doSuspend {
ResumeDevice(baseName)
}
// Caller wants to know about ErrDeviceIDExists so that it can try with a different device id.
if dmSawExist {
return ErrDeviceIDExists
}
return fmt.Errorf("devicemapper: Error running deviceCreate (createSnapDevice) %s", err)
}
if doSuspend {
if err := ResumeDevice(baseName); err != nil {
return err
}
}
return nil
}

View File

@@ -1,35 +0,0 @@
// +build linux
package devicemapper
import "C"
import (
"strings"
)
// Due to the way cgo works this has to be in a separate file, as devmapper.go has
// definitions in the cgo block, which is incompatible with using "//export"
// DevmapperLogCallback exports the devmapper log callback for cgo.
//export DevmapperLogCallback
func DevmapperLogCallback(level C.int, file *C.char, line C.int, dmErrnoOrClass C.int, message *C.char) {
msg := C.GoString(message)
if level < 7 {
if strings.Contains(msg, "busy") {
dmSawBusy = true
}
if strings.Contains(msg, "File exists") {
dmSawExist = true
}
if strings.Contains(msg, "No such device or address") {
dmSawEnxio = true
}
}
if dmLogger != nil {
dmLogger.DMLog(int(level), C.GoString(file), int(line), int(dmErrnoOrClass), msg)
}
}

View File

@@ -1,251 +0,0 @@
// +build linux
package devicemapper
/*
#cgo LDFLAGS: -L. -ldevmapper
#include <libdevmapper.h>
#include <linux/fs.h> // FIXME: present only for BLKGETSIZE64, maybe we can remove it?
// FIXME: Can't we find a way to do the logging in pure Go?
extern void DevmapperLogCallback(int level, char *file, int line, int dm_errno_or_class, char *str);
static void log_cb(int level, const char *file, int line, int dm_errno_or_class, const char *f, ...)
{
char buffer[256];
va_list ap;
va_start(ap, f);
vsnprintf(buffer, 256, f, ap);
va_end(ap);
DevmapperLogCallback(level, (char *)file, line, dm_errno_or_class, buffer);
}
static void log_with_errno_init()
{
dm_log_with_errno_init(log_cb);
}
*/
import "C"
import (
"reflect"
"unsafe"
)
type (
cdmTask C.struct_dm_task
)
// IOCTL consts
const (
BlkGetSize64 = C.BLKGETSIZE64
BlkDiscard = C.BLKDISCARD
)
// Devicemapper cookie flags.
const (
DmUdevDisableSubsystemRulesFlag = C.DM_UDEV_DISABLE_SUBSYSTEM_RULES_FLAG
DmUdevDisableDiskRulesFlag = C.DM_UDEV_DISABLE_DISK_RULES_FLAG
DmUdevDisableOtherRulesFlag = C.DM_UDEV_DISABLE_OTHER_RULES_FLAG
DmUdevDisableLibraryFallback = C.DM_UDEV_DISABLE_LIBRARY_FALLBACK
)
// DeviceMapper mapped functions.
var (
DmGetLibraryVersion = dmGetLibraryVersionFct
DmGetNextTarget = dmGetNextTargetFct
DmLogInitVerbose = dmLogInitVerboseFct
DmSetDevDir = dmSetDevDirFct
DmTaskAddTarget = dmTaskAddTargetFct
DmTaskCreate = dmTaskCreateFct
DmTaskDestroy = dmTaskDestroyFct
DmTaskGetDeps = dmTaskGetDepsFct
DmTaskGetInfo = dmTaskGetInfoFct
DmTaskGetDriverVersion = dmTaskGetDriverVersionFct
DmTaskRun = dmTaskRunFct
DmTaskSetAddNode = dmTaskSetAddNodeFct
DmTaskSetCookie = dmTaskSetCookieFct
DmTaskSetMessage = dmTaskSetMessageFct
DmTaskSetName = dmTaskSetNameFct
DmTaskSetRo = dmTaskSetRoFct
DmTaskSetSector = dmTaskSetSectorFct
DmUdevWait = dmUdevWaitFct
DmUdevSetSyncSupport = dmUdevSetSyncSupportFct
DmUdevGetSyncSupport = dmUdevGetSyncSupportFct
DmCookieSupported = dmCookieSupportedFct
LogWithErrnoInit = logWithErrnoInitFct
DmTaskDeferredRemove = dmTaskDeferredRemoveFct
DmTaskGetInfoWithDeferred = dmTaskGetInfoWithDeferredFct
)
func free(p *C.char) {
C.free(unsafe.Pointer(p))
}
func dmTaskDestroyFct(task *cdmTask) {
C.dm_task_destroy((*C.struct_dm_task)(task))
}
func dmTaskCreateFct(taskType int) *cdmTask {
return (*cdmTask)(C.dm_task_create(C.int(taskType)))
}
func dmTaskRunFct(task *cdmTask) int {
ret, _ := C.dm_task_run((*C.struct_dm_task)(task))
return int(ret)
}
func dmTaskSetNameFct(task *cdmTask, name string) int {
Cname := C.CString(name)
defer free(Cname)
return int(C.dm_task_set_name((*C.struct_dm_task)(task), Cname))
}
func dmTaskSetMessageFct(task *cdmTask, message string) int {
Cmessage := C.CString(message)
defer free(Cmessage)
return int(C.dm_task_set_message((*C.struct_dm_task)(task), Cmessage))
}
func dmTaskSetSectorFct(task *cdmTask, sector uint64) int {
return int(C.dm_task_set_sector((*C.struct_dm_task)(task), C.uint64_t(sector)))
}
func dmTaskSetCookieFct(task *cdmTask, cookie *uint, flags uint16) int {
cCookie := C.uint32_t(*cookie)
defer func() {
*cookie = uint(cCookie)
}()
return int(C.dm_task_set_cookie((*C.struct_dm_task)(task), &cCookie, C.uint16_t(flags)))
}
func dmTaskSetAddNodeFct(task *cdmTask, addNode AddNodeType) int {
return int(C.dm_task_set_add_node((*C.struct_dm_task)(task), C.dm_add_node_t(addNode)))
}
func dmTaskSetRoFct(task *cdmTask) int {
return int(C.dm_task_set_ro((*C.struct_dm_task)(task)))
}
func dmTaskAddTargetFct(task *cdmTask,
start, size uint64, ttype, params string) int {
Cttype := C.CString(ttype)
defer free(Cttype)
Cparams := C.CString(params)
defer free(Cparams)
return int(C.dm_task_add_target((*C.struct_dm_task)(task), C.uint64_t(start), C.uint64_t(size), Cttype, Cparams))
}
func dmTaskGetDepsFct(task *cdmTask) *Deps {
Cdeps := C.dm_task_get_deps((*C.struct_dm_task)(task))
if Cdeps == nil {
return nil
}
// golang issue: https://github.com/golang/go/issues/11925
hdr := reflect.SliceHeader{
Data: uintptr(unsafe.Pointer(uintptr(unsafe.Pointer(Cdeps)) + unsafe.Sizeof(*Cdeps))),
Len: int(Cdeps.count),
Cap: int(Cdeps.count),
}
devices := *(*[]C.uint64_t)(unsafe.Pointer(&hdr))
deps := &Deps{
Count: uint32(Cdeps.count),
Filler: uint32(Cdeps.filler),
}
for _, device := range devices {
deps.Device = append(deps.Device, uint64(device))
}
return deps
}
func dmTaskGetInfoFct(task *cdmTask, info *Info) int {
Cinfo := C.struct_dm_info{}
defer func() {
info.Exists = int(Cinfo.exists)
info.Suspended = int(Cinfo.suspended)
info.LiveTable = int(Cinfo.live_table)
info.InactiveTable = int(Cinfo.inactive_table)
info.OpenCount = int32(Cinfo.open_count)
info.EventNr = uint32(Cinfo.event_nr)
info.Major = uint32(Cinfo.major)
info.Minor = uint32(Cinfo.minor)
info.ReadOnly = int(Cinfo.read_only)
info.TargetCount = int32(Cinfo.target_count)
}()
return int(C.dm_task_get_info((*C.struct_dm_task)(task), &Cinfo))
}
func dmTaskGetDriverVersionFct(task *cdmTask) string {
buffer := C.malloc(128)
defer C.free(buffer)
res := C.dm_task_get_driver_version((*C.struct_dm_task)(task), (*C.char)(buffer), 128)
if res == 0 {
return ""
}
return C.GoString((*C.char)(buffer))
}
func dmGetNextTargetFct(task *cdmTask, next unsafe.Pointer, start, length *uint64, target, params *string) unsafe.Pointer {
var (
Cstart, Clength C.uint64_t
CtargetType, Cparams *C.char
)
defer func() {
*start = uint64(Cstart)
*length = uint64(Clength)
*target = C.GoString(CtargetType)
*params = C.GoString(Cparams)
}()
nextp := C.dm_get_next_target((*C.struct_dm_task)(task), next, &Cstart, &Clength, &CtargetType, &Cparams)
return nextp
}
func dmUdevSetSyncSupportFct(syncWithUdev int) {
(C.dm_udev_set_sync_support(C.int(syncWithUdev)))
}
func dmUdevGetSyncSupportFct() int {
return int(C.dm_udev_get_sync_support())
}
func dmUdevWaitFct(cookie uint) int {
return int(C.dm_udev_wait(C.uint32_t(cookie)))
}
func dmCookieSupportedFct() int {
return int(C.dm_cookie_supported())
}
func dmLogInitVerboseFct(level int) {
C.dm_log_init_verbose(C.int(level))
}
func logWithErrnoInitFct() {
C.log_with_errno_init()
}
func dmSetDevDirFct(dir string) int {
Cdir := C.CString(dir)
defer free(Cdir)
return int(C.dm_set_dev_dir(Cdir))
}
func dmGetLibraryVersionFct(version *string) int {
buffer := C.CString(string(make([]byte, 128)))
defer free(buffer)
defer func() {
*version = C.GoString(buffer)
}()
return int(C.dm_get_library_version(buffer, 128))
}

View File

@@ -1,34 +0,0 @@
// +build linux,!libdm_no_deferred_remove
package devicemapper
/*
#cgo LDFLAGS: -L. -ldevmapper
#include <libdevmapper.h>
*/
import "C"
// LibraryDeferredRemovalSupport is supported when statically linked.
const LibraryDeferredRemovalSupport = true
func dmTaskDeferredRemoveFct(task *cdmTask) int {
return int(C.dm_task_deferred_remove((*C.struct_dm_task)(task)))
}
func dmTaskGetInfoWithDeferredFct(task *cdmTask, info *Info) int {
Cinfo := C.struct_dm_info{}
defer func() {
info.Exists = int(Cinfo.exists)
info.Suspended = int(Cinfo.suspended)
info.LiveTable = int(Cinfo.live_table)
info.InactiveTable = int(Cinfo.inactive_table)
info.OpenCount = int32(Cinfo.open_count)
info.EventNr = uint32(Cinfo.event_nr)
info.Major = uint32(Cinfo.major)
info.Minor = uint32(Cinfo.minor)
info.ReadOnly = int(Cinfo.read_only)
info.TargetCount = int32(Cinfo.target_count)
info.DeferredRemove = int(Cinfo.deferred_remove)
}()
return int(C.dm_task_get_info((*C.struct_dm_task)(task), &Cinfo))
}

View File

@@ -1,15 +0,0 @@
// +build linux,libdm_no_deferred_remove
package devicemapper
// LibraryDeferredRemovalsupport is not supported when statically linked.
const LibraryDeferredRemovalSupport = false
func dmTaskDeferredRemoveFct(task *cdmTask) int {
// Error. Nobody should be calling it.
return -1
}
func dmTaskGetInfoWithDeferredFct(task *cdmTask, info *Info) int {
return -1
}

View File

@@ -1,27 +0,0 @@
// +build linux
package devicemapper
import (
"syscall"
"unsafe"
)
func ioctlBlkGetSize64(fd uintptr) (int64, error) {
var size int64
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, BlkGetSize64, uintptr(unsafe.Pointer(&size))); err != 0 {
return 0, err
}
return size, nil
}
func ioctlBlkDiscard(fd uintptr, offset, length uint64) error {
var r [2]uint64
r[0] = offset
r[1] = length
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, BlkDiscard, uintptr(unsafe.Pointer(&r[0]))); err != 0 {
return err
}
return nil
}

View File

@@ -1,11 +0,0 @@
package devicemapper
// definitions from lvm2 lib/log/log.h
const (
LogLevelFatal = 2 + iota // _LOG_FATAL
LogLevelErr // _LOG_ERR
LogLevelWarn // _LOG_WARN
LogLevelNotice // _LOG_NOTICE
LogLevelInfo // _LOG_INFO
LogLevelDebug // _LOG_DEBUG
)

View File

@@ -1,26 +0,0 @@
package directory
import (
"io/ioutil"
"os"
"path/filepath"
)
// MoveToSubdir moves all contents of a directory to a subdirectory underneath the original path
func MoveToSubdir(oldpath, subdir string) error {
infos, err := ioutil.ReadDir(oldpath)
if err != nil {
return err
}
for _, info := range infos {
if info.Name() != subdir {
oldName := filepath.Join(oldpath, info.Name())
newName := filepath.Join(oldpath, subdir, info.Name())
if err := os.Rename(oldName, newName); err != nil {
return err
}
}
}
return nil
}

View File

@@ -1,182 +0,0 @@
package directory
import (
"io/ioutil"
"os"
"path/filepath"
"reflect"
"sort"
"testing"
)
// Size of an empty directory should be 0
func TestSizeEmpty(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeEmptyDirectory"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
var size int64
if size, _ = Size(dir); size != 0 {
t.Fatalf("empty directory has size: %d", size)
}
}
// Size of a directory with one empty file should be 0
func TestSizeEmptyFile(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeEmptyFile"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
var size int64
if size, _ = Size(file.Name()); size != 0 {
t.Fatalf("directory with one file has size: %d", size)
}
}
// Size of a directory with one 5-byte file should be 5
func TestSizeNonemptyFile(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeNonemptyFile"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
d := []byte{97, 98, 99, 100, 101}
file.Write(d)
var size int64
if size, _ = Size(file.Name()); size != 5 {
t.Fatalf("directory with one 5-byte file has size: %d", size)
}
}
// Size of a directory with one empty directory should be 0
func TestSizeNestedDirectoryEmpty(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeNestedDirectoryEmpty"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
if dir, err = ioutil.TempDir(dir, "nested"); err != nil {
t.Fatalf("failed to create nested directory: %s", err)
}
var size int64
if size, _ = Size(dir); size != 0 {
t.Fatalf("directory with one empty directory has size: %d", size)
}
}
// Test directory with 1 file and 1 empty directory
func TestSizeFileAndNestedDirectoryEmpty(t *testing.T) {
var dir string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "testSizeFileAndNestedDirectoryEmpty"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
if dir, err = ioutil.TempDir(dir, "nested"); err != nil {
t.Fatalf("failed to create nested directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
d := []byte{100, 111, 99, 107, 101, 114}
file.Write(d)
var size int64
if size, _ = Size(dir); size != 6 {
t.Fatalf("directory with 6-byte file and empty directory has size: %d", size)
}
}
// Test directory with 1 file and 1 non-empty directory
func TestSizeFileAndNestedDirectoryNonempty(t *testing.T) {
var dir, dirNested string
var err error
if dir, err = ioutil.TempDir(os.TempDir(), "TestSizeFileAndNestedDirectoryNonempty"); err != nil {
t.Fatalf("failed to create directory: %s", err)
}
if dirNested, err = ioutil.TempDir(dir, "nested"); err != nil {
t.Fatalf("failed to create nested directory: %s", err)
}
var file *os.File
if file, err = ioutil.TempFile(dir, "file"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
data := []byte{100, 111, 99, 107, 101, 114}
file.Write(data)
var nestedFile *os.File
if nestedFile, err = ioutil.TempFile(dirNested, "file"); err != nil {
t.Fatalf("failed to create file in nested directory: %s", err)
}
nestedData := []byte{100, 111, 99, 107, 101, 114}
nestedFile.Write(nestedData)
var size int64
if size, _ = Size(dir); size != 12 {
t.Fatalf("directory with 6-byte file and nested directory with 6-byte file has size: %d", size)
}
}
// Test migration of directory to a subdir underneath itself
func TestMoveToSubdir(t *testing.T) {
var outerDir, subDir string
var err error
if outerDir, err = ioutil.TempDir(os.TempDir(), "TestMoveToSubdir"); err != nil {
t.Fatalf("failed to create directory: %v", err)
}
if subDir, err = ioutil.TempDir(outerDir, "testSub"); err != nil {
t.Fatalf("failed to create subdirectory: %v", err)
}
// write 4 temp files in the outer dir to get moved
filesList := []string{"a", "b", "c", "d"}
for _, fName := range filesList {
if file, err := os.Create(filepath.Join(outerDir, fName)); err != nil {
t.Fatalf("couldn't create temp file %q: %v", fName, err)
} else {
file.WriteString(fName)
file.Close()
}
}
if err = MoveToSubdir(outerDir, filepath.Base(subDir)); err != nil {
t.Fatalf("Error during migration of content to subdirectory: %v", err)
}
// validate that the files were moved to the subdirectory
infos, err := ioutil.ReadDir(subDir)
if len(infos) != 4 {
t.Fatalf("Should be four files in the subdir after the migration: actual length: %d", len(infos))
}
var results []string
for _, info := range infos {
results = append(results, info.Name())
}
sort.Sort(sort.StringSlice(results))
if !reflect.DeepEqual(filesList, results) {
t.Fatalf("Results after migration do not equal list of files: expected: %v, got: %v", filesList, results)
}
}

View File

@@ -1,39 +0,0 @@
// +build linux freebsd
package directory
import (
"os"
"path/filepath"
"syscall"
)
// Size walks a directory tree and returns its total size in bytes.
func Size(dir string) (size int64, err error) {
data := make(map[uint64]struct{})
err = filepath.Walk(dir, func(d string, fileInfo os.FileInfo, e error) error {
// Ignore directory sizes
if fileInfo == nil {
return nil
}
s := fileInfo.Size()
if fileInfo.IsDir() || s == 0 {
return nil
}
// Check inode to handle hard links correctly
inode := fileInfo.Sys().(*syscall.Stat_t).Ino
// inode is not a uint64 on all platforms. Cast it to avoid issues.
if _, exists := data[uint64(inode)]; exists {
return nil
}
// inode is not a uint64 on all platforms. Cast it to avoid issues.
data[uint64(inode)] = struct{}{}
size += s
return nil
})
return
}

View File

@@ -1,35 +0,0 @@
// +build windows
package directory
import (
"os"
"path/filepath"
"github.com/hyperhq/hypercli/pkg/longpath"
)
// Size walks a directory tree and returns its total size in bytes.
func Size(dir string) (size int64, err error) {
fixedPath, err := filepath.Abs(dir)
if err != nil {
return
}
fixedPath = longpath.AddPrefix(fixedPath)
err = filepath.Walk(dir, func(d string, fileInfo os.FileInfo, e error) error {
// Ignore directory sizes
if fileInfo == nil {
return nil
}
s := fileInfo.Size()
if fileInfo.IsDir() || s == 0 {
return nil
}
size += s
return nil
})
return
}

View File

@@ -1,41 +0,0 @@
---
page_title: Docker discovery
page_description: discovery
page_keywords: docker, clustering, discovery
---
# Discovery
Docker comes with multiple Discovery backends.
## Backends
### Using etcd
Point your Docker Engine instances to a common etcd instance. You can specify
the address Docker uses to advertise the node using the `--cluster-advertise`
flag.
```bash
$ docker daemon -H=<node_ip:2376> --cluster-advertise=<node_ip:2376> --cluster-store etcd://<etcd_ip1>,<etcd_ip2>/<path>
```
### Using consul
Point your Docker Engine instances to a common Consul instance. You can specify
the address Docker uses to advertise the node using the `--cluster-advertise`
flag.
```bash
$ docker daemon -H=<node_ip:2376> --cluster-advertise=<node_ip:2376> --cluster-store consul://<consul_ip>/<path>
```
### Using zookeeper
Point your Docker Engine instances to a common Zookeeper instance. You can specify
the address Docker uses to advertise the node using the `--cluster-advertise`
flag.
```bash
$ docker daemon -H=<node_ip:2376> --cluster-advertise=<node_ip:2376> --cluster-store zk://<zk_addr1>,<zk_addr2>/<path>
```

View File

@@ -1,107 +0,0 @@
package discovery
import (
"fmt"
"net"
"strings"
"time"
log "github.com/Sirupsen/logrus"
)
var (
// Backends is a global map of discovery backends indexed by their
// associated scheme.
backends = make(map[string]Backend)
)
// Register makes a discovery backend available by the provided scheme.
// If Register is called twice with the same scheme an error is returned.
func Register(scheme string, d Backend) error {
if _, exists := backends[scheme]; exists {
return fmt.Errorf("scheme already registered %s", scheme)
}
log.WithField("name", scheme).Debug("Registering discovery service")
backends[scheme] = d
return nil
}
func parse(rawurl string) (string, string) {
parts := strings.SplitN(rawurl, "://", 2)
// nodes:port,node2:port => nodes://node1:port,node2:port
if len(parts) == 1 {
return "nodes", parts[0]
}
return parts[0], parts[1]
}
// ParseAdvertise parses the --cluster-advertise daemon config which accepts
// <ip-address>:<port> or <interface-name>:<port>
func ParseAdvertise(advertise string) (string, error) {
var (
iface *net.Interface
addrs []net.Addr
err error
)
addr, port, err := net.SplitHostPort(advertise)
if err != nil {
return "", fmt.Errorf("invalid --cluster-advertise configuration: %s: %v", advertise, err)
}
ip := net.ParseIP(addr)
// If it is a valid ip-address, use it as is
if ip != nil {
return advertise, nil
}
// If advertise is a valid interface name, get the valid ipv4 address and use it to advertise
ifaceName := addr
iface, err = net.InterfaceByName(ifaceName)
if err != nil {
return "", fmt.Errorf("invalid cluster advertise IP address or interface name (%s) : %v", advertise, err)
}
addrs, err = iface.Addrs()
if err != nil {
return "", fmt.Errorf("unable to get advertise IP address from interface (%s) : %v", advertise, err)
}
if addrs == nil || len(addrs) == 0 {
return "", fmt.Errorf("no available advertise IP address in interface (%s)", advertise)
}
addr = ""
for _, a := range addrs {
ip, _, err := net.ParseCIDR(a.String())
if err != nil {
return "", fmt.Errorf("error deriving advertise ip-address in interface (%s) : %v", advertise, err)
}
if ip.To4() == nil || ip.IsLoopback() {
continue
}
addr = ip.String()
break
}
if addr == "" {
return "", fmt.Errorf("couldnt find a valid ip-address in interface %s", advertise)
}
addr = fmt.Sprintf("%s:%s", addr, port)
return addr, nil
}
// New returns a new Discovery given a URL, heartbeat and ttl settings.
// Returns an error if the URL scheme is not supported.
func New(rawurl string, heartbeat time.Duration, ttl time.Duration, clusterOpts map[string]string) (Backend, error) {
scheme, uri := parse(rawurl)
if backend, exists := backends[scheme]; exists {
log.WithFields(log.Fields{"name": scheme, "uri": uri}).Debug("Initializing discovery service")
err := backend.Initialize(uri, heartbeat, ttl, clusterOpts)
return backend, err
}
return nil, ErrNotSupported
}

View File

@@ -1,35 +0,0 @@
package discovery
import (
"errors"
"time"
)
var (
// ErrNotSupported is returned when a discovery service is not supported.
ErrNotSupported = errors.New("discovery service not supported")
// ErrNotImplemented is returned when discovery feature is not implemented
// by discovery backend.
ErrNotImplemented = errors.New("not implemented in this discovery service")
)
// Watcher provides watching over a cluster for nodes joining and leaving.
type Watcher interface {
// Watch the discovery for entry changes.
// Returns a channel that will receive changes or an error.
// Providing a non-nil stopCh can be used to stop watching.
Watch(stopCh <-chan struct{}) (<-chan Entries, <-chan error)
}
// Backend is implemented by discovery backends which manage cluster entries.
type Backend interface {
// Watcher must be provided by every backend.
Watcher
// Initialize the discovery with URIs, a heartbeat, a ttl and optional settings.
Initialize(string, time.Duration, time.Duration, map[string]string) error
// Register to the discovery.
Register(string) error
}

View File

@@ -1,131 +0,0 @@
package discovery
import (
"testing"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (s *DiscoverySuite) TestNewEntry(c *check.C) {
entry, err := NewEntry("127.0.0.1:2375")
c.Assert(err, check.IsNil)
c.Assert(entry.Equals(&Entry{Host: "127.0.0.1", Port: "2375"}), check.Equals, true)
c.Assert(entry.String(), check.Equals, "127.0.0.1:2375")
_, err = NewEntry("127.0.0.1")
c.Assert(err, check.NotNil)
}
func (s *DiscoverySuite) TestParse(c *check.C) {
scheme, uri := parse("127.0.0.1:2375")
c.Assert(scheme, check.Equals, "nodes")
c.Assert(uri, check.Equals, "127.0.0.1:2375")
scheme, uri = parse("localhost:2375")
c.Assert(scheme, check.Equals, "nodes")
c.Assert(uri, check.Equals, "localhost:2375")
scheme, uri = parse("scheme://127.0.0.1:2375")
c.Assert(scheme, check.Equals, "scheme")
c.Assert(uri, check.Equals, "127.0.0.1:2375")
scheme, uri = parse("scheme://localhost:2375")
c.Assert(scheme, check.Equals, "scheme")
c.Assert(uri, check.Equals, "localhost:2375")
scheme, uri = parse("")
c.Assert(scheme, check.Equals, "nodes")
c.Assert(uri, check.Equals, "")
}
func (s *DiscoverySuite) TestCreateEntries(c *check.C) {
entries, err := CreateEntries(nil)
c.Assert(entries, check.DeepEquals, Entries{})
c.Assert(err, check.IsNil)
entries, err = CreateEntries([]string{"127.0.0.1:2375", "127.0.0.2:2375", ""})
c.Assert(err, check.IsNil)
expected := Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
}
c.Assert(entries.Equals(expected), check.Equals, true)
_, err = CreateEntries([]string{"127.0.0.1", "127.0.0.2"})
c.Assert(err, check.NotNil)
}
func (s *DiscoverySuite) TestContainsEntry(c *check.C) {
entries, err := CreateEntries([]string{"127.0.0.1:2375", "127.0.0.2:2375", ""})
c.Assert(err, check.IsNil)
c.Assert(entries.Contains(&Entry{Host: "127.0.0.1", Port: "2375"}), check.Equals, true)
c.Assert(entries.Contains(&Entry{Host: "127.0.0.3", Port: "2375"}), check.Equals, false)
}
func (s *DiscoverySuite) TestEntriesEquality(c *check.C) {
entries := Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
}
// Same
c.Assert(entries.Equals(Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
}), check.
Equals, true)
// Different size
c.Assert(entries.Equals(Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.2", Port: "2375"},
&Entry{Host: "127.0.0.3", Port: "2375"},
}), check.
Equals, false)
// Different content
c.Assert(entries.Equals(Entries{
&Entry{Host: "127.0.0.1", Port: "2375"},
&Entry{Host: "127.0.0.42", Port: "2375"},
}), check.
Equals, false)
}
func (s *DiscoverySuite) TestEntriesDiff(c *check.C) {
entry1 := &Entry{Host: "1.1.1.1", Port: "1111"}
entry2 := &Entry{Host: "2.2.2.2", Port: "2222"}
entry3 := &Entry{Host: "3.3.3.3", Port: "3333"}
entries := Entries{entry1, entry2}
// No diff
added, removed := entries.Diff(Entries{entry2, entry1})
c.Assert(added, check.HasLen, 0)
c.Assert(removed, check.HasLen, 0)
// Add
added, removed = entries.Diff(Entries{entry2, entry3, entry1})
c.Assert(added, check.HasLen, 1)
c.Assert(added.Contains(entry3), check.Equals, true)
c.Assert(removed, check.HasLen, 0)
// Remove
added, removed = entries.Diff(Entries{entry2})
c.Assert(added, check.HasLen, 0)
c.Assert(removed, check.HasLen, 1)
c.Assert(removed.Contains(entry1), check.Equals, true)
// Add and remove
added, removed = entries.Diff(Entries{entry1, entry3})
c.Assert(added, check.HasLen, 1)
c.Assert(added.Contains(entry3), check.Equals, true)
c.Assert(removed, check.HasLen, 1)
c.Assert(removed.Contains(entry2), check.Equals, true)
}

View File

@@ -1,97 +0,0 @@
package discovery
import (
"fmt"
"net"
)
// NewEntry creates a new entry.
func NewEntry(url string) (*Entry, error) {
host, port, err := net.SplitHostPort(url)
if err != nil {
return nil, err
}
return &Entry{host, port}, nil
}
// An Entry represents a host.
type Entry struct {
Host string
Port string
}
// Equals returns true if cmp contains the same data.
func (e *Entry) Equals(cmp *Entry) bool {
return e.Host == cmp.Host && e.Port == cmp.Port
}
// String returns the string form of an entry.
func (e *Entry) String() string {
return fmt.Sprintf("%s:%s", e.Host, e.Port)
}
// Entries is a list of *Entry with some helpers.
type Entries []*Entry
// Equals returns true if cmp contains the same data.
func (e Entries) Equals(cmp Entries) bool {
// Check if the file has really changed.
if len(e) != len(cmp) {
return false
}
for i := range e {
if !e[i].Equals(cmp[i]) {
return false
}
}
return true
}
// Contains returns true if the Entries contain a given Entry.
func (e Entries) Contains(entry *Entry) bool {
for _, curr := range e {
if curr.Equals(entry) {
return true
}
}
return false
}
// Diff compares two entries and returns the added and removed entries.
func (e Entries) Diff(cmp Entries) (Entries, Entries) {
added := Entries{}
for _, entry := range cmp {
if !e.Contains(entry) {
added = append(added, entry)
}
}
removed := Entries{}
for _, entry := range e {
if !cmp.Contains(entry) {
removed = append(removed, entry)
}
}
return added, removed
}
// CreateEntries returns an array of entries based on the given addresses.
func CreateEntries(addrs []string) (Entries, error) {
entries := Entries{}
if addrs == nil {
return entries, nil
}
for _, addr := range addrs {
if len(addr) == 0 {
continue
}
entry, err := NewEntry(addr)
if err != nil {
return nil, err
}
entries = append(entries, entry)
}
return entries, nil
}

View File

@@ -1,109 +0,0 @@
package file
import (
"fmt"
"io/ioutil"
"strings"
"time"
"github.com/hyperhq/hypercli/pkg/discovery"
)
// Discovery is exported
type Discovery struct {
heartbeat time.Duration
path string
}
func init() {
Init()
}
// Init is exported
func Init() {
discovery.Register("file", &Discovery{})
}
// Initialize is exported
func (s *Discovery) Initialize(path string, heartbeat time.Duration, ttl time.Duration, _ map[string]string) error {
s.path = path
s.heartbeat = heartbeat
return nil
}
func parseFileContent(content []byte) []string {
var result []string
for _, line := range strings.Split(strings.TrimSpace(string(content)), "\n") {
line = strings.TrimSpace(line)
// Ignoring line starts with #
if strings.HasPrefix(line, "#") {
continue
}
// Inlined # comment also ignored.
if strings.Contains(line, "#") {
line = line[0:strings.Index(line, "#")]
// Trim additional spaces caused by above stripping.
line = strings.TrimSpace(line)
}
for _, ip := range discovery.Generate(line) {
result = append(result, ip)
}
}
return result
}
func (s *Discovery) fetch() (discovery.Entries, error) {
fileContent, err := ioutil.ReadFile(s.path)
if err != nil {
return nil, fmt.Errorf("failed to read '%s': %v", s.path, err)
}
return discovery.CreateEntries(parseFileContent(fileContent))
}
// Watch is exported
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
errCh := make(chan error)
ticker := time.NewTicker(s.heartbeat)
go func() {
defer close(errCh)
defer close(ch)
// Send the initial entries if available.
currentEntries, err := s.fetch()
if err != nil {
errCh <- err
} else {
ch <- currentEntries
}
// Periodically send updates.
for {
select {
case <-ticker.C:
newEntries, err := s.fetch()
if err != nil {
errCh <- err
continue
}
// Check if the file has really changed.
if !newEntries.Equals(currentEntries) {
ch <- newEntries
}
currentEntries = newEntries
case <-stopCh:
ticker.Stop()
return
}
}
}()
return ch, errCh
}
// Register is exported
func (s *Discovery) Register(addr string) error {
return discovery.ErrNotImplemented
}

View File

@@ -1,114 +0,0 @@
package file
import (
"io/ioutil"
"os"
"testing"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (s *DiscoverySuite) TestInitialize(c *check.C) {
d := &Discovery{}
d.Initialize("/path/to/file", 1000, 0, nil)
c.Assert(d.path, check.Equals, "/path/to/file")
}
func (s *DiscoverySuite) TestNew(c *check.C) {
d, err := discovery.New("file:///path/to/file", 0, 0, nil)
c.Assert(err, check.IsNil)
c.Assert(d.(*Discovery).path, check.Equals, "/path/to/file")
}
func (s *DiscoverySuite) TestContent(c *check.C) {
data := `
1.1.1.[1:2]:1111
2.2.2.[2:4]:2222
`
ips := parseFileContent([]byte(data))
c.Assert(ips, check.HasLen, 5)
c.Assert(ips[0], check.Equals, "1.1.1.1:1111")
c.Assert(ips[1], check.Equals, "1.1.1.2:1111")
c.Assert(ips[2], check.Equals, "2.2.2.2:2222")
c.Assert(ips[3], check.Equals, "2.2.2.3:2222")
c.Assert(ips[4], check.Equals, "2.2.2.4:2222")
}
func (s *DiscoverySuite) TestRegister(c *check.C) {
discovery := &Discovery{path: "/path/to/file"}
c.Assert(discovery.Register("0.0.0.0"), check.NotNil)
}
func (s *DiscoverySuite) TestParsingContentsWithComments(c *check.C) {
data := `
### test ###
1.1.1.1:1111 # inline comment
# 2.2.2.2:2222
### empty line with comment
3.3.3.3:3333
### test ###
`
ips := parseFileContent([]byte(data))
c.Assert(ips, check.HasLen, 2)
c.Assert("1.1.1.1:1111", check.Equals, ips[0])
c.Assert("3.3.3.3:3333", check.Equals, ips[1])
}
func (s *DiscoverySuite) TestWatch(c *check.C) {
data := `
1.1.1.1:1111
2.2.2.2:2222
`
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
// Create a temporary file and remove it.
tmp, err := ioutil.TempFile(os.TempDir(), "discovery-file-test")
c.Assert(err, check.IsNil)
c.Assert(tmp.Close(), check.IsNil)
c.Assert(os.Remove(tmp.Name()), check.IsNil)
// Set up file discovery.
d := &Discovery{}
d.Initialize(tmp.Name(), 1000, 0, nil)
stopCh := make(chan struct{})
ch, errCh := d.Watch(stopCh)
// Make sure it fires errors since the file doesn't exist.
c.Assert(<-errCh, check.NotNil)
// We have to drain the error channel otherwise Watch will get stuck.
go func() {
for range errCh {
}
}()
// Write the file and make sure we get the expected value back.
c.Assert(ioutil.WriteFile(tmp.Name(), []byte(data), 0600), check.IsNil)
c.Assert(<-ch, check.DeepEquals, expected)
// Add a new entry and look it up.
expected = append(expected, &discovery.Entry{Host: "3.3.3.3", Port: "3333"})
f, err := os.OpenFile(tmp.Name(), os.O_APPEND|os.O_WRONLY, 0600)
c.Assert(err, check.IsNil)
c.Assert(f, check.NotNil)
_, err = f.WriteString("\n3.3.3.3:3333\n")
c.Assert(err, check.IsNil)
f.Close()
c.Assert(<-ch, check.DeepEquals, expected)
// Stop and make sure it closes all channels.
close(stopCh)
c.Assert(<-ch, check.IsNil)
c.Assert(<-errCh, check.IsNil)
}

View File

@@ -1,35 +0,0 @@
package discovery
import (
"fmt"
"regexp"
"strconv"
)
// Generate takes care of IP generation
func Generate(pattern string) []string {
re, _ := regexp.Compile(`\[(.+):(.+)\]`)
submatch := re.FindStringSubmatch(pattern)
if submatch == nil {
return []string{pattern}
}
from, err := strconv.Atoi(submatch[1])
if err != nil {
return []string{pattern}
}
to, err := strconv.Atoi(submatch[2])
if err != nil {
return []string{pattern}
}
template := re.ReplaceAllString(pattern, "%d")
var result []string
for val := from; val <= to; val++ {
entry := fmt.Sprintf(template, val)
result = append(result, entry)
}
return result
}

View File

@@ -1,53 +0,0 @@
package discovery
import (
"github.com/go-check/check"
)
func (s *DiscoverySuite) TestGeneratorNotGenerate(c *check.C) {
ips := Generate("127.0.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, "127.0.0.1")
}
func (s *DiscoverySuite) TestGeneratorWithPortNotGenerate(c *check.C) {
ips := Generate("127.0.0.1:8080")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, "127.0.0.1:8080")
}
func (s *DiscoverySuite) TestGeneratorMatchFailedNotGenerate(c *check.C) {
ips := Generate("127.0.0.[1]")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, "127.0.0.[1]")
}
func (s *DiscoverySuite) TestGeneratorWithPort(c *check.C) {
ips := Generate("127.0.0.[1:11]:2375")
c.Assert(len(ips), check.Equals, 11)
c.Assert(ips[0], check.Equals, "127.0.0.1:2375")
c.Assert(ips[1], check.Equals, "127.0.0.2:2375")
c.Assert(ips[2], check.Equals, "127.0.0.3:2375")
c.Assert(ips[3], check.Equals, "127.0.0.4:2375")
c.Assert(ips[4], check.Equals, "127.0.0.5:2375")
c.Assert(ips[5], check.Equals, "127.0.0.6:2375")
c.Assert(ips[6], check.Equals, "127.0.0.7:2375")
c.Assert(ips[7], check.Equals, "127.0.0.8:2375")
c.Assert(ips[8], check.Equals, "127.0.0.9:2375")
c.Assert(ips[9], check.Equals, "127.0.0.10:2375")
c.Assert(ips[10], check.Equals, "127.0.0.11:2375")
}
func (s *DiscoverySuite) TestGenerateWithMalformedInputAtRangeStart(c *check.C) {
malformedInput := "127.0.0.[x:11]:2375"
ips := Generate(malformedInput)
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, malformedInput)
}
func (s *DiscoverySuite) TestGenerateWithMalformedInputAtRangeEnd(c *check.C) {
malformedInput := "127.0.0.[1:x]:2375"
ips := Generate(malformedInput)
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0], check.Equals, malformedInput)
}

View File

@@ -1,192 +0,0 @@
package kv
import (
"fmt"
"path"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/docker/go-connections/tlsconfig"
"github.com/docker/libkv"
"github.com/docker/libkv/store"
"github.com/docker/libkv/store/consul"
"github.com/docker/libkv/store/etcd"
"github.com/docker/libkv/store/zookeeper"
"github.com/hyperhq/hypercli/pkg/discovery"
)
const (
defaultDiscoveryPath = "docker/nodes"
)
// Discovery is exported
type Discovery struct {
backend store.Backend
store store.Store
heartbeat time.Duration
ttl time.Duration
prefix string
path string
}
func init() {
Init()
}
// Init is exported
func Init() {
// Register to libkv
zookeeper.Register()
consul.Register()
etcd.Register()
// Register to internal discovery service
discovery.Register("zk", &Discovery{backend: store.ZK})
discovery.Register("consul", &Discovery{backend: store.CONSUL})
discovery.Register("etcd", &Discovery{backend: store.ETCD})
}
// Initialize is exported
func (s *Discovery) Initialize(uris string, heartbeat time.Duration, ttl time.Duration, clusterOpts map[string]string) error {
var (
parts = strings.SplitN(uris, "/", 2)
addrs = strings.Split(parts[0], ",")
err error
)
// A custom prefix to the path can be optionally used.
if len(parts) == 2 {
s.prefix = parts[1]
}
s.heartbeat = heartbeat
s.ttl = ttl
// Use a custom path if specified in discovery options
dpath := defaultDiscoveryPath
if clusterOpts["kv.path"] != "" {
dpath = clusterOpts["kv.path"]
}
s.path = path.Join(s.prefix, dpath)
var config *store.Config
if clusterOpts["kv.cacertfile"] != "" && clusterOpts["kv.certfile"] != "" && clusterOpts["kv.keyfile"] != "" {
log.Info("Initializing discovery with TLS")
tlsConfig, err := tlsconfig.Client(tlsconfig.Options{
CAFile: clusterOpts["kv.cacertfile"],
CertFile: clusterOpts["kv.certfile"],
KeyFile: clusterOpts["kv.keyfile"],
})
if err != nil {
return err
}
config = &store.Config{
// Set ClientTLS to trigger https (bug in libkv/etcd)
ClientTLS: &store.ClientTLSConfig{
CACertFile: clusterOpts["kv.cacertfile"],
CertFile: clusterOpts["kv.certfile"],
KeyFile: clusterOpts["kv.keyfile"],
},
// The actual TLS config that will be used
TLS: tlsConfig,
}
} else {
log.Info("Initializing discovery without TLS")
}
// Creates a new store, will ignore options given
// if not supported by the chosen store
s.store, err = libkv.NewStore(s.backend, addrs, config)
return err
}
// Watch the store until either there's a store error or we receive a stop request.
// Returns false if we shouldn't attempt watching the store anymore (stop request received).
func (s *Discovery) watchOnce(stopCh <-chan struct{}, watchCh <-chan []*store.KVPair, discoveryCh chan discovery.Entries, errCh chan error) bool {
for {
select {
case pairs := <-watchCh:
if pairs == nil {
return true
}
log.WithField("discovery", s.backend).Debugf("Watch triggered with %d nodes", len(pairs))
// Convert `KVPair` into `discovery.Entry`.
addrs := make([]string, len(pairs))
for _, pair := range pairs {
addrs = append(addrs, string(pair.Value))
}
entries, err := discovery.CreateEntries(addrs)
if err != nil {
errCh <- err
} else {
discoveryCh <- entries
}
case <-stopCh:
// We were requested to stop watching.
return false
}
}
}
// Watch is exported
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
errCh := make(chan error)
go func() {
defer close(ch)
defer close(errCh)
// Forever: Create a store watch, watch until we get an error and then try again.
// Will only stop if we receive a stopCh request.
for {
// Create the path to watch if it does not exist yet
exists, err := s.store.Exists(s.path)
if err != nil {
errCh <- err
}
if !exists {
if err := s.store.Put(s.path, []byte(""), &store.WriteOptions{IsDir: true}); err != nil {
errCh <- err
}
}
// Set up a watch.
watchCh, err := s.store.WatchTree(s.path, stopCh)
if err != nil {
errCh <- err
} else {
if !s.watchOnce(stopCh, watchCh, ch, errCh) {
return
}
}
// If we get here it means the store watch channel was closed. This
// is unexpected so let's retry later.
errCh <- fmt.Errorf("Unexpected watch error")
time.Sleep(s.heartbeat)
}
}()
return ch, errCh
}
// Register is exported
func (s *Discovery) Register(addr string) error {
opts := &store.WriteOptions{TTL: s.ttl}
return s.store.Put(path.Join(s.path, addr), []byte(addr), opts)
}
// Store returns the underlying store used by KV discovery.
func (s *Discovery) Store() store.Store {
return s.store
}
// Prefix returns the store prefix
func (s *Discovery) Prefix() string {
return s.prefix
}

View File

@@ -1,324 +0,0 @@
package kv
import (
"errors"
"io/ioutil"
"os"
"path"
"testing"
"time"
"github.com/docker/libkv"
"github.com/docker/libkv/store"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (ds *DiscoverySuite) TestInitialize(c *check.C) {
storeMock := &FakeStore{
Endpoints: []string{"127.0.0.1"},
}
d := &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1", 0, 0, nil)
d.store = storeMock
s := d.store.(*FakeStore)
c.Assert(s.Endpoints, check.HasLen, 1)
c.Assert(s.Endpoints[0], check.Equals, "127.0.0.1")
c.Assert(d.path, check.Equals, defaultDiscoveryPath)
storeMock = &FakeStore{
Endpoints: []string{"127.0.0.1:1234"},
}
d = &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1:1234/path", 0, 0, nil)
d.store = storeMock
s = d.store.(*FakeStore)
c.Assert(s.Endpoints, check.HasLen, 1)
c.Assert(s.Endpoints[0], check.Equals, "127.0.0.1:1234")
c.Assert(d.path, check.Equals, "path/"+defaultDiscoveryPath)
storeMock = &FakeStore{
Endpoints: []string{"127.0.0.1:1234", "127.0.0.2:1234", "127.0.0.3:1234"},
}
d = &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1:1234,127.0.0.2:1234,127.0.0.3:1234/path", 0, 0, nil)
d.store = storeMock
s = d.store.(*FakeStore)
c.Assert(s.Endpoints, check.HasLen, 3)
c.Assert(s.Endpoints[0], check.Equals, "127.0.0.1:1234")
c.Assert(s.Endpoints[1], check.Equals, "127.0.0.2:1234")
c.Assert(s.Endpoints[2], check.Equals, "127.0.0.3:1234")
c.Assert(d.path, check.Equals, "path/"+defaultDiscoveryPath)
}
// Extremely limited mock store so we can test initialization
type Mock struct {
// Endpoints passed to InitializeMock
Endpoints []string
// Options passed to InitializeMock
Options *store.Config
}
func NewMock(endpoints []string, options *store.Config) (store.Store, error) {
s := &Mock{}
s.Endpoints = endpoints
s.Options = options
return s, nil
}
func (s *Mock) Put(key string, value []byte, opts *store.WriteOptions) error {
return errors.New("Put not supported")
}
func (s *Mock) Get(key string) (*store.KVPair, error) {
return nil, errors.New("Get not supported")
}
func (s *Mock) Delete(key string) error {
return errors.New("Delete not supported")
}
// Exists mock
func (s *Mock) Exists(key string) (bool, error) {
return false, errors.New("Exists not supported")
}
// Watch mock
func (s *Mock) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
return nil, errors.New("Watch not supported")
}
// WatchTree mock
func (s *Mock) WatchTree(prefix string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
return nil, errors.New("WatchTree not supported")
}
// NewLock mock
func (s *Mock) NewLock(key string, options *store.LockOptions) (store.Locker, error) {
return nil, errors.New("NewLock not supported")
}
// List mock
func (s *Mock) List(prefix string) ([]*store.KVPair, error) {
return nil, errors.New("List not supported")
}
// DeleteTree mock
func (s *Mock) DeleteTree(prefix string) error {
return errors.New("DeleteTree not supported")
}
// AtomicPut mock
func (s *Mock) AtomicPut(key string, value []byte, previous *store.KVPair, opts *store.WriteOptions) (bool, *store.KVPair, error) {
return false, nil, errors.New("AtomicPut not supported")
}
// AtomicDelete mock
func (s *Mock) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
return false, errors.New("AtomicDelete not supported")
}
// Close mock
func (s *Mock) Close() {
return
}
func (ds *DiscoverySuite) TestInitializeWithCerts(c *check.C) {
cert := `-----BEGIN CERTIFICATE-----
MIIDCDCCAfKgAwIBAgIICifG7YeiQOEwCwYJKoZIhvcNAQELMBIxEDAOBgNVBAMT
B1Rlc3QgQ0EwHhcNMTUxMDAxMjMwMDAwWhcNMjAwOTI5MjMwMDAwWjASMRAwDgYD
VQQDEwdUZXN0IENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1wRC
O+flnLTK5ImjTurNRHwSejuqGbc4CAvpB0hS+z0QlSs4+zE9h80aC4hz+6caRpds
+J908Q+RvAittMHbpc7VjbZP72G6fiXk7yPPl6C10HhRSoSi3nY+B7F2E8cuz14q
V2e+ejhWhSrBb/keyXpcyjoW1BOAAJ2TIclRRkICSCZrpXUyXxAvzXfpFXo1RhSb
UywN11pfiCQzDUN7sPww9UzFHuAHZHoyfTr27XnJYVUerVYrCPq8vqfn//01qz55
Xs0hvzGdlTFXhuabFtQnKFH5SNwo/fcznhB7rePOwHojxOpXTBepUCIJLbtNnWFT
V44t9gh5IqIWtoBReQIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAAYwEgYDVR0TAQH/
BAgwBgEB/wIBAjAdBgNVHQ4EFgQUZKUI8IIjIww7X/6hvwggQK4bD24wHwYDVR0j
BBgwFoAUZKUI8IIjIww7X/6hvwggQK4bD24wCwYJKoZIhvcNAQELA4IBAQDES2cz
7sCQfDCxCIWH7X8kpi/JWExzUyQEJ0rBzN1m3/x8ySRxtXyGekimBqQwQdFqlwMI
xzAQKkh3ue8tNSzRbwqMSyH14N1KrSxYS9e9szJHfUasoTpQGPmDmGIoRJuq1h6M
ej5x1SCJ7GWCR6xEXKUIE9OftXm9TdFzWa7Ja3OHz/mXteii8VXDuZ5ACq6EE5bY
8sP4gcICfJ5fTrpTlk9FIqEWWQrCGa5wk95PGEj+GJpNogjXQ97wVoo/Y3p1brEn
t5zjN9PAq4H1fuCMdNNA+p1DHNwd+ELTxcMAnb2ajwHvV6lKPXutrTFc4umJToBX
FpTxDmJHEV4bzUzh
-----END CERTIFICATE-----
`
key := `-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEA1wRCO+flnLTK5ImjTurNRHwSejuqGbc4CAvpB0hS+z0QlSs4
+zE9h80aC4hz+6caRpds+J908Q+RvAittMHbpc7VjbZP72G6fiXk7yPPl6C10HhR
SoSi3nY+B7F2E8cuz14qV2e+ejhWhSrBb/keyXpcyjoW1BOAAJ2TIclRRkICSCZr
pXUyXxAvzXfpFXo1RhSbUywN11pfiCQzDUN7sPww9UzFHuAHZHoyfTr27XnJYVUe
rVYrCPq8vqfn//01qz55Xs0hvzGdlTFXhuabFtQnKFH5SNwo/fcznhB7rePOwHoj
xOpXTBepUCIJLbtNnWFTV44t9gh5IqIWtoBReQIDAQABAoIBAHSWipORGp/uKFXj
i/mut776x8ofsAxhnLBARQr93ID+i49W8H7EJGkOfaDjTICYC1dbpGrri61qk8sx
qX7p3v/5NzKwOIfEpirgwVIqSNYe/ncbxnhxkx6tXtUtFKmEx40JskvSpSYAhmmO
1XSx0E/PWaEN/nLgX/f1eWJIlxlQkk3QeqL+FGbCXI48DEtlJ9+MzMu4pAwZTpj5
5qtXo5JJ0jRGfJVPAOznRsYqv864AhMdMIWguzk6EGnbaCWwPcfcn+h9a5LMdony
MDHfBS7bb5tkF3+AfnVY3IBMVx7YlsD9eAyajlgiKu4zLbwTRHjXgShy+4Oussz0
ugNGnkECgYEA/hi+McrZC8C4gg6XqK8+9joD8tnyDZDz88BQB7CZqABUSwvjDqlP
L8hcwo/lzvjBNYGkqaFPUICGWKjeCtd8pPS2DCVXxDQX4aHF1vUur0uYNncJiV3N
XQz4Iemsa6wnKf6M67b5vMXICw7dw0HZCdIHD1hnhdtDz0uVpeevLZ8CgYEA2KCT
Y43lorjrbCgMqtlefkr3GJA9dey+hTzCiWEOOqn9RqGoEGUday0sKhiLofOgmN2B
LEukpKIey8s+Q/cb6lReajDVPDsMweX8i7hz3Wa4Ugp4Xa5BpHqu8qIAE2JUZ7bU
t88aQAYE58pUF+/Lq1QzAQdrjjzQBx6SrBxieecCgYEAvukoPZEC8mmiN1VvbTX+
QFHmlZha3QaDxChB+QUe7bMRojEUL/fVnzkTOLuVFqSfxevaI/km9n0ac5KtAchV
xjp2bTnBb5EUQFqjopYktWA+xO07JRJtMfSEmjZPbbay1kKC7rdTfBm961EIHaRj
xZUf6M+rOE8964oGrdgdLlECgYEA046GQmx6fh7/82FtdZDRQp9tj3SWQUtSiQZc
qhO59Lq8mjUXz+MgBuJXxkiwXRpzlbaFB0Bca1fUoYw8o915SrDYf/Zu2OKGQ/qa
V81sgiVmDuEgycR7YOlbX6OsVUHrUlpwhY3hgfMe6UtkMvhBvHF/WhroBEIJm1pV
PXZ/CbMCgYEApNWVktFBjOaYfY6SNn4iSts1jgsQbbpglg3kT7PLKjCAhI6lNsbk
dyT7ut01PL6RaW4SeQWtrJIVQaM6vF3pprMKqlc5XihOGAmVqH7rQx9rtQB5TicL
BFrwkQE4HQtQBV60hYQUzzlSk44VFDz+jxIEtacRHaomDRh2FtOTz+I=
-----END RSA PRIVATE KEY-----
`
certFile, err := ioutil.TempFile("", "cert")
c.Assert(err, check.IsNil)
defer os.Remove(certFile.Name())
certFile.Write([]byte(cert))
certFile.Close()
keyFile, err := ioutil.TempFile("", "key")
c.Assert(err, check.IsNil)
defer os.Remove(keyFile.Name())
keyFile.Write([]byte(key))
keyFile.Close()
libkv.AddStore("mock", NewMock)
d := &Discovery{backend: "mock"}
err = d.Initialize("127.0.0.3:1234", 0, 0, map[string]string{
"kv.cacertfile": certFile.Name(),
"kv.certfile": certFile.Name(),
"kv.keyfile": keyFile.Name(),
})
c.Assert(err, check.IsNil)
s := d.store.(*Mock)
c.Assert(s.Options.TLS, check.NotNil)
c.Assert(s.Options.TLS.RootCAs, check.NotNil)
c.Assert(s.Options.TLS.Certificates, check.HasLen, 1)
}
func (ds *DiscoverySuite) TestWatch(c *check.C) {
mockCh := make(chan []*store.KVPair)
storeMock := &FakeStore{
Endpoints: []string{"127.0.0.1:1234"},
mockKVChan: mockCh,
}
d := &Discovery{backend: store.CONSUL}
d.Initialize("127.0.0.1:1234/path", 0, 0, nil)
d.store = storeMock
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
kvs := []*store.KVPair{
{Key: path.Join("path", defaultDiscoveryPath, "1.1.1.1"), Value: []byte("1.1.1.1:1111")},
{Key: path.Join("path", defaultDiscoveryPath, "2.2.2.2"), Value: []byte("2.2.2.2:2222")},
}
stopCh := make(chan struct{})
ch, errCh := d.Watch(stopCh)
// It should fire an error since the first WatchTree call failed.
c.Assert(<-errCh, check.ErrorMatches, "test error")
// We have to drain the error channel otherwise Watch will get stuck.
go func() {
for range errCh {
}
}()
// Push the entries into the store channel and make sure discovery emits.
mockCh <- kvs
c.Assert(<-ch, check.DeepEquals, expected)
// Add a new entry.
expected = append(expected, &discovery.Entry{Host: "3.3.3.3", Port: "3333"})
kvs = append(kvs, &store.KVPair{Key: path.Join("path", defaultDiscoveryPath, "3.3.3.3"), Value: []byte("3.3.3.3:3333")})
mockCh <- kvs
c.Assert(<-ch, check.DeepEquals, expected)
close(mockCh)
// Give it enough time to call WatchTree.
time.Sleep(3)
// Stop and make sure it closes all channels.
close(stopCh)
c.Assert(<-ch, check.IsNil)
c.Assert(<-errCh, check.IsNil)
}
// FakeStore implements store.Store methods. It mocks all store
// function in a simple, naive way.
type FakeStore struct {
Endpoints []string
Options *store.Config
mockKVChan <-chan []*store.KVPair
watchTreeCallCount int
}
func (s *FakeStore) Put(key string, value []byte, options *store.WriteOptions) error {
return nil
}
func (s *FakeStore) Get(key string) (*store.KVPair, error) {
return nil, nil
}
func (s *FakeStore) Delete(key string) error {
return nil
}
func (s *FakeStore) Exists(key string) (bool, error) {
return true, nil
}
func (s *FakeStore) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) {
return nil, nil
}
// WatchTree will fail the first time, and return the mockKVchan afterwards.
// This is the behavior we need for testing.. If we need 'moar', should update this.
func (s *FakeStore) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) {
if s.watchTreeCallCount == 0 {
s.watchTreeCallCount = 1
return nil, errors.New("test error")
}
// First calls error
return s.mockKVChan, nil
}
func (s *FakeStore) NewLock(key string, options *store.LockOptions) (store.Locker, error) {
return nil, nil
}
func (s *FakeStore) List(directory string) ([]*store.KVPair, error) {
return []*store.KVPair{}, nil
}
func (s *FakeStore) DeleteTree(directory string) error {
return nil
}
func (s *FakeStore) AtomicPut(key string, value []byte, previous *store.KVPair, options *store.WriteOptions) (bool, *store.KVPair, error) {
return true, nil, nil
}
func (s *FakeStore) AtomicDelete(key string, previous *store.KVPair) (bool, error) {
return true, nil
}
func (s *FakeStore) Close() {
}

View File

@@ -1,83 +0,0 @@
package memory
import (
"time"
"github.com/hyperhq/hypercli/pkg/discovery"
)
// Discovery implements a descovery backend that keeps
// data in memory.
type Discovery struct {
heartbeat time.Duration
values []string
}
func init() {
Init()
}
// Init registers the memory backend on demand.
func Init() {
discovery.Register("memory", &Discovery{})
}
// Initialize sets the heartbeat for the memory backend.
func (s *Discovery) Initialize(_ string, heartbeat time.Duration, _ time.Duration, _ map[string]string) error {
s.heartbeat = heartbeat
s.values = make([]string, 0)
return nil
}
// Watch sends periodic discovery updates to a channel.
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
errCh := make(chan error)
ticker := time.NewTicker(s.heartbeat)
go func() {
defer close(errCh)
defer close(ch)
// Send the initial entries if available.
var currentEntries discovery.Entries
if len(s.values) > 0 {
var err error
currentEntries, err = discovery.CreateEntries(s.values)
if err != nil {
errCh <- err
} else {
ch <- currentEntries
}
}
// Periodically send updates.
for {
select {
case <-ticker.C:
newEntries, err := discovery.CreateEntries(s.values)
if err != nil {
errCh <- err
continue
}
// Check if the file has really changed.
if !newEntries.Equals(currentEntries) {
ch <- newEntries
}
currentEntries = newEntries
case <-stopCh:
ticker.Stop()
return
}
}
}()
return ch, errCh
}
// Register adds a new address to the discovery.
func (s *Discovery) Register(addr string) error {
s.values = append(s.values, addr)
return nil
}

View File

@@ -1,48 +0,0 @@
package memory
import (
"testing"
"github.com/go-check/check"
"github.com/hyperhq/hypercli/pkg/discovery"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type discoverySuite struct{}
var _ = check.Suite(&discoverySuite{})
func (s *discoverySuite) TestWatch(c *check.C) {
d := &Discovery{}
d.Initialize("foo", 1000, 0, nil)
stopCh := make(chan struct{})
ch, errCh := d.Watch(stopCh)
// We have to drain the error channel otherwise Watch will get stuck.
go func() {
for range errCh {
}
}()
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
}
c.Assert(d.Register("1.1.1.1:1111"), check.IsNil)
c.Assert(<-ch, check.DeepEquals, expected)
expected = discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
c.Assert(d.Register("2.2.2.2:2222"), check.IsNil)
c.Assert(<-ch, check.DeepEquals, expected)
// Stop and make sure it closes all channels.
close(stopCh)
c.Assert(<-ch, check.IsNil)
c.Assert(<-errCh, check.IsNil)
}

View File

@@ -1,54 +0,0 @@
package nodes
import (
"fmt"
"strings"
"time"
"github.com/hyperhq/hypercli/pkg/discovery"
)
// Discovery is exported
type Discovery struct {
entries discovery.Entries
}
func init() {
Init()
}
// Init is exported
func Init() {
discovery.Register("nodes", &Discovery{})
}
// Initialize is exported
func (s *Discovery) Initialize(uris string, _ time.Duration, _ time.Duration, _ map[string]string) error {
for _, input := range strings.Split(uris, ",") {
for _, ip := range discovery.Generate(input) {
entry, err := discovery.NewEntry(ip)
if err != nil {
return fmt.Errorf("%s, please check you are using the correct discovery (missing token:// ?)", err.Error())
}
s.entries = append(s.entries, entry)
}
}
return nil
}
// Watch is exported
func (s *Discovery) Watch(stopCh <-chan struct{}) (<-chan discovery.Entries, <-chan error) {
ch := make(chan discovery.Entries)
go func() {
defer close(ch)
ch <- s.entries
<-stopCh
}()
return ch, nil
}
// Register is exported
func (s *Discovery) Register(addr string) error {
return discovery.ErrNotImplemented
}

View File

@@ -1,51 +0,0 @@
package nodes
import (
"testing"
"github.com/hyperhq/hypercli/pkg/discovery"
"github.com/go-check/check"
)
// Hook up gocheck into the "go test" runner.
func Test(t *testing.T) { check.TestingT(t) }
type DiscoverySuite struct{}
var _ = check.Suite(&DiscoverySuite{})
func (s *DiscoverySuite) TestInitialize(c *check.C) {
d := &Discovery{}
d.Initialize("1.1.1.1:1111,2.2.2.2:2222", 0, 0, nil)
c.Assert(len(d.entries), check.Equals, 2)
c.Assert(d.entries[0].String(), check.Equals, "1.1.1.1:1111")
c.Assert(d.entries[1].String(), check.Equals, "2.2.2.2:2222")
}
func (s *DiscoverySuite) TestInitializeWithPattern(c *check.C) {
d := &Discovery{}
d.Initialize("1.1.1.[1:2]:1111,2.2.2.[2:4]:2222", 0, 0, nil)
c.Assert(len(d.entries), check.Equals, 5)
c.Assert(d.entries[0].String(), check.Equals, "1.1.1.1:1111")
c.Assert(d.entries[1].String(), check.Equals, "1.1.1.2:1111")
c.Assert(d.entries[2].String(), check.Equals, "2.2.2.2:2222")
c.Assert(d.entries[3].String(), check.Equals, "2.2.2.3:2222")
c.Assert(d.entries[4].String(), check.Equals, "2.2.2.4:2222")
}
func (s *DiscoverySuite) TestWatch(c *check.C) {
d := &Discovery{}
d.Initialize("1.1.1.1:1111,2.2.2.2:2222", 0, 0, nil)
expected := discovery.Entries{
&discovery.Entry{Host: "1.1.1.1", Port: "1111"},
&discovery.Entry{Host: "2.2.2.2", Port: "2222"},
}
ch, _ := d.Watch(nil)
c.Assert(expected.Equals(<-ch), check.Equals, true)
}
func (s *DiscoverySuite) TestRegister(c *check.C) {
d := &Discovery{}
c.Assert(d.Register("0.0.0.0"), check.NotNil)
}

View File

@@ -1,40 +0,0 @@
// Package filenotify provides a mechanism for watching file(s) for changes.
// Generally leans on fsnotify, but provides a poll-based notifier which fsnotify does not support.
// These are wrapped up in a common interface so that either can be used interchangeably in your code.
package filenotify
import "gopkg.in/fsnotify.v1"
// FileWatcher is an interface for implementing file notification watchers
type FileWatcher interface {
Events() <-chan fsnotify.Event
Errors() <-chan error
Add(name string) error
Remove(name string) error
Close() error
}
// New tries to use an fs-event watcher, and falls back to the poller if there is an error
func New() (FileWatcher, error) {
if watcher, err := NewEventWatcher(); err == nil {
return watcher, nil
}
return NewPollingWatcher(), nil
}
// NewPollingWatcher returns a poll-based file watcher
func NewPollingWatcher() FileWatcher {
return &filePoller{
events: make(chan fsnotify.Event),
errors: make(chan error),
}
}
// NewEventWatcher returns an fs-event based file watcher
func NewEventWatcher() (FileWatcher, error) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return nil, err
}
return &fsNotifyWatcher{watcher}, nil
}

View File

@@ -1,18 +0,0 @@
package filenotify
import "gopkg.in/fsnotify.v1"
// fsNotify wraps the fsnotify package to satisfy the FileNotifer interface
type fsNotifyWatcher struct {
*fsnotify.Watcher
}
// GetEvents returns the fsnotify event channel receiver
func (w *fsNotifyWatcher) Events() <-chan fsnotify.Event {
return w.Watcher.Events
}
// GetErrors returns the fsnotify error channel receiver
func (w *fsNotifyWatcher) Errors() <-chan error {
return w.Watcher.Errors
}

View File

@@ -1,205 +0,0 @@
package filenotify
import (
"errors"
"fmt"
"os"
"sync"
"time"
"github.com/Sirupsen/logrus"
"gopkg.in/fsnotify.v1"
)
var (
// errPollerClosed is returned when the poller is closed
errPollerClosed = errors.New("poller is closed")
// errNoSuchPoller is returned when trying to remove a watch that doesn't exist
errNoSuchWatch = errors.New("poller does not exist")
)
// watchWaitTime is the time to wait between file poll loops
const watchWaitTime = 200 * time.Millisecond
// filePoller is used to poll files for changes, especially in cases where fsnotify
// can't be run (e.g. when inotify handles are exhausted)
// filePoller satisfies the FileWatcher interface
type filePoller struct {
// watches is the list of files currently being polled, close the associated channel to stop the watch
watches map[string]chan struct{}
// events is the channel to listen to for watch events
events chan fsnotify.Event
// errors is the channel to listen to for watch errors
errors chan error
// mu locks the poller for modification
mu sync.Mutex
// closed is used to specify when the poller has already closed
closed bool
}
// Add adds a filename to the list of watches
// once added the file is polled for changes in a separate goroutine
func (w *filePoller) Add(name string) error {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed == true {
return errPollerClosed
}
f, err := os.Open(name)
if err != nil {
return err
}
fi, err := os.Stat(name)
if err != nil {
return err
}
if w.watches == nil {
w.watches = make(map[string]chan struct{})
}
if _, exists := w.watches[name]; exists {
return fmt.Errorf("watch exists")
}
chClose := make(chan struct{})
w.watches[name] = chClose
go w.watch(f, fi, chClose)
return nil
}
// Remove stops and removes watch with the specified name
func (w *filePoller) Remove(name string) error {
w.mu.Lock()
defer w.mu.Unlock()
return w.remove(name)
}
func (w *filePoller) remove(name string) error {
if w.closed == true {
return errPollerClosed
}
chClose, exists := w.watches[name]
if !exists {
return errNoSuchWatch
}
close(chClose)
delete(w.watches, name)
return nil
}
// Events returns the event channel
// This is used for notifications on events about watched files
func (w *filePoller) Events() <-chan fsnotify.Event {
return w.events
}
// Errors returns the errors channel
// This is used for notifications about errors on watched files
func (w *filePoller) Errors() <-chan error {
return w.errors
}
// Close closes the poller
// All watches are stopped, removed, and the poller cannot be added to
func (w *filePoller) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.closed {
return nil
}
w.closed = true
for name := range w.watches {
w.remove(name)
delete(w.watches, name)
}
close(w.events)
close(w.errors)
return nil
}
// sendEvent publishes the specified event to the events channel
func (w *filePoller) sendEvent(e fsnotify.Event, chClose <-chan struct{}) error {
select {
case w.events <- e:
case <-chClose:
return fmt.Errorf("closed")
}
return nil
}
// sendErr publishes the specified error to the errors channel
func (w *filePoller) sendErr(e error, chClose <-chan struct{}) error {
select {
case w.errors <- e:
case <-chClose:
return fmt.Errorf("closed")
}
return nil
}
// watch is responsible for polling the specified file for changes
// upon finding changes to a file or errors, sendEvent/sendErr is called
func (w *filePoller) watch(f *os.File, lastFi os.FileInfo, chClose chan struct{}) {
for {
time.Sleep(watchWaitTime)
select {
case <-chClose:
logrus.Debugf("watch for %s closed", f.Name())
return
default:
}
fi, err := os.Stat(f.Name())
if err != nil {
// if we got an error here and lastFi is not set, we can presume that nothing has changed
// This should be safe since before `watch()` is called, a stat is performed, there is any error `watch` is not called
if lastFi == nil {
continue
}
// If it doesn't exist at this point, it must have been removed
// no need to send the error here since this is a valid operation
if os.IsNotExist(err) {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Remove, Name: f.Name()}, chClose); err != nil {
return
}
lastFi = nil
continue
}
// at this point, send the error
if err := w.sendErr(err, chClose); err != nil {
return
}
continue
}
if lastFi == nil {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Create, Name: fi.Name()}, chClose); err != nil {
return
}
lastFi = fi
continue
}
if fi.Mode() != lastFi.Mode() {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Chmod, Name: fi.Name()}, chClose); err != nil {
return
}
lastFi = fi
continue
}
if fi.ModTime() != lastFi.ModTime() || fi.Size() != lastFi.Size() {
if err := w.sendEvent(fsnotify.Event{Op: fsnotify.Write, Name: fi.Name()}, chClose); err != nil {
return
}
lastFi = fi
continue
}
}
}

View File

@@ -1,133 +0,0 @@
package filenotify
import (
"fmt"
"io/ioutil"
"os"
"testing"
"time"
"gopkg.in/fsnotify.v1"
)
func TestPollerAddRemove(t *testing.T) {
w := NewPollingWatcher()
if err := w.Add("no-such-file"); err == nil {
t.Fatal("should have gotten error when adding a non-existent file")
}
if err := w.Remove("no-such-file"); err == nil {
t.Fatal("should have gotten error when removing non-existent watch")
}
f, err := ioutil.TempFile("", "asdf")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(f.Name())
if err := w.Add(f.Name()); err != nil {
t.Fatal(err)
}
if err := w.Remove(f.Name()); err != nil {
t.Fatal(err)
}
}
func TestPollerEvent(t *testing.T) {
w := NewPollingWatcher()
f, err := ioutil.TempFile("", "test-poller")
if err != nil {
t.Fatal("error creating temp file")
}
defer os.RemoveAll(f.Name())
f.Close()
if err := w.Add(f.Name()); err != nil {
t.Fatal(err)
}
select {
case <-w.Events():
t.Fatal("got event before anything happened")
case <-w.Errors():
t.Fatal("got error before anything happened")
default:
}
if err := ioutil.WriteFile(f.Name(), []byte("hello"), 644); err != nil {
t.Fatal(err)
}
if err := assertEvent(w, fsnotify.Write); err != nil {
t.Fatal(err)
}
if err := os.Chmod(f.Name(), 600); err != nil {
t.Fatal(err)
}
if err := assertEvent(w, fsnotify.Chmod); err != nil {
t.Fatal(err)
}
if err := os.Remove(f.Name()); err != nil {
t.Fatal(err)
}
if err := assertEvent(w, fsnotify.Remove); err != nil {
t.Fatal(err)
}
}
func TestPollerClose(t *testing.T) {
w := NewPollingWatcher()
if err := w.Close(); err != nil {
t.Fatal(err)
}
// test double-close
if err := w.Close(); err != nil {
t.Fatal(err)
}
select {
case _, open := <-w.Events():
if open {
t.Fatal("event chan should be closed")
}
default:
t.Fatal("event chan should be closed")
}
select {
case _, open := <-w.Errors():
if open {
t.Fatal("errors chan should be closed")
}
default:
t.Fatal("errors chan should be closed")
}
f, err := ioutil.TempFile("", "asdf")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(f.Name())
if err := w.Add(f.Name()); err == nil {
t.Fatal("should have gotten error adding watch for closed watcher")
}
}
func assertEvent(w FileWatcher, eType fsnotify.Op) error {
var err error
select {
case e := <-w.Events():
if e.Op != eType {
err = fmt.Errorf("got wrong event type, expected %q: %v", eType, e)
}
case e := <-w.Errors():
err = fmt.Errorf("got unexpected error waiting for events %v: %v", eType, e)
case <-time.After(watchWaitTime * 3):
err = fmt.Errorf("timeout waiting for event %v", eType)
}
return err
}

View File

@@ -1,573 +0,0 @@
package fileutils
import (
"io/ioutil"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"testing"
)
// CopyFile with invalid src
func TestCopyFileWithInvalidSrc(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile("/invalid/file/path", path.Join(tempFolder, "dest"))
if err == nil {
t.Fatal("Should have fail to copy an invalid src file")
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes")
}
}
// CopyFile with invalid dest
func TestCopyFileWithInvalidDest(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
src := path.Join(tempFolder, "file")
err = ioutil.WriteFile(src, []byte("content"), 0740)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile(src, path.Join(tempFolder, "/invalid/dest/path"))
if err == nil {
t.Fatal("Should have fail to copy an invalid src file")
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes")
}
}
// CopyFile with same src and dest
func TestCopyFileWithSameSrcAndDest(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
file := path.Join(tempFolder, "file")
err = ioutil.WriteFile(file, []byte("content"), 0740)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile(file, file)
if err != nil {
t.Fatal(err)
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes as it is the same file.")
}
}
// CopyFile with same src and dest but path is different and not clean
func TestCopyFileWithSameSrcAndDestWithPathNameDifferent(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
testFolder := path.Join(tempFolder, "test")
err = os.MkdirAll(testFolder, 0740)
if err != nil {
t.Fatal(err)
}
file := path.Join(testFolder, "file")
sameFile := testFolder + "/../test/file"
err = ioutil.WriteFile(file, []byte("content"), 0740)
if err != nil {
t.Fatal(err)
}
bytes, err := CopyFile(file, sameFile)
if err != nil {
t.Fatal(err)
}
if bytes != 0 {
t.Fatal("Should have written 0 bytes as it is the same file.")
}
}
func TestCopyFile(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
defer os.RemoveAll(tempFolder)
if err != nil {
t.Fatal(err)
}
src := path.Join(tempFolder, "src")
dest := path.Join(tempFolder, "dest")
ioutil.WriteFile(src, []byte("content"), 0777)
ioutil.WriteFile(dest, []byte("destContent"), 0777)
bytes, err := CopyFile(src, dest)
if err != nil {
t.Fatal(err)
}
if bytes != 7 {
t.Fatalf("Should have written %d bytes but wrote %d", 7, bytes)
}
actual, err := ioutil.ReadFile(dest)
if err != nil {
t.Fatal(err)
}
if string(actual) != "content" {
t.Fatalf("Dest content was '%s', expected '%s'", string(actual), "content")
}
}
// Reading a symlink to a directory must return the directory
func TestReadSymlinkedDirectoryExistingDirectory(t *testing.T) {
var err error
if err = os.Mkdir("/tmp/testReadSymlinkToExistingDirectory", 0777); err != nil {
t.Errorf("failed to create directory: %s", err)
}
if err = os.Symlink("/tmp/testReadSymlinkToExistingDirectory", "/tmp/dirLinkTest"); err != nil {
t.Errorf("failed to create symlink: %s", err)
}
var path string
if path, err = ReadSymlinkedDirectory("/tmp/dirLinkTest"); err != nil {
t.Fatalf("failed to read symlink to directory: %s", err)
}
if path != "/tmp/testReadSymlinkToExistingDirectory" {
t.Fatalf("symlink returned unexpected directory: %s", path)
}
if err = os.Remove("/tmp/testReadSymlinkToExistingDirectory"); err != nil {
t.Errorf("failed to remove temporary directory: %s", err)
}
if err = os.Remove("/tmp/dirLinkTest"); err != nil {
t.Errorf("failed to remove symlink: %s", err)
}
}
// Reading a non-existing symlink must fail
func TestReadSymlinkedDirectoryNonExistingSymlink(t *testing.T) {
var path string
var err error
if path, err = ReadSymlinkedDirectory("/tmp/test/foo/Non/ExistingPath"); err == nil {
t.Fatalf("error expected for non-existing symlink")
}
if path != "" {
t.Fatalf("expected empty path, but '%s' was returned", path)
}
}
// Reading a symlink to a file must fail
func TestReadSymlinkedDirectoryToFile(t *testing.T) {
var err error
var file *os.File
if file, err = os.Create("/tmp/testReadSymlinkToFile"); err != nil {
t.Fatalf("failed to create file: %s", err)
}
file.Close()
if err = os.Symlink("/tmp/testReadSymlinkToFile", "/tmp/fileLinkTest"); err != nil {
t.Errorf("failed to create symlink: %s", err)
}
var path string
if path, err = ReadSymlinkedDirectory("/tmp/fileLinkTest"); err == nil {
t.Fatalf("ReadSymlinkedDirectory on a symlink to a file should've failed")
}
if path != "" {
t.Fatalf("path should've been empty: %s", path)
}
if err = os.Remove("/tmp/testReadSymlinkToFile"); err != nil {
t.Errorf("failed to remove file: %s", err)
}
if err = os.Remove("/tmp/fileLinkTest"); err != nil {
t.Errorf("failed to remove symlink: %s", err)
}
}
func TestWildcardMatches(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"*"})
if match != true {
t.Errorf("failed to get a wildcard match, got %v", match)
}
}
// A simple pattern match should return true.
func TestPatternMatches(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"*.go"})
if match != true {
t.Errorf("failed to get a match, got %v", match)
}
}
// An exclusion followed by an inclusion should return true.
func TestExclusionPatternMatchesPatternBefore(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"!fileutils.go", "*.go"})
if match != true {
t.Errorf("failed to get true match on exclusion pattern, got %v", match)
}
}
// A folder pattern followed by an exception should return false.
func TestPatternMatchesFolderExclusions(t *testing.T) {
match, _ := Matches("docs/README.md", []string{"docs", "!docs/README.md"})
if match != false {
t.Errorf("failed to get a false match on exclusion pattern, got %v", match)
}
}
// A folder pattern followed by an exception should return false.
func TestPatternMatchesFolderWithSlashExclusions(t *testing.T) {
match, _ := Matches("docs/README.md", []string{"docs/", "!docs/README.md"})
if match != false {
t.Errorf("failed to get a false match on exclusion pattern, got %v", match)
}
}
// A folder pattern followed by an exception should return false.
func TestPatternMatchesFolderWildcardExclusions(t *testing.T) {
match, _ := Matches("docs/README.md", []string{"docs/*", "!docs/README.md"})
if match != false {
t.Errorf("failed to get a false match on exclusion pattern, got %v", match)
}
}
// A pattern followed by an exclusion should return false.
func TestExclusionPatternMatchesPatternAfter(t *testing.T) {
match, _ := Matches("fileutils.go", []string{"*.go", "!fileutils.go"})
if match != false {
t.Errorf("failed to get false match on exclusion pattern, got %v", match)
}
}
// A filename evaluating to . should return false.
func TestExclusionPatternMatchesWholeDirectory(t *testing.T) {
match, _ := Matches(".", []string{"*.go"})
if match != false {
t.Errorf("failed to get false match on ., got %v", match)
}
}
// A single ! pattern should return an error.
func TestSingleExclamationError(t *testing.T) {
_, err := Matches("fileutils.go", []string{"!"})
if err == nil {
t.Errorf("failed to get an error for a single exclamation point, got %v", err)
}
}
// A string preceded with a ! should return true from Exclusion.
func TestExclusion(t *testing.T) {
exclusion := exclusion("!")
if !exclusion {
t.Errorf("failed to get true for a single !, got %v", exclusion)
}
}
// Matches with no patterns
func TestMatchesWithNoPatterns(t *testing.T) {
matches, err := Matches("/any/path/there", []string{})
if err != nil {
t.Fatal(err)
}
if matches {
t.Fatalf("Should not have match anything")
}
}
// Matches with malformed patterns
func TestMatchesWithMalformedPatterns(t *testing.T) {
matches, err := Matches("/any/path/there", []string{"["})
if err == nil {
t.Fatal("Should have failed because of a malformed syntax in the pattern")
}
if matches {
t.Fatalf("Should not have match anything")
}
}
// Test lots of variants of patterns & strings
func TestMatches(t *testing.T) {
tests := []struct {
pattern string
text string
pass bool
}{
{"**", "file", true},
{"**", "file/", true},
{"**/", "file", true}, // weird one
{"**/", "file/", true},
{"**", "/", true},
{"**/", "/", true},
{"**", "dir/file", true},
{"**/", "dir/file", false},
{"**", "dir/file/", true},
{"**/", "dir/file/", true},
{"**/**", "dir/file", true},
{"**/**", "dir/file/", true},
{"dir/**", "dir/file", true},
{"dir/**", "dir/file/", true},
{"dir/**", "dir/dir2/file", true},
{"dir/**", "dir/dir2/file/", true},
{"**/dir2/*", "dir/dir2/file", true},
{"**/dir2/*", "dir/dir2/file/", false},
{"**/dir2/**", "dir/dir2/dir3/file", true},
{"**/dir2/**", "dir/dir2/dir3/file/", true},
{"**file", "file", true},
{"**file", "dir/file", true},
{"**/file", "dir/file", true},
{"**file", "dir/dir/file", true},
{"**/file", "dir/dir/file", true},
{"**/file*", "dir/dir/file", true},
{"**/file*", "dir/dir/file.txt", true},
{"**/file*txt", "dir/dir/file.txt", true},
{"**/file*.txt", "dir/dir/file.txt", true},
{"**/file*.txt*", "dir/dir/file.txt", true},
{"**/**/*.txt", "dir/dir/file.txt", true},
{"**/**/*.txt2", "dir/dir/file.txt", false},
{"**/*.txt", "file.txt", true},
{"**/**/*.txt", "file.txt", true},
{"a**/*.txt", "a/file.txt", true},
{"a**/*.txt", "a/dir/file.txt", true},
{"a**/*.txt", "a/dir/dir/file.txt", true},
{"a/*.txt", "a/dir/file.txt", false},
{"a/*.txt", "a/file.txt", true},
{"a/*.txt**", "a/file.txt", true},
{"a[b-d]e", "ae", false},
{"a[b-d]e", "ace", true},
{"a[b-d]e", "aae", false},
{"a[^b-d]e", "aze", true},
{".*", ".foo", true},
{".*", "foo", false},
{"abc.def", "abcdef", false},
{"abc.def", "abc.def", true},
{"abc.def", "abcZdef", false},
{"abc?def", "abcZdef", true},
{"abc?def", "abcdef", false},
{"a\\*b", "a*b", true},
{"a\\", "a", false},
{"a\\", "a\\", false},
{"a\\\\", "a\\", true},
{"**/foo/bar", "foo/bar", true},
{"**/foo/bar", "dir/foo/bar", true},
{"**/foo/bar", "dir/dir2/foo/bar", true},
{"abc/**", "abc", false},
{"abc/**", "abc/def", true},
{"abc/**", "abc/def/ghi", true},
}
for _, test := range tests {
res, _ := regexpMatch(test.pattern, test.text)
if res != test.pass {
t.Fatalf("Failed: %v - res:%v", test, res)
}
}
}
// An empty string should return true from Empty.
func TestEmpty(t *testing.T) {
empty := empty("")
if !empty {
t.Errorf("failed to get true for an empty string, got %v", empty)
}
}
func TestCleanPatterns(t *testing.T) {
cleaned, _, _, _ := CleanPatterns([]string{"docs", "config"})
if len(cleaned) != 2 {
t.Errorf("expected 2 element slice, got %v", len(cleaned))
}
}
func TestCleanPatternsStripEmptyPatterns(t *testing.T) {
cleaned, _, _, _ := CleanPatterns([]string{"docs", "config", ""})
if len(cleaned) != 2 {
t.Errorf("expected 2 element slice, got %v", len(cleaned))
}
}
func TestCleanPatternsExceptionFlag(t *testing.T) {
_, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md"})
if !exceptions {
t.Errorf("expected exceptions to be true, got %v", exceptions)
}
}
func TestCleanPatternsLeadingSpaceTrimmed(t *testing.T) {
_, _, exceptions, _ := CleanPatterns([]string{"docs", " !docs/README.md"})
if !exceptions {
t.Errorf("expected exceptions to be true, got %v", exceptions)
}
}
func TestCleanPatternsTrailingSpaceTrimmed(t *testing.T) {
_, _, exceptions, _ := CleanPatterns([]string{"docs", "!docs/README.md "})
if !exceptions {
t.Errorf("expected exceptions to be true, got %v", exceptions)
}
}
func TestCleanPatternsErrorSingleException(t *testing.T) {
_, _, _, err := CleanPatterns([]string{"!"})
if err == nil {
t.Errorf("expected error on single exclamation point, got %v", err)
}
}
func TestCleanPatternsFolderSplit(t *testing.T) {
_, dirs, _, _ := CleanPatterns([]string{"docs/config/CONFIG.md"})
if dirs[0][0] != "docs" {
t.Errorf("expected first element in dirs slice to be docs, got %v", dirs[0][1])
}
if dirs[0][1] != "config" {
t.Errorf("expected first element in dirs slice to be config, got %v", dirs[0][1])
}
}
func TestCreateIfNotExistsDir(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempFolder)
folderToCreate := filepath.Join(tempFolder, "tocreate")
if err := CreateIfNotExists(folderToCreate, true); err != nil {
t.Fatal(err)
}
fileinfo, err := os.Stat(folderToCreate)
if err != nil {
t.Fatalf("Should have create a folder, got %v", err)
}
if !fileinfo.IsDir() {
t.Fatalf("Should have been a dir, seems it's not")
}
}
func TestCreateIfNotExistsFile(t *testing.T) {
tempFolder, err := ioutil.TempDir("", "docker-fileutils-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempFolder)
fileToCreate := filepath.Join(tempFolder, "file/to/create")
if err := CreateIfNotExists(fileToCreate, false); err != nil {
t.Fatal(err)
}
fileinfo, err := os.Stat(fileToCreate)
if err != nil {
t.Fatalf("Should have create a file, got %v", err)
}
if fileinfo.IsDir() {
t.Fatalf("Should have been a file, seems it's not")
}
}
// These matchTests are stolen from go's filepath Match tests.
type matchTest struct {
pattern, s string
match bool
err error
}
var matchTests = []matchTest{
{"abc", "abc", true, nil},
{"*", "abc", true, nil},
{"*c", "abc", true, nil},
{"a*", "a", true, nil},
{"a*", "abc", true, nil},
{"a*", "ab/c", false, nil},
{"a*/b", "abc/b", true, nil},
{"a*/b", "a/c/b", false, nil},
{"a*b*c*d*e*/f", "axbxcxdxe/f", true, nil},
{"a*b*c*d*e*/f", "axbxcxdxexxx/f", true, nil},
{"a*b*c*d*e*/f", "axbxcxdxe/xxx/f", false, nil},
{"a*b*c*d*e*/f", "axbxcxdxexxx/fff", false, nil},
{"a*b?c*x", "abxbbxdbxebxczzx", true, nil},
{"a*b?c*x", "abxbbxdbxebxczzy", false, nil},
{"ab[c]", "abc", true, nil},
{"ab[b-d]", "abc", true, nil},
{"ab[e-g]", "abc", false, nil},
{"ab[^c]", "abc", false, nil},
{"ab[^b-d]", "abc", false, nil},
{"ab[^e-g]", "abc", true, nil},
{"a\\*b", "a*b", true, nil},
{"a\\*b", "ab", false, nil},
{"a?b", "a☺b", true, nil},
{"a[^a]b", "a☺b", true, nil},
{"a???b", "a☺b", false, nil},
{"a[^a][^a][^a]b", "a☺b", false, nil},
{"[a-ζ]*", "α", true, nil},
{"*[a-ζ]", "A", false, nil},
{"a?b", "a/b", false, nil},
{"a*b", "a/b", false, nil},
{"[\\]a]", "]", true, nil},
{"[\\-]", "-", true, nil},
{"[x\\-]", "x", true, nil},
{"[x\\-]", "-", true, nil},
{"[x\\-]", "z", false, nil},
{"[\\-x]", "x", true, nil},
{"[\\-x]", "-", true, nil},
{"[\\-x]", "a", false, nil},
{"[]a]", "]", false, filepath.ErrBadPattern},
{"[-]", "-", false, filepath.ErrBadPattern},
{"[x-]", "x", false, filepath.ErrBadPattern},
{"[x-]", "-", false, filepath.ErrBadPattern},
{"[x-]", "z", false, filepath.ErrBadPattern},
{"[-x]", "x", false, filepath.ErrBadPattern},
{"[-x]", "-", false, filepath.ErrBadPattern},
{"[-x]", "a", false, filepath.ErrBadPattern},
{"\\", "a", false, filepath.ErrBadPattern},
{"[a-b-c]", "a", false, filepath.ErrBadPattern},
{"[", "a", false, filepath.ErrBadPattern},
{"[^", "a", false, filepath.ErrBadPattern},
{"[^bc", "a", false, filepath.ErrBadPattern},
{"a[", "a", false, filepath.ErrBadPattern}, // was nil but IMO its wrong
{"a[", "ab", false, filepath.ErrBadPattern},
{"*x", "xxx", true, nil},
}
func errp(e error) string {
if e == nil {
return "<nil>"
}
return e.Error()
}
// TestMatch test's our version of filepath.Match, called regexpMatch.
func TestMatch(t *testing.T) {
for _, tt := range matchTests {
pattern := tt.pattern
s := tt.s
if runtime.GOOS == "windows" {
if strings.Index(pattern, "\\") >= 0 {
// no escape allowed on windows.
continue
}
pattern = filepath.Clean(pattern)
s = filepath.Clean(s)
}
ok, err := regexpMatch(pattern, s)
if ok != tt.match || err != tt.err {
t.Fatalf("Match(%#q, %#q) = %v, %q want %v, %q", pattern, s, ok, errp(err), tt.match, errp(tt.err))
}
}
}

View File

@@ -1,100 +0,0 @@
package gitutils
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/hyperhq/hypercli/pkg/symlink"
"github.com/hyperhq/hypercli/pkg/urlutil"
)
// Clone clones a repository into a newly created directory which
// will be under "docker-build-git"
func Clone(remoteURL string) (string, error) {
if !urlutil.IsGitTransport(remoteURL) {
remoteURL = "https://" + remoteURL
}
root, err := ioutil.TempDir("", "docker-build-git")
if err != nil {
return "", err
}
u, err := url.Parse(remoteURL)
if err != nil {
return "", err
}
fragment := u.Fragment
clone := cloneArgs(u, root)
if output, err := git(clone...); err != nil {
return "", fmt.Errorf("Error trying to use git: %s (%s)", err, output)
}
return checkoutGit(fragment, root)
}
func cloneArgs(remoteURL *url.URL, root string) []string {
args := []string{"clone", "--recursive"}
shallow := len(remoteURL.Fragment) == 0
if shallow && strings.HasPrefix(remoteURL.Scheme, "http") {
res, err := http.Head(fmt.Sprintf("%s/info/refs?service=git-upload-pack", remoteURL))
if err != nil || res.Header.Get("Content-Type") != "application/x-git-upload-pack-advertisement" {
shallow = false
}
}
if shallow {
args = append(args, "--depth", "1")
}
if remoteURL.Fragment != "" {
remoteURL.Fragment = ""
}
return append(args, remoteURL.String(), root)
}
func checkoutGit(fragment, root string) (string, error) {
refAndDir := strings.SplitN(fragment, ":", 2)
if len(refAndDir[0]) != 0 {
if output, err := gitWithinDir(root, "checkout", refAndDir[0]); err != nil {
return "", fmt.Errorf("Error trying to use git: %s (%s)", err, output)
}
}
if len(refAndDir) > 1 && len(refAndDir[1]) != 0 {
newCtx, err := symlink.FollowSymlinkInScope(filepath.Join(root, refAndDir[1]), root)
if err != nil {
return "", fmt.Errorf("Error setting git context, %q not within git root: %s", refAndDir[1], err)
}
fi, err := os.Stat(newCtx)
if err != nil {
return "", err
}
if !fi.IsDir() {
return "", fmt.Errorf("Error setting git context, not a directory: %s", newCtx)
}
root = newCtx
}
return root, nil
}
func gitWithinDir(dir string, args ...string) ([]byte, error) {
a := []string{"--work-tree", dir, "--git-dir", filepath.Join(dir, ".git")}
return git(append(a, args...)...)
}
func git(args ...string) ([]byte, error) {
return exec.Command("git", args...).CombinedOutput()
}

View File

@@ -1,186 +0,0 @@
package gitutils
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
"testing"
)
func TestCloneArgsSmartHttp(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
serverURL, _ := url.Parse(server.URL)
serverURL.Path = "/repo.git"
gitURL := serverURL.String()
mux.HandleFunc("/repo.git/info/refs", func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query().Get("service")
w.Header().Set("Content-Type", fmt.Sprintf("application/x-%s-advertisement", q))
})
args := cloneArgs(serverURL, "/tmp")
exp := []string{"clone", "--recursive", "--depth", "1", gitURL, "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCloneArgsDumbHttp(t *testing.T) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
serverURL, _ := url.Parse(server.URL)
serverURL.Path = "/repo.git"
gitURL := serverURL.String()
mux.HandleFunc("/repo.git/info/refs", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
})
args := cloneArgs(serverURL, "/tmp")
exp := []string{"clone", "--recursive", gitURL, "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCloneArgsGit(t *testing.T) {
u, _ := url.Parse("git://github.com/hyperhq/hypercli")
args := cloneArgs(u, "/tmp")
exp := []string{"clone", "--recursive", "--depth", "1", "git://github.com/hyperhq/hypercli", "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCloneArgsStripFragment(t *testing.T) {
u, _ := url.Parse("git://github.com/hyperhq/hypercli#test")
args := cloneArgs(u, "/tmp")
exp := []string{"clone", "--recursive", "git://github.com/hyperhq/hypercli", "/tmp"}
if !reflect.DeepEqual(args, exp) {
t.Fatalf("Expected %v, got %v", exp, args)
}
}
func TestCheckoutGit(t *testing.T) {
root, err := ioutil.TempDir("", "docker-build-git-checkout")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(root)
gitDir := filepath.Join(root, "repo")
_, err = git("init", gitDir)
if err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "config", "user.email", "test@docker.com"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "config", "user.name", "Docker test"); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(gitDir, "Dockerfile"), []byte("FROM scratch"), 0644); err != nil {
t.Fatal(err)
}
subDir := filepath.Join(gitDir, "subdir")
if err = os.Mkdir(subDir, 0755); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(subDir, "Dockerfile"), []byte("FROM scratch\nEXPOSE 5000"), 0644); err != nil {
t.Fatal(err)
}
if err = os.Symlink("../subdir", filepath.Join(gitDir, "parentlink")); err != nil {
t.Fatal(err)
}
if err = os.Symlink("/subdir", filepath.Join(gitDir, "absolutelink")); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "add", "-A"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "commit", "-am", "First commit"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "checkout", "-b", "test"); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(gitDir, "Dockerfile"), []byte("FROM scratch\nEXPOSE 3000"), 0644); err != nil {
t.Fatal(err)
}
if err = ioutil.WriteFile(filepath.Join(subDir, "Dockerfile"), []byte("FROM busybox\nEXPOSE 5000"), 0644); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "add", "-A"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "commit", "-am", "Branch commit"); err != nil {
t.Fatal(err)
}
if _, err = gitWithinDir(gitDir, "checkout", "master"); err != nil {
t.Fatal(err)
}
cases := []struct {
frag string
exp string
fail bool
}{
{"", "FROM scratch", false},
{"master", "FROM scratch", false},
{":subdir", "FROM scratch\nEXPOSE 5000", false},
{":nosubdir", "", true}, // missing directory error
{":Dockerfile", "", true}, // not a directory error
{"master:nosubdir", "", true},
{"master:subdir", "FROM scratch\nEXPOSE 5000", false},
{"master:parentlink", "FROM scratch\nEXPOSE 5000", false},
{"master:absolutelink", "FROM scratch\nEXPOSE 5000", false},
{"master:../subdir", "", true},
{"test", "FROM scratch\nEXPOSE 3000", false},
{"test:", "FROM scratch\nEXPOSE 3000", false},
{"test:subdir", "FROM busybox\nEXPOSE 5000", false},
}
for _, c := range cases {
r, err := checkoutGit(c.frag, gitDir)
fail := err != nil
if fail != c.fail {
t.Fatalf("Expected %v failure, error was %v\n", c.fail, err)
}
if c.fail {
continue
}
b, err := ioutil.ReadFile(filepath.Join(r, "Dockerfile"))
if err != nil {
t.Fatal(err)
}
if string(b) != c.exp {
t.Fatalf("Expected %v, was %v\n", c.exp, string(b))
}
}
}

View File

@@ -1,15 +0,0 @@
// +build cgo
package graphdb
import "database/sql"
// NewSqliteConn opens a connection to a sqlite
// database.
func NewSqliteConn(root string) (*Database, error) {
conn, err := sql.Open("sqlite3", root)
if err != nil {
return nil, err
}
return NewDatabase(conn)
}

View File

@@ -1,7 +0,0 @@
// +build cgo,!windows
package graphdb
import (
_ "github.com/mattn/go-sqlite3" // registers sqlite
)

View File

@@ -1,7 +0,0 @@
// +build cgo,windows
package graphdb
import (
_ "github.com/mattn/go-sqlite3" // registers sqlite
)

View File

@@ -1,8 +0,0 @@
// +build !cgo
package graphdb
// NewSqliteConn return a new sqlite connection.
func NewSqliteConn(root string) (*Database, error) {
panic("Not implemented")
}

View File

@@ -1,551 +0,0 @@
package graphdb
import (
"database/sql"
"fmt"
"path"
"strings"
"sync"
)
const (
createEntityTable = `
CREATE TABLE IF NOT EXISTS entity (
id text NOT NULL PRIMARY KEY
);`
createEdgeTable = `
CREATE TABLE IF NOT EXISTS edge (
"entity_id" text NOT NULL,
"parent_id" text NULL,
"name" text NOT NULL,
CONSTRAINT "parent_fk" FOREIGN KEY ("parent_id") REFERENCES "entity" ("id"),
CONSTRAINT "entity_fk" FOREIGN KEY ("entity_id") REFERENCES "entity" ("id")
);
`
createEdgeIndices = `
CREATE UNIQUE INDEX IF NOT EXISTS "name_parent_ix" ON "edge" (parent_id, name);
`
)
// Entity with a unique id.
type Entity struct {
id string
}
// An Edge connects two entities together.
type Edge struct {
EntityID string
Name string
ParentID string
}
// Entities stores the list of entities.
type Entities map[string]*Entity
// Edges stores the relationships between entities.
type Edges []*Edge
// WalkFunc is a function invoked to process an individual entity.
type WalkFunc func(fullPath string, entity *Entity) error
// Database is a graph database for storing entities and their relationships.
type Database struct {
conn *sql.DB
mux sync.RWMutex
}
// IsNonUniqueNameError processes the error to check if it's caused by
// a constraint violation.
// This is necessary because the error isn't the same across various
// sqlite versions.
func IsNonUniqueNameError(err error) bool {
str := err.Error()
// sqlite 3.7.17-1ubuntu1 returns:
// Set failure: Abort due to constraint violation: columns parent_id, name are not unique
if strings.HasSuffix(str, "name are not unique") {
return true
}
// sqlite-3.8.3-1.fc20 returns:
// Set failure: Abort due to constraint violation: UNIQUE constraint failed: edge.parent_id, edge.name
if strings.Contains(str, "UNIQUE constraint failed") && strings.Contains(str, "edge.name") {
return true
}
// sqlite-3.6.20-1.el6 returns:
// Set failure: Abort due to constraint violation: constraint failed
if strings.HasSuffix(str, "constraint failed") {
return true
}
return false
}
// NewDatabase creates a new graph database initialized with a root entity.
func NewDatabase(conn *sql.DB) (*Database, error) {
if conn == nil {
return nil, fmt.Errorf("Database connection cannot be nil")
}
db := &Database{conn: conn}
// Create root entities
tx, err := conn.Begin()
if err != nil {
return nil, err
}
if _, err := tx.Exec(createEntityTable); err != nil {
return nil, err
}
if _, err := tx.Exec(createEdgeTable); err != nil {
return nil, err
}
if _, err := tx.Exec(createEdgeIndices); err != nil {
return nil, err
}
if _, err := tx.Exec("DELETE FROM entity where id = ?", "0"); err != nil {
tx.Rollback()
return nil, err
}
if _, err := tx.Exec("INSERT INTO entity (id) VALUES (?);", "0"); err != nil {
tx.Rollback()
return nil, err
}
if _, err := tx.Exec("DELETE FROM edge where entity_id=? and name=?", "0", "/"); err != nil {
tx.Rollback()
return nil, err
}
if _, err := tx.Exec("INSERT INTO edge (entity_id, name) VALUES(?,?);", "0", "/"); err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return db, nil
}
// Close the underlying connection to the database.
func (db *Database) Close() error {
return db.conn.Close()
}
// Set the entity id for a given path.
func (db *Database) Set(fullPath, id string) (*Entity, error) {
db.mux.Lock()
defer db.mux.Unlock()
tx, err := db.conn.Begin()
if err != nil {
return nil, err
}
var entityID string
if err := tx.QueryRow("SELECT id FROM entity WHERE id = ?;", id).Scan(&entityID); err != nil {
if err == sql.ErrNoRows {
if _, err := tx.Exec("INSERT INTO entity (id) VALUES(?);", id); err != nil {
tx.Rollback()
return nil, err
}
} else {
tx.Rollback()
return nil, err
}
}
e := &Entity{id}
parentPath, name := splitPath(fullPath)
if err := db.setEdge(parentPath, name, e, tx); err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return e, nil
}
// Exists returns true if a name already exists in the database.
func (db *Database) Exists(name string) bool {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return false
}
return e != nil
}
func (db *Database) setEdge(parentPath, name string, e *Entity, tx *sql.Tx) error {
parent, err := db.get(parentPath)
if err != nil {
return err
}
if parent.id == e.id {
return fmt.Errorf("Cannot set self as child")
}
if _, err := tx.Exec("INSERT INTO edge (parent_id, name, entity_id) VALUES (?,?,?);", parent.id, name, e.id); err != nil {
return err
}
return nil
}
// RootEntity returns the root "/" entity for the database.
func (db *Database) RootEntity() *Entity {
return &Entity{
id: "0",
}
}
// Get returns the entity for a given path.
func (db *Database) Get(name string) *Entity {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return nil
}
return e
}
func (db *Database) get(name string) (*Entity, error) {
e := db.RootEntity()
// We always know the root name so return it if
// it is requested
if name == "/" {
return e, nil
}
parts := split(name)
for i := 1; i < len(parts); i++ {
p := parts[i]
if p == "" {
continue
}
next := db.child(e, p)
if next == nil {
return nil, fmt.Errorf("Cannot find child for %s", name)
}
e = next
}
return e, nil
}
// List all entities by from the name.
// The key will be the full path of the entity.
func (db *Database) List(name string, depth int) Entities {
db.mux.RLock()
defer db.mux.RUnlock()
out := Entities{}
e, err := db.get(name)
if err != nil {
return out
}
children, err := db.children(e, name, depth, nil)
if err != nil {
return out
}
for _, c := range children {
out[c.FullPath] = c.Entity
}
return out
}
// Walk through the child graph of an entity, calling walkFunc for each child entity.
// It is safe for walkFunc to call graph functions.
func (db *Database) Walk(name string, walkFunc WalkFunc, depth int) error {
children, err := db.Children(name, depth)
if err != nil {
return err
}
// Note: the database lock must not be held while calling walkFunc
for _, c := range children {
if err := walkFunc(c.FullPath, c.Entity); err != nil {
return err
}
}
return nil
}
// Children returns the children of the specified entity.
func (db *Database) Children(name string, depth int) ([]WalkMeta, error) {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return nil, err
}
return db.children(e, name, depth, nil)
}
// Parents returns the parents of a specified entity.
func (db *Database) Parents(name string) ([]string, error) {
db.mux.RLock()
defer db.mux.RUnlock()
e, err := db.get(name)
if err != nil {
return nil, err
}
return db.parents(e)
}
// Refs returns the reference count for a specified id.
func (db *Database) Refs(id string) int {
db.mux.RLock()
defer db.mux.RUnlock()
var count int
if err := db.conn.QueryRow("SELECT COUNT(*) FROM edge WHERE entity_id = ?;", id).Scan(&count); err != nil {
return 0
}
return count
}
// RefPaths returns all the id's path references.
func (db *Database) RefPaths(id string) Edges {
db.mux.RLock()
defer db.mux.RUnlock()
refs := Edges{}
rows, err := db.conn.Query("SELECT name, parent_id FROM edge WHERE entity_id = ?;", id)
if err != nil {
return refs
}
defer rows.Close()
for rows.Next() {
var name string
var parentID string
if err := rows.Scan(&name, &parentID); err != nil {
return refs
}
refs = append(refs, &Edge{
EntityID: id,
Name: name,
ParentID: parentID,
})
}
return refs
}
// Delete the reference to an entity at a given path.
func (db *Database) Delete(name string) error {
db.mux.Lock()
defer db.mux.Unlock()
if name == "/" {
return fmt.Errorf("Cannot delete root entity")
}
parentPath, n := splitPath(name)
parent, err := db.get(parentPath)
if err != nil {
return err
}
if _, err := db.conn.Exec("DELETE FROM edge WHERE parent_id = ? AND name = ?;", parent.id, n); err != nil {
return err
}
return nil
}
// Purge removes the entity with the specified id
// Walk the graph to make sure all references to the entity
// are removed and return the number of references removed
func (db *Database) Purge(id string) (int, error) {
db.mux.Lock()
defer db.mux.Unlock()
tx, err := db.conn.Begin()
if err != nil {
return -1, err
}
// Delete all edges
rows, err := tx.Exec("DELETE FROM edge WHERE entity_id = ?;", id)
if err != nil {
tx.Rollback()
return -1, err
}
changes, err := rows.RowsAffected()
if err != nil {
return -1, err
}
// Clear who's using this id as parent
refs, err := tx.Exec("DELETE FROM edge WHERE parent_id = ?;", id)
if err != nil {
tx.Rollback()
return -1, err
}
refsCount, err := refs.RowsAffected()
if err != nil {
return -1, err
}
// Delete entity
if _, err := tx.Exec("DELETE FROM entity where id = ?;", id); err != nil {
tx.Rollback()
return -1, err
}
if err := tx.Commit(); err != nil {
return -1, err
}
return int(changes + refsCount), nil
}
// Rename an edge for a given path
func (db *Database) Rename(currentName, newName string) error {
db.mux.Lock()
defer db.mux.Unlock()
parentPath, name := splitPath(currentName)
newParentPath, newEdgeName := splitPath(newName)
if parentPath != newParentPath {
return fmt.Errorf("Cannot rename when root paths do not match %s != %s", parentPath, newParentPath)
}
parent, err := db.get(parentPath)
if err != nil {
return err
}
rows, err := db.conn.Exec("UPDATE edge SET name = ? WHERE parent_id = ? AND name = ?;", newEdgeName, parent.id, name)
if err != nil {
return err
}
i, err := rows.RowsAffected()
if err != nil {
return err
}
if i == 0 {
return fmt.Errorf("Cannot locate edge for %s %s", parent.id, name)
}
return nil
}
// WalkMeta stores the walk metadata.
type WalkMeta struct {
Parent *Entity
Entity *Entity
FullPath string
Edge *Edge
}
func (db *Database) children(e *Entity, name string, depth int, entities []WalkMeta) ([]WalkMeta, error) {
if e == nil {
return entities, nil
}
rows, err := db.conn.Query("SELECT entity_id, name FROM edge where parent_id = ?;", e.id)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var entityID, entityName string
if err := rows.Scan(&entityID, &entityName); err != nil {
return nil, err
}
child := &Entity{entityID}
edge := &Edge{
ParentID: e.id,
Name: entityName,
EntityID: child.id,
}
meta := WalkMeta{
Parent: e,
Entity: child,
FullPath: path.Join(name, edge.Name),
Edge: edge,
}
entities = append(entities, meta)
if depth != 0 {
nDepth := depth
if depth != -1 {
nDepth--
}
entities, err = db.children(child, meta.FullPath, nDepth, entities)
if err != nil {
return nil, err
}
}
}
return entities, nil
}
func (db *Database) parents(e *Entity) (parents []string, err error) {
if e == nil {
return parents, nil
}
rows, err := db.conn.Query("SELECT parent_id FROM edge where entity_id = ?;", e.id)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var parentID string
if err := rows.Scan(&parentID); err != nil {
return nil, err
}
parents = append(parents, parentID)
}
return parents, nil
}
// Return the entity based on the parent path and name.
func (db *Database) child(parent *Entity, name string) *Entity {
var id string
if err := db.conn.QueryRow("SELECT entity_id FROM edge WHERE parent_id = ? AND name = ?;", parent.id, name).Scan(&id); err != nil {
return nil
}
return &Entity{id}
}
// ID returns the id used to reference this entity.
func (e *Entity) ID() string {
return e.id
}
// Paths returns the paths sorted by depth.
func (e Entities) Paths() []string {
out := make([]string, len(e))
var i int
for k := range e {
out[i] = k
i++
}
sortByDepth(out)
return out
}

View File

@@ -1,657 +0,0 @@
package graphdb
import (
"database/sql"
"fmt"
"os"
"path"
"strconv"
"testing"
_ "github.com/mattn/go-sqlite3"
)
func newTestDb(t *testing.T) (*Database, string) {
p := path.Join(os.TempDir(), "sqlite.db")
conn, err := sql.Open("sqlite3", p)
db, err := NewDatabase(conn)
if err != nil {
t.Fatal(err)
}
return db, p
}
func destroyTestDb(dbPath string) {
os.Remove(dbPath)
}
func TestNewDatabase(t *testing.T) {
db, dbpath := newTestDb(t)
if db == nil {
t.Fatal("Database should not be nil")
}
db.Close()
defer destroyTestDb(dbpath)
}
func TestCreateRootEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
root := db.RootEntity()
if root == nil {
t.Fatal("Root entity should not be nil")
}
}
func TestGetRootEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
e := db.Get("/")
if e == nil {
t.Fatal("Entity should not be nil")
}
if e.ID() != "0" {
t.Fatalf("Entity id should be 0, got %s", e.ID())
}
}
func TestSetEntityWithDifferentName(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/test", "1")
if _, err := db.Set("/other", "1"); err != nil {
t.Fatal(err)
}
}
func TestSetDuplicateEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
if _, err := db.Set("/foo", "42"); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/foo", "43"); err == nil {
t.Fatalf("Creating an entry with a duplicate path did not cause an error")
}
}
func TestCreateChild(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
child, err := db.Set("/db", "1")
if err != nil {
t.Fatal(err)
}
if child == nil {
t.Fatal("Child should not be nil")
}
if child.ID() != "1" {
t.Fail()
}
}
func TestParents(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
for i := 1; i < 6; i++ {
a := strconv.Itoa(i)
if _, err := db.Set("/"+a, a); err != nil {
t.Fatal(err)
}
}
for i := 6; i < 11; i++ {
a := strconv.Itoa(i)
p := strconv.Itoa(i - 5)
key := fmt.Sprintf("/%s/%s", p, a)
if _, err := db.Set(key, a); err != nil {
t.Fatal(err)
}
parents, err := db.Parents(key)
if err != nil {
t.Fatal(err)
}
if len(parents) != 1 {
t.Fatalf("Expected 1 entry for %s got %d", key, len(parents))
}
if parents[0] != p {
t.Fatalf("ID %s received, %s expected", parents[0], p)
}
}
}
func TestChildren(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
str := "/"
for i := 1; i < 6; i++ {
a := strconv.Itoa(i)
if _, err := db.Set(str+a, a); err != nil {
t.Fatal(err)
}
str = str + a + "/"
}
str = "/"
for i := 10; i < 30; i++ { // 20 entities
a := strconv.Itoa(i)
if _, err := db.Set(str+a, a); err != nil {
t.Fatal(err)
}
str = str + a + "/"
}
entries, err := db.Children("/", 5)
if err != nil {
t.Fatal(err)
}
if len(entries) != 11 {
t.Fatalf("Expect 11 entries for / got %d", len(entries))
}
entries, err = db.Children("/", 20)
if err != nil {
t.Fatal(err)
}
if len(entries) != 25 {
t.Fatalf("Expect 25 entries for / got %d", len(entries))
}
}
func TestListAllRootChildren(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
for i := 1; i < 6; i++ {
a := strconv.Itoa(i)
if _, err := db.Set("/"+a, a); err != nil {
t.Fatal(err)
}
}
entries := db.List("/", -1)
if len(entries) != 5 {
t.Fatalf("Expect 5 entries for / got %d", len(entries))
}
}
func TestListAllSubChildren(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
entries := db.List("/webapp", 1)
if len(entries) != 3 {
t.Fatalf("Expect 3 entries for / got %d", len(entries))
}
entries = db.List("/webapp", 0)
if len(entries) != 2 {
t.Fatalf("Expect 2 entries for / got %d", len(entries))
}
}
func TestAddSelfAsChild(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
child, err := db.Set("/test", "1")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/test/other", child.ID()); err == nil {
t.Fatal("Error should not be nil")
}
}
func TestAddChildToNonExistentRoot(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
if _, err := db.Set("/myapp", "1"); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/myapp/proxy/db", "2"); err == nil {
t.Fatal("Error should not be nil")
}
}
func TestWalkAll(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/db/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
if err := db.Walk("/", func(p string, e *Entity) error {
t.Logf("Path: %s Entity: %s", p, e.ID())
return nil
}, -1); err != nil {
t.Fatal(err)
}
}
func TestGetEntityByPath(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
entity := db.Get("/webapp/db/logs")
if entity == nil {
t.Fatal("Entity should not be nil")
}
if entity.ID() != "4" {
t.Fatalf("Expected to get entity with id 4, got %s", entity.ID())
}
}
func TestEnitiesPaths(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
out := db.List("/", -1)
for _, p := range out.Paths() {
t.Log(p)
}
}
func TestDeleteRootEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
if err := db.Delete("/"); err == nil {
t.Fatal("Error should not be nil")
}
}
func TestDeleteEntity(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
_, err := db.Set("/webapp", "1")
if err != nil {
t.Fatal(err)
}
child2, err := db.Set("/db", "2")
if err != nil {
t.Fatal(err)
}
child4, err := db.Set("/logs", "4")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/db/logs", child4.ID()); err != nil {
t.Fatal(err)
}
child3, err := db.Set("/sentry", "3")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/sentry", child3.ID()); err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/db", child2.ID()); err != nil {
t.Fatal(err)
}
child5, err := db.Set("/gograph", "5")
if err != nil {
t.Fatal(err)
}
if _, err := db.Set("/webapp/same-ref-diff-name", child5.ID()); err != nil {
t.Fatal(err)
}
if err := db.Delete("/webapp/sentry"); err != nil {
t.Fatal(err)
}
entity := db.Get("/webapp/sentry")
if entity != nil {
t.Fatal("Entity /webapp/sentry should be nil")
}
}
func TestCountRefs(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
if db.Refs("1") != 1 {
t.Fatal("Expect reference count to be 1")
}
db.Set("/db", "2")
db.Set("/webapp/db", "2")
if db.Refs("2") != 2 {
t.Fatal("Expect reference count to be 2")
}
}
func TestPurgeId(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
if c := db.Refs("1"); c != 1 {
t.Fatalf("Expect reference count to be 1, got %d", c)
}
db.Set("/db", "2")
db.Set("/webapp/db", "2")
count, err := db.Purge("2")
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Fatalf("Expected 2 references to be removed, got %d", count)
}
}
// Regression test https://github.com/hyperhq/hypercli/issues/12334
func TestPurgeIdRefPaths(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
db.Set("/db", "2")
db.Set("/db/webapp", "1")
if c := db.Refs("1"); c != 2 {
t.Fatalf("Expected 2 reference for webapp, got %d", c)
}
if c := db.Refs("2"); c != 1 {
t.Fatalf("Expected 1 reference for db, got %d", c)
}
if rp := db.RefPaths("2"); len(rp) != 1 {
t.Fatalf("Expected 1 reference path for db, got %d", len(rp))
}
count, err := db.Purge("2")
if err != nil {
t.Fatal(err)
}
if count != 2 {
t.Fatalf("Expected 2 rows to be removed, got %d", count)
}
if c := db.Refs("2"); c != 0 {
t.Fatalf("Expected 0 reference for db, got %d", c)
}
if c := db.Refs("1"); c != 1 {
t.Fatalf("Expected 1 reference for webapp, got %d", c)
}
}
func TestRename(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
if db.Refs("1") != 1 {
t.Fatal("Expect reference count to be 1")
}
db.Set("/db", "2")
db.Set("/webapp/db", "2")
if db.Get("/webapp/db") == nil {
t.Fatal("Cannot find entity at path /webapp/db")
}
if err := db.Rename("/webapp/db", "/webapp/newdb"); err != nil {
t.Fatal(err)
}
if db.Get("/webapp/db") != nil {
t.Fatal("Entity should not exist at /webapp/db")
}
if db.Get("/webapp/newdb") == nil {
t.Fatal("Cannot find entity at path /webapp/newdb")
}
}
func TestCreateMultipleNames(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/db", "1")
if _, err := db.Set("/myapp", "1"); err != nil {
t.Fatal(err)
}
db.Walk("/", func(p string, e *Entity) error {
t.Logf("%s\n", p)
return nil
}, -1)
}
func TestRefPaths(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/webapp", "1")
db.Set("/db", "2")
db.Set("/webapp/db", "2")
refs := db.RefPaths("2")
if len(refs) != 2 {
t.Fatalf("Expected reference count to be 2, got %d", len(refs))
}
}
func TestExistsTrue(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/testing", "1")
if !db.Exists("/testing") {
t.Fatalf("/tesing should exist")
}
}
func TestExistsFalse(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/toerhe", "1")
if db.Exists("/testing") {
t.Fatalf("/tesing should not exist")
}
}
func TestGetNameWithTrailingSlash(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
db.Set("/todo", "1")
e := db.Get("/todo/")
if e == nil {
t.Fatalf("Entity should not be nil")
}
}
func TestConcurrentWrites(t *testing.T) {
db, dbpath := newTestDb(t)
defer destroyTestDb(dbpath)
errs := make(chan error, 2)
save := func(name string, id string) {
if _, err := db.Set(fmt.Sprintf("/%s", name), id); err != nil {
errs <- err
}
errs <- nil
}
purge := func(id string) {
if _, err := db.Purge(id); err != nil {
errs <- err
}
errs <- nil
}
save("/1", "1")
go purge("1")
go save("/2", "2")
any := false
for i := 0; i < 2; i++ {
if err := <-errs; err != nil {
any = true
t.Log(err)
}
}
if any {
t.Fail()
}
}

View File

@@ -1,27 +0,0 @@
package graphdb
import "sort"
type pathSorter struct {
paths []string
by func(i, j string) bool
}
func sortByDepth(paths []string) {
s := &pathSorter{paths, func(i, j string) bool {
return PathDepth(i) > PathDepth(j)
}}
sort.Sort(s)
}
func (s *pathSorter) Len() int {
return len(s.paths)
}
func (s *pathSorter) Swap(i, j int) {
s.paths[i], s.paths[j] = s.paths[j], s.paths[i]
}
func (s *pathSorter) Less(i, j int) bool {
return s.by(s.paths[i], s.paths[j])
}

View File

@@ -1,29 +0,0 @@
package graphdb
import (
"testing"
)
func TestSort(t *testing.T) {
paths := []string{
"/",
"/myreallylongname",
"/app/db",
}
sortByDepth(paths)
if len(paths) != 3 {
t.Fatalf("Expected 3 parts got %d", len(paths))
}
if paths[0] != "/app/db" {
t.Fatalf("Expected /app/db got %s", paths[0])
}
if paths[1] != "/myreallylongname" {
t.Fatalf("Expected /myreallylongname got %s", paths[1])
}
if paths[2] != "/" {
t.Fatalf("Expected / got %s", paths[2])
}
}

View File

@@ -1,32 +0,0 @@
package graphdb
import (
"path"
"strings"
)
// Split p on /
func split(p string) []string {
return strings.Split(p, "/")
}
// PathDepth returns the depth or number of / in a given path
func PathDepth(p string) int {
parts := split(p)
if len(parts) == 2 && parts[1] == "" {
return 1
}
return len(parts)
}
func splitPath(p string) (parent, name string) {
if p[0] != '/' {
p = "/" + p
}
parent, name = path.Split(p)
l := len(parent)
if parent[l-1] == '/' {
parent = parent[:l-1]
}
return
}

View File

@@ -1,24 +0,0 @@
package homedir
import (
"path/filepath"
"testing"
)
func TestGet(t *testing.T) {
home := Get()
if home == "" {
t.Fatal("returned home directory is empty")
}
if !filepath.IsAbs(home) {
t.Fatalf("returned path is not absolute: %s", home)
}
}
func TestGetShortcutString(t *testing.T) {
shortcut := GetShortcutString()
if shortcut == "" {
t.Fatal("returned shortcut string is empty")
}
}

View File

@@ -1,115 +0,0 @@
package httputils
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestDownload(t *testing.T) {
expected := "Hello, docker !"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, expected)
}))
defer ts.Close()
response, err := Download(ts.URL)
if err != nil {
t.Fatal(err)
}
actual, err := ioutil.ReadAll(response.Body)
response.Body.Close()
if err != nil || string(actual) != expected {
t.Fatalf("Expected the response %q, got err:%v, response:%v, actual:%s", expected, err, response, string(actual))
}
}
func TestDownload400Errors(t *testing.T) {
expectedError := "Got HTTP status code >= 400: 403 Forbidden"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 403
http.Error(w, "something failed (forbidden)", http.StatusForbidden)
}))
defer ts.Close()
// Expected status code = 403
if _, err := Download(ts.URL); err == nil || err.Error() != expectedError {
t.Fatalf("Expected the the error %q, got %v", expectedError, err)
}
}
func TestDownloadOtherErrors(t *testing.T) {
if _, err := Download("I'm not an url.."); err == nil || !strings.Contains(err.Error(), "unsupported protocol scheme") {
t.Fatalf("Expected an error with 'unsupported protocol scheme', got %v", err)
}
}
func TestNewHTTPRequestError(t *testing.T) {
errorMessage := "Some error message"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 403
http.Error(w, errorMessage, http.StatusForbidden)
}))
defer ts.Close()
httpResponse, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if err := NewHTTPRequestError(errorMessage, httpResponse); err.Error() != errorMessage {
t.Fatalf("Expected err to be %q, got %v", errorMessage, err)
}
}
func TestParseServerHeader(t *testing.T) {
inputs := map[string][]string{
"bad header": {"error"},
"(bad header)": {"error"},
"(without/spaces)": {"error"},
"(header/with spaces)": {"error"},
"foo/bar (baz)": {"foo", "bar", "baz"},
"foo/bar": {"error"},
"foo": {"error"},
"foo/bar (baz space)": {"foo", "bar", "baz space"},
" f f / b b ( b s ) ": {"f f", "b b", "b s"},
"foo/bar (baz) ignore": {"foo", "bar", "baz"},
"foo/bar ()": {"error"},
"foo/bar()": {"error"},
"foo/bar(baz)": {"foo", "bar", "baz"},
"foo/bar/zzz(baz)": {"foo/bar", "zzz", "baz"},
"foo/bar(baz/abc)": {"foo", "bar", "baz/abc"},
"foo/bar(baz (abc))": {"foo", "bar", "baz (abc)"},
}
for header, values := range inputs {
serverHeader, err := ParseServerHeader(header)
if err != nil {
if err != errInvalidHeader {
t.Fatalf("Failed to parse %q, and got some unexpected error: %q", header, err)
}
if values[0] == "error" {
continue
}
t.Fatalf("Header %q failed to parse when it shouldn't have", header)
}
if values[0] == "error" {
t.Fatalf("Header %q parsed ok when it should have failed(%q).", header, serverHeader)
}
if serverHeader.App != values[0] {
t.Fatalf("Expected serverHeader.App for %q to equal %q, got %q", header, values[0], serverHeader.App)
}
if serverHeader.Ver != values[1] {
t.Fatalf("Expected serverHeader.Ver for %q to equal %q, got %q", header, values[1], serverHeader.Ver)
}
if serverHeader.OS != values[2] {
t.Fatalf("Expected serverHeader.OS for %q to equal %q, got %q", header, values[2], serverHeader.OS)
}
}
}

View File

@@ -1,13 +0,0 @@
package httputils
import (
"testing"
)
func TestDetectContentType(t *testing.T) {
input := []byte("That is just a plain text")
if contentType, _, err := DetectContentType(input); err != nil || contentType != "text/plain" {
t.Errorf("TestDetectContentType failed")
}
}

View File

@@ -1,307 +0,0 @@
package httputils
import (
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestResumableRequestHeaderSimpleErrors(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Hello, world !")
}))
defer ts.Close()
client := &http.Client{}
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
expectedError := "client and request can't be nil\n"
resreq := &resumableRequestReader{}
_, err = resreq.Read([]byte{})
if err == nil || err.Error() != expectedError {
t.Fatalf("Expected an error with '%s', got %v.", expectedError, err)
}
resreq = &resumableRequestReader{
client: client,
request: req,
totalSize: -1,
}
expectedError = "failed to auto detect content length"
_, err = resreq.Read([]byte{})
if err == nil || err.Error() != expectedError {
t.Fatalf("Expected an error with '%s', got %v.", expectedError, err)
}
}
// Not too much failures, bails out after some wait
func TestResumableRequestHeaderNotTooMuchFailures(t *testing.T) {
client := &http.Client{}
var badReq *http.Request
badReq, err := http.NewRequest("GET", "I'm not an url", nil)
if err != nil {
t.Fatal(err)
}
resreq := &resumableRequestReader{
client: client,
request: badReq,
failures: 0,
maxFailures: 2,
}
read, err := resreq.Read([]byte{})
if err != nil || read != 0 {
t.Fatalf("Expected no error and no byte read, got err:%v, read:%v.", err, read)
}
}
// Too much failures, returns the error
func TestResumableRequestHeaderTooMuchFailures(t *testing.T) {
client := &http.Client{}
var badReq *http.Request
badReq, err := http.NewRequest("GET", "I'm not an url", nil)
if err != nil {
t.Fatal(err)
}
resreq := &resumableRequestReader{
client: client,
request: badReq,
failures: 0,
maxFailures: 1,
}
defer resreq.Close()
expectedError := `Get I%27m%20not%20an%20url: unsupported protocol scheme ""`
read, err := resreq.Read([]byte{})
if err == nil || err.Error() != expectedError || read != 0 {
t.Fatalf("Expected the error '%s', got err:%v, read:%v.", expectedError, err, read)
}
}
type errorReaderCloser struct{}
func (errorReaderCloser) Close() error { return nil }
func (errorReaderCloser) Read(p []byte) (n int, err error) {
return 0, fmt.Errorf("A error occured")
}
// If a an unknown error is encountered, return 0, nil and log it
func TestResumableRequestReaderWithReadError(t *testing.T) {
var req *http.Request
req, err := http.NewRequest("GET", "", nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
response := &http.Response{
Status: "500 Internal Server",
StatusCode: 500,
ContentLength: 0,
Close: true,
Body: errorReaderCloser{},
}
resreq := &resumableRequestReader{
client: client,
request: req,
currentResponse: response,
lastRange: 1,
totalSize: 1,
}
defer resreq.Close()
buf := make([]byte, 1)
read, err := resreq.Read(buf)
if err != nil {
t.Fatal(err)
}
if read != 0 {
t.Fatalf("Expected to have read nothing, but read %v", read)
}
}
func TestResumableRequestReaderWithEOFWith416Response(t *testing.T) {
var req *http.Request
req, err := http.NewRequest("GET", "", nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
response := &http.Response{
Status: "416 Requested Range Not Satisfiable",
StatusCode: 416,
ContentLength: 0,
Close: true,
Body: ioutil.NopCloser(strings.NewReader("")),
}
resreq := &resumableRequestReader{
client: client,
request: req,
currentResponse: response,
lastRange: 1,
totalSize: 1,
}
defer resreq.Close()
buf := make([]byte, 1)
_, err = resreq.Read(buf)
if err == nil || err != io.EOF {
t.Fatalf("Expected an io.EOF error, got %v", err)
}
}
func TestResumableRequestReaderWithServerDoesntSupportByteRanges(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Range") == "" {
t.Fatalf("Expected a Range HTTP header, got nothing")
}
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
resreq := &resumableRequestReader{
client: client,
request: req,
lastRange: 1,
}
defer resreq.Close()
buf := make([]byte, 2)
_, err = resreq.Read(buf)
if err == nil || err.Error() != "the server doesn't support byte ranges" {
t.Fatalf("Expected an error 'the server doesn't support byte ranges', got %v", err)
}
}
func TestResumableRequestReaderWithZeroTotalSize(t *testing.T) {
srvtxt := "some response text data"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, srvtxt)
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
retries := uint32(5)
resreq := ResumableRequestReader(client, req, retries, 0)
defer resreq.Close()
data, err := ioutil.ReadAll(resreq)
if err != nil {
t.Fatal(err)
}
resstr := strings.TrimSuffix(string(data), "\n")
if resstr != srvtxt {
t.Errorf("resstr != srvtxt")
}
}
func TestResumableRequestReader(t *testing.T) {
srvtxt := "some response text data"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, srvtxt)
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
retries := uint32(5)
imgSize := int64(len(srvtxt))
resreq := ResumableRequestReader(client, req, retries, imgSize)
defer resreq.Close()
data, err := ioutil.ReadAll(resreq)
if err != nil {
t.Fatal(err)
}
resstr := strings.TrimSuffix(string(data), "\n")
if resstr != srvtxt {
t.Errorf("resstr != srvtxt")
}
}
func TestResumableRequestReaderWithInitialResponse(t *testing.T) {
srvtxt := "some response text data"
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, srvtxt)
}))
defer ts.Close()
var req *http.Request
req, err := http.NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatal(err)
}
client := &http.Client{}
retries := uint32(5)
imgSize := int64(len(srvtxt))
res, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
resreq := ResumableRequestReaderWithInitialResponse(client, req, retries, imgSize, res)
defer resreq.Close()
data, err := ioutil.ReadAll(resreq)
if err != nil {
t.Fatal(err)
}
resstr := strings.TrimSuffix(string(data), "\n")
if resstr != srvtxt {
t.Errorf("resstr != srvtxt")
}
}

View File

@@ -1,243 +0,0 @@
// +build !windows
package idtools
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"syscall"
"testing"
)
type node struct {
uid int
gid int
}
func TestMkdirAllAs(t *testing.T) {
dirName, err := ioutil.TempDir("", "mkdirall")
if err != nil {
t.Fatalf("Couldn't create temp dir: %v", err)
}
defer os.RemoveAll(dirName)
testTree := map[string]node{
"usr": {0, 0},
"usr/bin": {0, 0},
"lib": {33, 33},
"lib/x86_64": {45, 45},
"lib/x86_64/share": {1, 1},
}
if err := buildTree(dirName, testTree); err != nil {
t.Fatal(err)
}
// test adding a directory to a pre-existing dir; only the new dir is owned by the uid/gid
if err := MkdirAllAs(filepath.Join(dirName, "usr", "share"), 0755, 99, 99); err != nil {
t.Fatal(err)
}
testTree["usr/share"] = node{99, 99}
verifyTree, err := readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
// test 2-deep new directories--both should be owned by the uid/gid pair
if err := MkdirAllAs(filepath.Join(dirName, "lib", "some", "other"), 0755, 101, 101); err != nil {
t.Fatal(err)
}
testTree["lib/some"] = node{101, 101}
testTree["lib/some/other"] = node{101, 101}
verifyTree, err = readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
// test a directory that already exists; should be chowned, but nothing else
if err := MkdirAllAs(filepath.Join(dirName, "usr"), 0755, 102, 102); err != nil {
t.Fatal(err)
}
testTree["usr"] = node{102, 102}
verifyTree, err = readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
}
func TestMkdirAllNewAs(t *testing.T) {
dirName, err := ioutil.TempDir("", "mkdirnew")
if err != nil {
t.Fatalf("Couldn't create temp dir: %v", err)
}
defer os.RemoveAll(dirName)
testTree := map[string]node{
"usr": {0, 0},
"usr/bin": {0, 0},
"lib": {33, 33},
"lib/x86_64": {45, 45},
"lib/x86_64/share": {1, 1},
}
if err := buildTree(dirName, testTree); err != nil {
t.Fatal(err)
}
// test adding a directory to a pre-existing dir; only the new dir is owned by the uid/gid
if err := MkdirAllNewAs(filepath.Join(dirName, "usr", "share"), 0755, 99, 99); err != nil {
t.Fatal(err)
}
testTree["usr/share"] = node{99, 99}
verifyTree, err := readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
// test 2-deep new directories--both should be owned by the uid/gid pair
if err := MkdirAllNewAs(filepath.Join(dirName, "lib", "some", "other"), 0755, 101, 101); err != nil {
t.Fatal(err)
}
testTree["lib/some"] = node{101, 101}
testTree["lib/some/other"] = node{101, 101}
verifyTree, err = readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
// test a directory that already exists; should NOT be chowned
if err := MkdirAllNewAs(filepath.Join(dirName, "usr"), 0755, 102, 102); err != nil {
t.Fatal(err)
}
verifyTree, err = readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
}
func TestMkdirAs(t *testing.T) {
dirName, err := ioutil.TempDir("", "mkdir")
if err != nil {
t.Fatalf("Couldn't create temp dir: %v", err)
}
defer os.RemoveAll(dirName)
testTree := map[string]node{
"usr": {0, 0},
}
if err := buildTree(dirName, testTree); err != nil {
t.Fatal(err)
}
// test a directory that already exists; should just chown to the requested uid/gid
if err := MkdirAs(filepath.Join(dirName, "usr"), 0755, 99, 99); err != nil {
t.Fatal(err)
}
testTree["usr"] = node{99, 99}
verifyTree, err := readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
// create a subdir under a dir which doesn't exist--should fail
if err := MkdirAs(filepath.Join(dirName, "usr", "bin", "subdir"), 0755, 102, 102); err == nil {
t.Fatalf("Trying to create a directory with Mkdir where the parent doesn't exist should have failed")
}
// create a subdir under an existing dir; should only change the ownership of the new subdir
if err := MkdirAs(filepath.Join(dirName, "usr", "bin"), 0755, 102, 102); err != nil {
t.Fatal(err)
}
testTree["usr/bin"] = node{102, 102}
verifyTree, err = readTree(dirName, "")
if err != nil {
t.Fatal(err)
}
if err := compareTrees(testTree, verifyTree); err != nil {
t.Fatal(err)
}
}
func buildTree(base string, tree map[string]node) error {
for path, node := range tree {
fullPath := filepath.Join(base, path)
if err := os.MkdirAll(fullPath, 0755); err != nil {
return fmt.Errorf("Couldn't create path: %s; error: %v", fullPath, err)
}
if err := os.Chown(fullPath, node.uid, node.gid); err != nil {
return fmt.Errorf("Couldn't chown path: %s; error: %v", fullPath, err)
}
}
return nil
}
func readTree(base, root string) (map[string]node, error) {
tree := make(map[string]node)
dirInfos, err := ioutil.ReadDir(base)
if err != nil {
return nil, fmt.Errorf("Couldn't read directory entries for %q: %v", base, err)
}
for _, info := range dirInfos {
s := &syscall.Stat_t{}
if err := syscall.Stat(filepath.Join(base, info.Name()), s); err != nil {
return nil, fmt.Errorf("Can't stat file %q: %v", filepath.Join(base, info.Name()), err)
}
tree[filepath.Join(root, info.Name())] = node{int(s.Uid), int(s.Gid)}
if info.IsDir() {
// read the subdirectory
subtree, err := readTree(filepath.Join(base, info.Name()), filepath.Join(root, info.Name()))
if err != nil {
return nil, err
}
for path, nodeinfo := range subtree {
tree[path] = nodeinfo
}
}
}
return tree, nil
}
func compareTrees(left, right map[string]node) error {
if len(left) != len(right) {
return fmt.Errorf("Trees aren't the same size")
}
for path, nodeLeft := range left {
if nodeRight, ok := right[path]; ok {
if nodeRight.uid != nodeLeft.uid || nodeRight.gid != nodeLeft.gid {
// mismatch
return fmt.Errorf("mismatched ownership for %q: expected: %d:%d, got: %d:%d", path,
nodeLeft.uid, nodeLeft.gid, nodeRight.uid, nodeRight.gid)
}
continue
}
return fmt.Errorf("right tree didn't contain path %q", path)
}
return nil
}

View File

@@ -1,46 +0,0 @@
// Package checker provide Docker specific implementations of the go-check.Checker interface.
package checker
import (
"github.com/go-check/check"
"github.com/vdemeester/shakers"
)
// As a commodity, we bring all check.Checker variables into the current namespace to avoid having
// to think about check.X versus checker.X.
var (
DeepEquals = check.DeepEquals
ErrorMatches = check.ErrorMatches
FitsTypeOf = check.FitsTypeOf
HasLen = check.HasLen
Implements = check.Implements
IsNil = check.IsNil
Matches = check.Matches
Not = check.Not
NotNil = check.NotNil
PanicMatches = check.PanicMatches
Panics = check.Panics
Contains = shakers.Contains
ContainsAny = shakers.ContainsAny
Count = shakers.Count
Equals = shakers.Equals
EqualFold = shakers.EqualFold
False = shakers.False
GreaterOrEqualThan = shakers.GreaterOrEqualThan
GreaterThan = shakers.GreaterThan
HasPrefix = shakers.HasPrefix
HasSuffix = shakers.HasSuffix
Index = shakers.Index
IndexAny = shakers.IndexAny
IsAfter = shakers.IsAfter
IsBefore = shakers.IsBefore
IsBetween = shakers.IsBetween
IsLower = shakers.IsLower
IsUpper = shakers.IsUpper
LessOrEqualThan = shakers.LessOrEqualThan
LessThan = shakers.LessThan
TimeEquals = shakers.TimeEquals
True = shakers.True
TimeIgnore = shakers.TimeIgnore
)

View File

@@ -1,71 +0,0 @@
package integration
import (
"fmt"
"os/exec"
"strings"
"time"
"github.com/go-check/check"
)
var execCommand = exec.Command
// DockerCmdWithError executes a docker command that is supposed to fail and returns
// the output, the exit code and the error.
func DockerCmdWithError(dockerBinary string, args ...string) (string, int, error) {
return RunCommandWithOutput(execCommand(dockerBinary, args...))
}
// DockerCmdWithStdoutStderr executes a docker command and returns the content of the
// stdout, stderr and the exit code. If a check.C is passed, it will fail and stop tests
// if the error is not nil.
func DockerCmdWithStdoutStderr(dockerBinary string, c *check.C, args ...string) (string, string, int) {
stdout, stderr, status, err := RunCommandWithStdoutStderr(execCommand(dockerBinary, args...))
if c != nil {
c.Assert(err, check.IsNil, check.Commentf("%q failed with errors: %s, %v", strings.Join(args, " "), stderr, err))
}
return stdout, stderr, status
}
// DockerCmd executes a docker command and returns the output and the exit code. If the
// command returns an error, it will fail and stop the tests.
func DockerCmd(dockerBinary string, c *check.C, args ...string) (string, int) {
out, status, err := RunCommandWithOutput(execCommand(dockerBinary, args...))
c.Assert(err, check.IsNil, check.Commentf("%q failed with errors: %s, %v", strings.Join(args, " "), out, err))
return out, status
}
// DockerCmdWithTimeout executes a docker command with a timeout, and returns the output,
// the exit code and the error (if any).
func DockerCmdWithTimeout(dockerBinary string, timeout time.Duration, args ...string) (string, int, error) {
out, status, err := RunCommandWithOutputAndTimeout(execCommand(dockerBinary, args...), timeout)
if err != nil {
return out, status, fmt.Errorf("%q failed with errors: %v : %q", strings.Join(args, " "), err, out)
}
return out, status, err
}
// DockerCmdInDir executes a docker command in a directory and returns the output, the
// exit code and the error (if any).
func DockerCmdInDir(dockerBinary string, path string, args ...string) (string, int, error) {
dockerCommand := execCommand(dockerBinary, args...)
dockerCommand.Dir = path
out, status, err := RunCommandWithOutput(dockerCommand)
if err != nil {
return out, status, fmt.Errorf("%q failed with errors: %v : %q", strings.Join(args, " "), err, out)
}
return out, status, err
}
// DockerCmdInDirWithTimeout executes a docker command in a directory with a timeout and
// returns the output, the exit code and the error (if any).
func DockerCmdInDirWithTimeout(dockerBinary string, timeout time.Duration, path string, args ...string) (string, int, error) {
dockerCommand := execCommand(dockerBinary, args...)
dockerCommand.Dir = path
out, status, err := RunCommandWithOutputAndTimeout(dockerCommand, timeout)
if err != nil {
return out, status, fmt.Errorf("%q failed with errors: %v : %q", strings.Join(args, " "), err, out)
}
return out, status, err
}

View File

@@ -1,405 +0,0 @@
package integration
import (
"fmt"
"os"
"os/exec"
"testing"
"io/ioutil"
"strings"
"time"
"github.com/go-check/check"
)
const dockerBinary = "docker"
// Setup go-check for this test
func Test(t *testing.T) {
check.TestingT(t)
}
func init() {
check.Suite(&DockerCmdSuite{})
}
type DockerCmdSuite struct{}
// Fake the exec.Command to use our mock.
func (s *DockerCmdSuite) SetUpTest(c *check.C) {
execCommand = fakeExecCommand
}
// And bring it back to normal after the test.
func (s *DockerCmdSuite) TearDownTest(c *check.C) {
execCommand = exec.Command
}
// DockerCmdWithError tests
func (s *DockerCmdSuite) TestDockerCmdWithError(c *check.C) {
cmds := []struct {
binary string
args []string
expectedOut string
expectedExitCode int
expectedError error
}{
{
"doesnotexists",
[]string{},
"Command doesnotexists not found.",
1,
fmt.Errorf("exit status 1"),
},
{
dockerBinary,
[]string{"an", "error"},
"an error has occurred",
1,
fmt.Errorf("exit status 1"),
},
{
dockerBinary,
[]string{"an", "exitCode", "127"},
"an error has occurred with exitCode 127",
127,
fmt.Errorf("exit status 127"),
},
{
dockerBinary,
[]string{"run", "-ti", "ubuntu", "echo", "hello"},
"hello",
0,
nil,
},
}
for _, cmd := range cmds {
out, exitCode, error := DockerCmdWithError(cmd.binary, cmd.args...)
c.Assert(out, check.Equals, cmd.expectedOut, check.Commentf("Expected output %q for arguments %v, got %q", cmd.expectedOut, cmd.args, out))
c.Assert(exitCode, check.Equals, cmd.expectedExitCode, check.Commentf("Expected exitCode %q for arguments %v, got %q", cmd.expectedExitCode, cmd.args, exitCode))
if cmd.expectedError != nil {
c.Assert(error, check.NotNil, check.Commentf("Expected an error %q, got nothing", cmd.expectedError))
c.Assert(error.Error(), check.Equals, cmd.expectedError.Error(), check.Commentf("Expected error %q for arguments %v, got %q", cmd.expectedError.Error(), cmd.args, error.Error()))
} else {
c.Assert(error, check.IsNil, check.Commentf("Expected no error, got %v", error))
}
}
}
// DockerCmdWithStdoutStderr tests
type dockerCmdWithStdoutStderrErrorSuite struct{}
func (s *dockerCmdWithStdoutStderrErrorSuite) Test(c *check.C) {
// Should fail, the test too
DockerCmdWithStdoutStderr(dockerBinary, c, "an", "error")
}
type dockerCmdWithStdoutStderrSuccessSuite struct{}
func (s *dockerCmdWithStdoutStderrSuccessSuite) Test(c *check.C) {
stdout, stderr, exitCode := DockerCmdWithStdoutStderr(dockerBinary, c, "run", "-ti", "ubuntu", "echo", "hello")
c.Assert(stdout, check.Equals, "hello")
c.Assert(stderr, check.Equals, "")
c.Assert(exitCode, check.Equals, 0)
}
func (s *DockerCmdSuite) TestDockerCmdWithStdoutStderrError(c *check.C) {
// Run error suite, should fail.
output := String{}
result := check.Run(&dockerCmdWithStdoutStderrErrorSuite{}, &check.RunConf{Output: &output})
c.Check(result.Succeeded, check.Equals, 0)
c.Check(result.Failed, check.Equals, 1)
}
func (s *DockerCmdSuite) TestDockerCmdWithStdoutStderrSuccess(c *check.C) {
// Run error suite, should fail.
output := String{}
result := check.Run(&dockerCmdWithStdoutStderrSuccessSuite{}, &check.RunConf{Output: &output})
c.Check(result.Succeeded, check.Equals, 1)
c.Check(result.Failed, check.Equals, 0)
}
// DockerCmd tests
type dockerCmdErrorSuite struct{}
func (s *dockerCmdErrorSuite) Test(c *check.C) {
// Should fail, the test too
DockerCmd(dockerBinary, c, "an", "error")
}
type dockerCmdSuccessSuite struct{}
func (s *dockerCmdSuccessSuite) Test(c *check.C) {
stdout, exitCode := DockerCmd(dockerBinary, c, "run", "-ti", "ubuntu", "echo", "hello")
c.Assert(stdout, check.Equals, "hello")
c.Assert(exitCode, check.Equals, 0)
}
func (s *DockerCmdSuite) TestDockerCmdError(c *check.C) {
// Run error suite, should fail.
output := String{}
result := check.Run(&dockerCmdErrorSuite{}, &check.RunConf{Output: &output})
c.Check(result.Succeeded, check.Equals, 0)
c.Check(result.Failed, check.Equals, 1)
}
func (s *DockerCmdSuite) TestDockerCmdSuccess(c *check.C) {
// Run error suite, should fail.
output := String{}
result := check.Run(&dockerCmdSuccessSuite{}, &check.RunConf{Output: &output})
c.Check(result.Succeeded, check.Equals, 1)
c.Check(result.Failed, check.Equals, 0)
}
// DockerCmdWithTimeout tests
func (s *DockerCmdSuite) TestDockerCmdWithTimeout(c *check.C) {
cmds := []struct {
binary string
args []string
timeout time.Duration
expectedOut string
expectedExitCode int
expectedError error
}{
{
"doesnotexists",
[]string{},
200 * time.Millisecond,
`Command doesnotexists not found.`,
1,
fmt.Errorf(`"" failed with errors: exit status 1 : "Command doesnotexists not found."`),
},
{
dockerBinary,
[]string{"an", "error"},
200 * time.Millisecond,
`an error has occurred`,
1,
fmt.Errorf(`"an error" failed with errors: exit status 1 : "an error has occurred"`),
},
{
dockerBinary,
[]string{"a", "command", "that", "times", "out"},
5 * time.Millisecond,
"",
0,
fmt.Errorf(`"a command that times out" failed with errors: command timed out : ""`),
},
{
dockerBinary,
[]string{"run", "-ti", "ubuntu", "echo", "hello"},
200 * time.Millisecond,
"hello",
0,
nil,
},
}
for _, cmd := range cmds {
out, exitCode, error := DockerCmdWithTimeout(cmd.binary, cmd.timeout, cmd.args...)
c.Assert(out, check.Equals, cmd.expectedOut, check.Commentf("Expected output %q for arguments %v, got %q", cmd.expectedOut, cmd.args, out))
c.Assert(exitCode, check.Equals, cmd.expectedExitCode, check.Commentf("Expected exitCode %q for arguments %v, got %q", cmd.expectedExitCode, cmd.args, exitCode))
if cmd.expectedError != nil {
c.Assert(error, check.NotNil, check.Commentf("Expected an error %q, got nothing", cmd.expectedError))
c.Assert(error.Error(), check.Equals, cmd.expectedError.Error(), check.Commentf("Expected error %q for arguments %v, got %q", cmd.expectedError.Error(), cmd.args, error.Error()))
} else {
c.Assert(error, check.IsNil, check.Commentf("Expected no error, got %v", error))
}
}
}
// DockerCmdInDir tests
func (s *DockerCmdSuite) TestDockerCmdInDir(c *check.C) {
tempFolder, err := ioutil.TempDir("", "test-docker-cmd-in-dir")
c.Assert(err, check.IsNil)
cmds := []struct {
binary string
args []string
expectedOut string
expectedExitCode int
expectedError error
}{
{
"doesnotexists",
[]string{},
`Command doesnotexists not found.`,
1,
fmt.Errorf(`"dir:%s" failed with errors: exit status 1 : "Command doesnotexists not found."`, tempFolder),
},
{
dockerBinary,
[]string{"an", "error"},
`an error has occurred`,
1,
fmt.Errorf(`"dir:%s an error" failed with errors: exit status 1 : "an error has occurred"`, tempFolder),
},
{
dockerBinary,
[]string{"run", "-ti", "ubuntu", "echo", "hello"},
"hello",
0,
nil,
},
}
for _, cmd := range cmds {
// We prepend the arguments with dir:thefolder.. the fake command will check
// that the current workdir is the same as the one we are passing.
args := append([]string{"dir:" + tempFolder}, cmd.args...)
out, exitCode, error := DockerCmdInDir(cmd.binary, tempFolder, args...)
c.Assert(out, check.Equals, cmd.expectedOut, check.Commentf("Expected output %q for arguments %v, got %q", cmd.expectedOut, cmd.args, out))
c.Assert(exitCode, check.Equals, cmd.expectedExitCode, check.Commentf("Expected exitCode %q for arguments %v, got %q", cmd.expectedExitCode, cmd.args, exitCode))
if cmd.expectedError != nil {
c.Assert(error, check.NotNil, check.Commentf("Expected an error %q, got nothing", cmd.expectedError))
c.Assert(error.Error(), check.Equals, cmd.expectedError.Error(), check.Commentf("Expected error %q for arguments %v, got %q", cmd.expectedError.Error(), cmd.args, error.Error()))
} else {
c.Assert(error, check.IsNil, check.Commentf("Expected no error, got %v", error))
}
}
}
// DockerCmdInDirWithTimeout tests
func (s *DockerCmdSuite) TestDockerCmdInDirWithTimeout(c *check.C) {
tempFolder, err := ioutil.TempDir("", "test-docker-cmd-in-dir")
c.Assert(err, check.IsNil)
cmds := []struct {
binary string
args []string
timeout time.Duration
expectedOut string
expectedExitCode int
expectedError error
}{
{
"doesnotexists",
[]string{},
200 * time.Millisecond,
`Command doesnotexists not found.`,
1,
fmt.Errorf(`"dir:%s" failed with errors: exit status 1 : "Command doesnotexists not found."`, tempFolder),
},
{
dockerBinary,
[]string{"an", "error"},
200 * time.Millisecond,
`an error has occurred`,
1,
fmt.Errorf(`"dir:%s an error" failed with errors: exit status 1 : "an error has occurred"`, tempFolder),
},
{
dockerBinary,
[]string{"a", "command", "that", "times", "out"},
5 * time.Millisecond,
"",
0,
fmt.Errorf(`"dir:%s a command that times out" failed with errors: command timed out : ""`, tempFolder),
},
{
dockerBinary,
[]string{"run", "-ti", "ubuntu", "echo", "hello"},
200 * time.Millisecond,
"hello",
0,
nil,
},
}
for _, cmd := range cmds {
// We prepend the arguments with dir:thefolder.. the fake command will check
// that the current workdir is the same as the one we are passing.
args := append([]string{"dir:" + tempFolder}, cmd.args...)
out, exitCode, error := DockerCmdInDirWithTimeout(cmd.binary, cmd.timeout, tempFolder, args...)
c.Assert(out, check.Equals, cmd.expectedOut, check.Commentf("Expected output %q for arguments %v, got %q", cmd.expectedOut, cmd.args, out))
c.Assert(exitCode, check.Equals, cmd.expectedExitCode, check.Commentf("Expected exitCode %q for arguments %v, got %q", cmd.expectedExitCode, cmd.args, exitCode))
if cmd.expectedError != nil {
c.Assert(error, check.NotNil, check.Commentf("Expected an error %q, got nothing", cmd.expectedError))
c.Assert(error.Error(), check.Equals, cmd.expectedError.Error(), check.Commentf("Expected error %q for arguments %v, got %q", cmd.expectedError.Error(), cmd.args, error.Error()))
} else {
c.Assert(error, check.IsNil, check.Commentf("Expected no error, got %v", error))
}
}
}
// Helpers :)
// Type implementing the io.Writer interface for analyzing output.
type String struct {
value string
}
// The only function required by the io.Writer interface. Will append
// written data to the String.value string.
func (s *String) Write(p []byte) (n int, err error) {
s.value += string(p)
return len(p), nil
}
// Helper function that mock the exec.Command call (and call the test binary)
func fakeExecCommand(command string, args ...string) *exec.Cmd {
cs := []string{"-test.run=TestHelperProcess", "--", command}
cs = append(cs, args...)
cmd := exec.Command(os.Args[0], cs...)
cmd.Env = []string{"GO_WANT_HELPER_PROCESS=1"}
return cmd
}
func TestHelperProcess(t *testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
return
}
args := os.Args
// Previous arguments are tests stuff, that looks like :
// /tmp/go-build970079519/…/_test/integration.test -test.run=TestHelperProcess --
cmd, args := args[3], args[4:]
// Handle the case where args[0] is dir:...
if len(args) > 0 && strings.HasPrefix(args[0], "dir:") {
expectedCwd := args[0][4:]
if len(args) > 1 {
args = args[1:]
}
cwd, err := os.Getwd()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to get workingdir: %v", err)
os.Exit(1)
}
// This checks that the given path is the same as the currend working dire
if expectedCwd != cwd {
fmt.Fprintf(os.Stderr, "Current workdir should be %q, but is %q", expectedCwd, cwd)
}
}
switch cmd {
case dockerBinary:
argsStr := strings.Join(args, " ")
switch argsStr {
case "an exitCode 127":
fmt.Fprintf(os.Stderr, "an error has occurred with exitCode 127")
os.Exit(127)
case "an error":
fmt.Fprintf(os.Stderr, "an error has occurred")
os.Exit(1)
case "a command that times out":
time.Sleep(10 * time.Second)
fmt.Fprintf(os.Stdout, "too long, should be killed")
// A random exit code (that should never happened in tests)
os.Exit(7)
case "run -ti ubuntu echo hello":
fmt.Fprintf(os.Stdout, "hello")
default:
fmt.Fprintf(os.Stdout, "no arguments")
}
default:
fmt.Fprintf(os.Stderr, "Command %s not found.", cmd)
os.Exit(1)
}
// some code here to check arguments perhaps?
os.Exit(0)
}

View File

@@ -1,361 +0,0 @@
package integration
import (
"archive/tar"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"reflect"
"strings"
"syscall"
"time"
"github.com/hyperhq/hypercli/pkg/stringutils"
)
// GetExitCode returns the ExitStatus of the specified error if its type is
// exec.ExitError, returns 0 and an error otherwise.
func GetExitCode(err error) (int, error) {
exitCode := 0
if exiterr, ok := err.(*exec.ExitError); ok {
if procExit, ok := exiterr.Sys().(syscall.WaitStatus); ok {
return procExit.ExitStatus(), nil
}
}
return exitCode, fmt.Errorf("failed to get exit code")
}
// ProcessExitCode process the specified error and returns the exit status code
// if the error was of type exec.ExitError, returns nothing otherwise.
func ProcessExitCode(err error) (exitCode int) {
if err != nil {
var exiterr error
if exitCode, exiterr = GetExitCode(err); exiterr != nil {
// TODO: Fix this so we check the error's text.
// we've failed to retrieve exit code, so we set it to 127
exitCode = 127
}
}
return
}
// IsKilled process the specified error and returns whether the process was killed or not.
func IsKilled(err error) bool {
if exitErr, ok := err.(*exec.ExitError); ok {
status, ok := exitErr.Sys().(syscall.WaitStatus)
if !ok {
return false
}
// status.ExitStatus() is required on Windows because it does not
// implement Signal() nor Signaled(). Just check it had a bad exit
// status could mean it was killed (and in tests we do kill)
return (status.Signaled() && status.Signal() == os.Kill) || status.ExitStatus() != 0
}
return false
}
// RunCommandWithOutput runs the specified command and returns the combined output (stdout/stderr)
// with the exitCode different from 0 and the error if something bad happened
func RunCommandWithOutput(cmd *exec.Cmd) (output string, exitCode int, err error) {
exitCode = 0
out, err := cmd.CombinedOutput()
exitCode = ProcessExitCode(err)
output = string(out)
return
}
// RunCommandWithStdoutStderr runs the specified command and returns stdout and stderr separately
// with the exitCode different from 0 and the error if something bad happened
func RunCommandWithStdoutStderr(cmd *exec.Cmd) (stdout string, stderr string, exitCode int, err error) {
var (
stderrBuffer, stdoutBuffer bytes.Buffer
)
exitCode = 0
cmd.Stderr = &stderrBuffer
cmd.Stdout = &stdoutBuffer
err = cmd.Run()
exitCode = ProcessExitCode(err)
stdout = stdoutBuffer.String()
stderr = stderrBuffer.String()
return
}
// RunCommandWithOutputForDuration runs the specified command "timeboxed" by the specified duration.
// If the process is still running when the timebox is finished, the process will be killed and .
// It will returns the output with the exitCode different from 0 and the error if something bad happened
// and a boolean whether it has been killed or not.
func RunCommandWithOutputForDuration(cmd *exec.Cmd, duration time.Duration) (output string, exitCode int, timedOut bool, err error) {
var outputBuffer bytes.Buffer
if cmd.Stdout != nil {
err = errors.New("cmd.Stdout already set")
return
}
cmd.Stdout = &outputBuffer
if cmd.Stderr != nil {
err = errors.New("cmd.Stderr already set")
return
}
cmd.Stderr = &outputBuffer
// Start the command in the main thread..
err = cmd.Start()
if err != nil {
err = fmt.Errorf("Fail to start command %v : %v", cmd, err)
}
type exitInfo struct {
exitErr error
exitCode int
}
done := make(chan exitInfo, 1)
go func() {
// And wait for it to exit in the goroutine :)
info := exitInfo{}
info.exitErr = cmd.Wait()
info.exitCode = ProcessExitCode(info.exitErr)
done <- info
}()
select {
case <-time.After(duration):
killErr := cmd.Process.Kill()
if killErr != nil {
fmt.Printf("failed to kill (pid=%d): %v\n", cmd.Process.Pid, killErr)
}
timedOut = true
case info := <-done:
err = info.exitErr
exitCode = info.exitCode
}
output = outputBuffer.String()
return
}
var errCmdTimeout = fmt.Errorf("command timed out")
// RunCommandWithOutputAndTimeout runs the specified command "timeboxed" by the specified duration.
// It returns the output with the exitCode different from 0 and the error if something bad happened or
// if the process timed out (and has been killed).
func RunCommandWithOutputAndTimeout(cmd *exec.Cmd, timeout time.Duration) (output string, exitCode int, err error) {
var timedOut bool
output, exitCode, timedOut, err = RunCommandWithOutputForDuration(cmd, timeout)
if timedOut {
err = errCmdTimeout
}
return
}
// RunCommand runs the specified command and returns the exitCode different from 0
// and the error if something bad happened.
func RunCommand(cmd *exec.Cmd) (exitCode int, err error) {
exitCode = 0
err = cmd.Run()
exitCode = ProcessExitCode(err)
return
}
// RunCommandPipelineWithOutput runs the array of commands with the output
// of each pipelined with the following (like cmd1 | cmd2 | cmd3 would do).
// It returns the final output, the exitCode different from 0 and the error
// if something bad happened.
func RunCommandPipelineWithOutput(cmds ...*exec.Cmd) (output string, exitCode int, err error) {
if len(cmds) < 2 {
return "", 0, errors.New("pipeline does not have multiple cmds")
}
// connect stdin of each cmd to stdout pipe of previous cmd
for i, cmd := range cmds {
if i > 0 {
prevCmd := cmds[i-1]
cmd.Stdin, err = prevCmd.StdoutPipe()
if err != nil {
return "", 0, fmt.Errorf("cannot set stdout pipe for %s: %v", cmd.Path, err)
}
}
}
// start all cmds except the last
for _, cmd := range cmds[:len(cmds)-1] {
if err = cmd.Start(); err != nil {
return "", 0, fmt.Errorf("starting %s failed with error: %v", cmd.Path, err)
}
}
var pipelineError error
defer func() {
// wait all cmds except the last to release their resources
for _, cmd := range cmds[:len(cmds)-1] {
if err := cmd.Wait(); err != nil {
pipelineError = fmt.Errorf("command %s failed with error: %v", cmd.Path, err)
break
}
}
}()
if pipelineError != nil {
return "", 0, pipelineError
}
// wait on last cmd
return RunCommandWithOutput(cmds[len(cmds)-1])
}
// UnmarshalJSON deserialize a JSON in the given interface.
func UnmarshalJSON(data []byte, result interface{}) error {
if err := json.Unmarshal(data, result); err != nil {
return err
}
return nil
}
// ConvertSliceOfStringsToMap converts a slices of string in a map
// with the strings as key and an empty string as values.
func ConvertSliceOfStringsToMap(input []string) map[string]struct{} {
output := make(map[string]struct{})
for _, v := range input {
output[v] = struct{}{}
}
return output
}
// CompareDirectoryEntries compares two sets of FileInfo (usually taken from a directory)
// and returns an error if different.
func CompareDirectoryEntries(e1 []os.FileInfo, e2 []os.FileInfo) error {
var (
e1Entries = make(map[string]struct{})
e2Entries = make(map[string]struct{})
)
for _, e := range e1 {
e1Entries[e.Name()] = struct{}{}
}
for _, e := range e2 {
e2Entries[e.Name()] = struct{}{}
}
if !reflect.DeepEqual(e1Entries, e2Entries) {
return fmt.Errorf("entries differ")
}
return nil
}
// ListTar lists the entries of a tar.
func ListTar(f io.Reader) ([]string, error) {
tr := tar.NewReader(f)
var entries []string
for {
th, err := tr.Next()
if err == io.EOF {
// end of tar archive
return entries, nil
}
if err != nil {
return entries, err
}
entries = append(entries, th.Name)
}
}
// RandomTmpDirPath provides a temporary path with rand string appended.
// does not create or checks if it exists.
func RandomTmpDirPath(s string, platform string) string {
tmp := "/tmp"
if platform == "windows" {
tmp = os.Getenv("TEMP")
}
path := filepath.Join(tmp, fmt.Sprintf("%s.%s", s, stringutils.GenerateRandomAlphaOnlyString(10)))
if platform == "windows" {
return filepath.FromSlash(path) // Using \
}
return filepath.ToSlash(path) // Using /
}
// ConsumeWithSpeed reads chunkSize bytes from reader before sleeping
// for interval duration. Returns total read bytes. Send true to the
// stop channel to return before reading to EOF on the reader.
func ConsumeWithSpeed(reader io.Reader, chunkSize int, interval time.Duration, stop chan bool) (n int, err error) {
buffer := make([]byte, chunkSize)
for {
var readBytes int
readBytes, err = reader.Read(buffer)
n += readBytes
if err != nil {
if err == io.EOF {
err = nil
}
return
}
select {
case <-stop:
return
case <-time.After(interval):
}
}
}
// ParseCgroupPaths parses 'procCgroupData', which is output of '/proc/<pid>/cgroup', and returns
// a map which cgroup name as key and path as value.
func ParseCgroupPaths(procCgroupData string) map[string]string {
cgroupPaths := map[string]string{}
for _, line := range strings.Split(procCgroupData, "\n") {
parts := strings.Split(line, ":")
if len(parts) != 3 {
continue
}
cgroupPaths[parts[1]] = parts[2]
}
return cgroupPaths
}
// ChannelBuffer holds a chan of byte array that can be populate in a goroutine.
type ChannelBuffer struct {
C chan []byte
}
// Write implements Writer.
func (c *ChannelBuffer) Write(b []byte) (int, error) {
c.C <- b
return len(b), nil
}
// Close closes the go channel.
func (c *ChannelBuffer) Close() error {
close(c.C)
return nil
}
// ReadTimeout reads the content of the channel in the specified byte array with
// the specified duration as timeout.
func (c *ChannelBuffer) ReadTimeout(p []byte, n time.Duration) (int, error) {
select {
case b := <-c.C:
return copy(p[0:], b), nil
case <-time.After(n):
return -1, fmt.Errorf("timeout reading from channel")
}
}
// RunAtDifferentDate runs the specified function with the given time.
// It changes the date of the system, which can led to weird behaviors.
func RunAtDifferentDate(date time.Time, block func()) {
// Layout for date. MMDDhhmmYYYY
const timeLayout = "010203042006"
// Ensure we bring time back to now
now := time.Now().Format(timeLayout)
dateReset := exec.Command("date", now)
defer RunCommand(dateReset)
dateChange := exec.Command("date", date.Format(timeLayout))
RunCommand(dateChange)
block()
return
}

View File

@@ -1,492 +0,0 @@
package integration
import (
"io"
"io/ioutil"
"os"
"os/exec"
"path"
"runtime"
"strings"
"testing"
"time"
)
func TestIsKilledFalseWithNonKilledProcess(t *testing.T) {
lsCmd := exec.Command("ls")
lsCmd.Start()
// Wait for it to finish
err := lsCmd.Wait()
if IsKilled(err) {
t.Fatalf("Expected the ls command to not be killed, was.")
}
}
func TestIsKilledTrueWithKilledProcess(t *testing.T) {
longCmd := exec.Command("top")
// Start a command
longCmd.Start()
// Capture the error when *dying*
done := make(chan error, 1)
go func() {
done <- longCmd.Wait()
}()
// Then kill it
longCmd.Process.Kill()
// Get the error
err := <-done
if !IsKilled(err) {
t.Fatalf("Expected the command to be killed, was not.")
}
}
func TestRunCommandWithOutput(t *testing.T) {
echoHelloWorldCmd := exec.Command("echo", "hello", "world")
out, exitCode, err := RunCommandWithOutput(echoHelloWorldCmd)
expected := "hello world\n"
if out != expected || exitCode != 0 || err != nil {
t.Fatalf("Expected command to output %s, got %s, %v with exitCode %v", expected, out, err, exitCode)
}
}
func TestRunCommandWithOutputError(t *testing.T) {
cmd := exec.Command("doesnotexists")
out, exitCode, err := RunCommandWithOutput(cmd)
expectedError := `exec: "doesnotexists": executable file not found in $PATH`
if out != "" || exitCode != 127 || err == nil || err.Error() != expectedError {
t.Fatalf("Expected command to output %s, got %s, %v with exitCode %v", expectedError, out, err, exitCode)
}
wrongLsCmd := exec.Command("ls", "-z")
expected := `ls: invalid option -- 'z'
Try 'ls --help' for more information.
`
out, exitCode, err = RunCommandWithOutput(wrongLsCmd)
if out != expected || exitCode != 2 || err == nil || err.Error() != "exit status 2" {
t.Fatalf("Expected command to output %s, got out:%s, err:%v with exitCode %v", expected, out, err, exitCode)
}
}
func TestRunCommandWithStdoutStderr(t *testing.T) {
echoHelloWorldCmd := exec.Command("echo", "hello", "world")
stdout, stderr, exitCode, err := RunCommandWithStdoutStderr(echoHelloWorldCmd)
expected := "hello world\n"
if stdout != expected || stderr != "" || exitCode != 0 || err != nil {
t.Fatalf("Expected command to output %s, got stdout:%s, stderr:%s, err:%v with exitCode %v", expected, stdout, stderr, err, exitCode)
}
}
func TestRunCommandWithStdoutStderrError(t *testing.T) {
cmd := exec.Command("doesnotexists")
stdout, stderr, exitCode, err := RunCommandWithStdoutStderr(cmd)
expectedError := `exec: "doesnotexists": executable file not found in $PATH`
if stdout != "" || stderr != "" || exitCode != 127 || err == nil || err.Error() != expectedError {
t.Fatalf("Expected command to output out:%s, stderr:%s, got stdout:%s, stderr:%s, err:%v with exitCode %v", "", "", stdout, stderr, err, exitCode)
}
wrongLsCmd := exec.Command("ls", "-z")
expected := `ls: invalid option -- 'z'
Try 'ls --help' for more information.
`
stdout, stderr, exitCode, err = RunCommandWithStdoutStderr(wrongLsCmd)
if stdout != "" && stderr != expected || exitCode != 2 || err == nil || err.Error() != "exit status 2" {
t.Fatalf("Expected command to output out:%s, stderr:%s, got stdout:%s, stderr:%s, err:%v with exitCode %v", "", expectedError, stdout, stderr, err, exitCode)
}
}
func TestRunCommandWithOutputForDurationFinished(t *testing.T) {
cmd := exec.Command("ls")
out, exitCode, timedOut, err := RunCommandWithOutputForDuration(cmd, 50*time.Millisecond)
if out == "" || exitCode != 0 || timedOut || err != nil {
t.Fatalf("Expected the command to run for less 50 milliseconds and thus not time out, but did not : out:[%s], exitCode:[%d], timedOut:[%v], err:[%v]", out, exitCode, timedOut, err)
}
}
func TestRunCommandWithOutputForDurationKilled(t *testing.T) {
cmd := exec.Command("sh", "-c", "while true ; do echo 1 ; sleep .1 ; done")
out, exitCode, timedOut, err := RunCommandWithOutputForDuration(cmd, 500*time.Millisecond)
ones := strings.Split(out, "\n")
if len(ones) != 6 || exitCode != 0 || !timedOut || err != nil {
t.Fatalf("Expected the command to run for 500 milliseconds (and thus print six lines (five with 1, one empty) and time out, but did not : out:[%s], exitCode:%d, timedOut:%v, err:%v", out, exitCode, timedOut, err)
}
}
func TestRunCommandWithOutputForDurationErrors(t *testing.T) {
cmd := exec.Command("ls")
cmd.Stdout = os.Stdout
if _, _, _, err := RunCommandWithOutputForDuration(cmd, 1*time.Millisecond); err == nil || err.Error() != "cmd.Stdout already set" {
t.Fatalf("Expected an error as cmd.Stdout was already set, did not (err:%s).", err)
}
cmd = exec.Command("ls")
cmd.Stderr = os.Stderr
if _, _, _, err := RunCommandWithOutputForDuration(cmd, 1*time.Millisecond); err == nil || err.Error() != "cmd.Stderr already set" {
t.Fatalf("Expected an error as cmd.Stderr was already set, did not (err:%s).", err)
}
}
func TestRunCommandWithOutputAndTimeoutFinished(t *testing.T) {
cmd := exec.Command("ls")
out, exitCode, err := RunCommandWithOutputAndTimeout(cmd, 50*time.Millisecond)
if out == "" || exitCode != 0 || err != nil {
t.Fatalf("Expected the command to run for less 50 milliseconds and thus not time out, but did not : out:[%s], exitCode:[%d], err:[%v]", out, exitCode, err)
}
}
func TestRunCommandWithOutputAndTimeoutKilled(t *testing.T) {
cmd := exec.Command("sh", "-c", "while true ; do echo 1 ; sleep .1 ; done")
out, exitCode, err := RunCommandWithOutputAndTimeout(cmd, 500*time.Millisecond)
ones := strings.Split(out, "\n")
if len(ones) != 6 || exitCode != 0 || err == nil || err.Error() != "command timed out" {
t.Fatalf("Expected the command to run for 500 milliseconds (and thus print six lines (five with 1, one empty) and time out with an error 'command timed out', but did not : out:[%s], exitCode:%d, err:%v", out, exitCode, err)
}
}
func TestRunCommandWithOutputAndTimeoutErrors(t *testing.T) {
cmd := exec.Command("ls")
cmd.Stdout = os.Stdout
if _, _, err := RunCommandWithOutputAndTimeout(cmd, 1*time.Millisecond); err == nil || err.Error() != "cmd.Stdout already set" {
t.Fatalf("Expected an error as cmd.Stdout was already set, did not (err:%s).", err)
}
cmd = exec.Command("ls")
cmd.Stderr = os.Stderr
if _, _, err := RunCommandWithOutputAndTimeout(cmd, 1*time.Millisecond); err == nil || err.Error() != "cmd.Stderr already set" {
t.Fatalf("Expected an error as cmd.Stderr was already set, did not (err:%s).", err)
}
}
func TestRunCommand(t *testing.T) {
lsCmd := exec.Command("ls")
exitCode, err := RunCommand(lsCmd)
if exitCode != 0 || err != nil {
t.Fatalf("Expected runCommand to run the command successfully, got: exitCode:%d, err:%v", exitCode, err)
}
var expectedError string
exitCode, err = RunCommand(exec.Command("doesnotexists"))
expectedError = `exec: "doesnotexists": executable file not found in $PATH`
if exitCode != 127 || err == nil || err.Error() != expectedError {
t.Fatalf("Expected runCommand to run the command successfully, got: exitCode:%d, err:%v", exitCode, err)
}
wrongLsCmd := exec.Command("ls", "-z")
expected := 2
expectedError = `exit status 2`
exitCode, err = RunCommand(wrongLsCmd)
if exitCode != expected || err == nil || err.Error() != expectedError {
t.Fatalf("Expected runCommand to run the command successfully, got: exitCode:%d, err:%v", exitCode, err)
}
}
func TestRunCommandPipelineWithOutputWithNotEnoughCmds(t *testing.T) {
_, _, err := RunCommandPipelineWithOutput(exec.Command("ls"))
expectedError := "pipeline does not have multiple cmds"
if err == nil || err.Error() != expectedError {
t.Fatalf("Expected an error with %s, got err:%s", expectedError, err)
}
}
func TestRunCommandPipelineWithOutputErrors(t *testing.T) {
cmd1 := exec.Command("ls")
cmd1.Stdout = os.Stdout
cmd2 := exec.Command("anything really")
_, _, err := RunCommandPipelineWithOutput(cmd1, cmd2)
if err == nil || err.Error() != "cannot set stdout pipe for anything really: exec: Stdout already set" {
t.Fatalf("Expected an error, got %v", err)
}
cmdWithError := exec.Command("doesnotexists")
cmdCat := exec.Command("cat")
_, _, err = RunCommandPipelineWithOutput(cmdWithError, cmdCat)
if err == nil || err.Error() != `starting doesnotexists failed with error: exec: "doesnotexists": executable file not found in $PATH` {
t.Fatalf("Expected an error, got %v", err)
}
}
func TestRunCommandPipelineWithOutput(t *testing.T) {
cmds := []*exec.Cmd{
// Print 2 characters
exec.Command("echo", "-n", "11"),
// Count the number or char from stdin (previous command)
exec.Command("wc", "-m"),
}
out, exitCode, err := RunCommandPipelineWithOutput(cmds...)
expectedOutput := "2\n"
if out != expectedOutput || exitCode != 0 || err != nil {
t.Fatalf("Expected %s for commands %v, got out:%s, exitCode:%d, err:%v", expectedOutput, cmds, out, exitCode, err)
}
}
// Simple simple test as it is just a passthrough for json.Unmarshal
func TestUnmarshalJSON(t *testing.T) {
emptyResult := struct{}{}
if err := UnmarshalJSON([]byte(""), &emptyResult); err == nil {
t.Fatalf("Expected an error, got nothing")
}
result := struct{ Name string }{}
if err := UnmarshalJSON([]byte(`{"name": "name"}`), &result); err != nil {
t.Fatal(err)
}
if result.Name != "name" {
t.Fatalf("Expected result.name to be 'name', was '%s'", result.Name)
}
}
func TestConvertSliceOfStringsToMap(t *testing.T) {
input := []string{"a", "b"}
actual := ConvertSliceOfStringsToMap(input)
for _, key := range input {
if _, ok := actual[key]; !ok {
t.Fatalf("Expected output to contains key %s, did not: %v", key, actual)
}
}
}
func TestCompareDirectoryEntries(t *testing.T) {
tmpFolder, err := ioutil.TempDir("", "integration-cli-utils-compare-directories")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpFolder)
file1 := path.Join(tmpFolder, "file1")
file2 := path.Join(tmpFolder, "file2")
os.Create(file1)
os.Create(file2)
fi1, err := os.Stat(file1)
if err != nil {
t.Fatal(err)
}
fi1bis, err := os.Stat(file1)
if err != nil {
t.Fatal(err)
}
fi2, err := os.Stat(file2)
if err != nil {
t.Fatal(err)
}
cases := []struct {
e1 []os.FileInfo
e2 []os.FileInfo
shouldError bool
}{
// Empty directories
{
[]os.FileInfo{},
[]os.FileInfo{},
false,
},
// Same FileInfos
{
[]os.FileInfo{fi1},
[]os.FileInfo{fi1},
false,
},
// Different FileInfos but same names
{
[]os.FileInfo{fi1},
[]os.FileInfo{fi1bis},
false,
},
// Different FileInfos, different names
{
[]os.FileInfo{fi1},
[]os.FileInfo{fi2},
true,
},
}
for _, elt := range cases {
err := CompareDirectoryEntries(elt.e1, elt.e2)
if elt.shouldError && err == nil {
t.Fatalf("Should have return an error, did not with %v and %v", elt.e1, elt.e2)
}
if !elt.shouldError && err != nil {
t.Fatalf("Should have not returned an error, but did : %v with %v and %v", err, elt.e1, elt.e2)
}
}
}
// FIXME make an "unhappy path" test for ListTar without "panicking" :-)
func TestListTar(t *testing.T) {
tmpFolder, err := ioutil.TempDir("", "integration-cli-utils-list-tar")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmpFolder)
// Let's create a Tar file
srcFile := path.Join(tmpFolder, "src")
tarFile := path.Join(tmpFolder, "src.tar")
os.Create(srcFile)
cmd := exec.Command("/bin/sh", "-c", "tar cf "+tarFile+" "+srcFile)
_, err = cmd.CombinedOutput()
if err != nil {
t.Fatal(err)
}
reader, err := os.Open(tarFile)
if err != nil {
t.Fatal(err)
}
defer reader.Close()
entries, err := ListTar(reader)
if err != nil {
t.Fatal(err)
}
if len(entries) != 1 && entries[0] != "src" {
t.Fatalf("Expected a tar file with 1 entry (%s), got %v", srcFile, entries)
}
}
func TestRandomTmpDirPath(t *testing.T) {
path := RandomTmpDirPath("something", runtime.GOOS)
prefix := "/tmp/something"
if runtime.GOOS == "windows" {
prefix = os.Getenv("TEMP") + `\something`
}
expectedSize := len(prefix) + 11
if !strings.HasPrefix(path, prefix) {
t.Fatalf("Expected generated path to have '%s' as prefix, got %s'", prefix, path)
}
if len(path) != expectedSize {
t.Fatalf("Expected generated path to be %d, got %d", expectedSize, len(path))
}
}
func TestConsumeWithSpeed(t *testing.T) {
reader := strings.NewReader("1234567890")
chunksize := 2
bytes1, err := ConsumeWithSpeed(reader, chunksize, 1*time.Second, nil)
if err != nil {
t.Fatal(err)
}
if bytes1 != 10 {
t.Fatalf("Expected to have read 10 bytes, got %d", bytes1)
}
}
func TestConsumeWithSpeedWithStop(t *testing.T) {
reader := strings.NewReader("1234567890")
chunksize := 2
stopIt := make(chan bool)
go func() {
time.Sleep(1 * time.Millisecond)
stopIt <- true
}()
bytes1, err := ConsumeWithSpeed(reader, chunksize, 20*time.Millisecond, stopIt)
if err != nil {
t.Fatal(err)
}
if bytes1 != 2 {
t.Fatalf("Expected to have read 2 bytes, got %d", bytes1)
}
}
func TestParseCgroupPathsEmpty(t *testing.T) {
cgroupMap := ParseCgroupPaths("")
if len(cgroupMap) != 0 {
t.Fatalf("Expected an empty map, got %v", cgroupMap)
}
cgroupMap = ParseCgroupPaths("\n")
if len(cgroupMap) != 0 {
t.Fatalf("Expected an empty map, got %v", cgroupMap)
}
cgroupMap = ParseCgroupPaths("something:else\nagain:here")
if len(cgroupMap) != 0 {
t.Fatalf("Expected an empty map, got %v", cgroupMap)
}
}
func TestParseCgroupPaths(t *testing.T) {
cgroupMap := ParseCgroupPaths("2:memory:/a\n1:cpuset:/b")
if len(cgroupMap) != 2 {
t.Fatalf("Expected a map with 2 entries, got %v", cgroupMap)
}
if value, ok := cgroupMap["memory"]; !ok || value != "/a" {
t.Fatalf("Expected cgroupMap to contains an entry for 'memory' with value '/a', got %v", cgroupMap)
}
if value, ok := cgroupMap["cpuset"]; !ok || value != "/b" {
t.Fatalf("Expected cgroupMap to contains an entry for 'cpuset' with value '/b', got %v", cgroupMap)
}
}
func TestChannelBufferTimeout(t *testing.T) {
expected := "11"
buf := &ChannelBuffer{make(chan []byte, 1)}
defer buf.Close()
go func() {
time.Sleep(100 * time.Millisecond)
io.Copy(buf, strings.NewReader(expected))
}()
// Wait long enough
b := make([]byte, 2)
_, err := buf.ReadTimeout(b, 50*time.Millisecond)
if err == nil && err.Error() != "timeout reading from channel" {
t.Fatalf("Expected an error, got %s", err)
}
// Wait for the end :)
time.Sleep(150 * time.Millisecond)
}
func TestChannelBuffer(t *testing.T) {
expected := "11"
buf := &ChannelBuffer{make(chan []byte, 1)}
defer buf.Close()
go func() {
time.Sleep(100 * time.Millisecond)
io.Copy(buf, strings.NewReader(expected))
}()
// Wait long enough
b := make([]byte, 2)
_, err := buf.ReadTimeout(b, 200*time.Millisecond)
if err != nil {
t.Fatal(err)
}
if string(b) != expected {
t.Fatalf("Expected '%s', got '%s'", expected, string(b))
}
}
// FIXME doesn't work
// func TestRunAtDifferentDate(t *testing.T) {
// var date string
// // Layout for date. MMDDhhmmYYYY
// const timeLayout = "20060102"
// expectedDate := "20100201"
// theDate, err := time.Parse(timeLayout, expectedDate)
// if err != nil {
// t.Fatal(err)
// }
// RunAtDifferentDate(theDate, func() {
// cmd := exec.Command("date", "+%Y%M%d")
// out, err := cmd.Output()
// if err != nil {
// t.Fatal(err)
// }
// date = string(out)
// })
// }

View File

@@ -1,158 +0,0 @@
package ioutils
import (
"crypto/sha1"
"encoding/hex"
"math/rand"
"testing"
"time"
)
func TestBytesPipeRead(t *testing.T) {
buf := NewBytesPipe(nil)
buf.Write([]byte("12"))
buf.Write([]byte("34"))
buf.Write([]byte("56"))
buf.Write([]byte("78"))
buf.Write([]byte("90"))
rd := make([]byte, 4)
n, err := buf.Read(rd)
if err != nil {
t.Fatal(err)
}
if n != 4 {
t.Fatalf("Wrong number of bytes read: %d, should be %d", n, 4)
}
if string(rd) != "1234" {
t.Fatalf("Read %s, but must be %s", rd, "1234")
}
n, err = buf.Read(rd)
if err != nil {
t.Fatal(err)
}
if n != 4 {
t.Fatalf("Wrong number of bytes read: %d, should be %d", n, 4)
}
if string(rd) != "5678" {
t.Fatalf("Read %s, but must be %s", rd, "5679")
}
n, err = buf.Read(rd)
if err != nil {
t.Fatal(err)
}
if n != 2 {
t.Fatalf("Wrong number of bytes read: %d, should be %d", n, 2)
}
if string(rd[:n]) != "90" {
t.Fatalf("Read %s, but must be %s", rd, "90")
}
}
func TestBytesPipeWrite(t *testing.T) {
buf := NewBytesPipe(nil)
buf.Write([]byte("12"))
buf.Write([]byte("34"))
buf.Write([]byte("56"))
buf.Write([]byte("78"))
buf.Write([]byte("90"))
if string(buf.buf[0]) != "1234567890" {
t.Fatalf("Buffer %s, must be %s", buf.buf, "1234567890")
}
}
// Write and read in different speeds/chunk sizes and check valid data is read.
func TestBytesPipeWriteRandomChunks(t *testing.T) {
cases := []struct{ iterations, writesPerLoop, readsPerLoop int }{
{100, 10, 1},
{1000, 10, 5},
{1000, 100, 0},
{1000, 5, 6},
{10000, 50, 25},
}
testMessage := []byte("this is a random string for testing")
// random slice sizes to read and write
writeChunks := []int{25, 35, 15, 20}
readChunks := []int{5, 45, 20, 25}
for _, c := range cases {
// first pass: write directly to hash
hash := sha1.New()
for i := 0; i < c.iterations*c.writesPerLoop; i++ {
if _, err := hash.Write(testMessage[:writeChunks[i%len(writeChunks)]]); err != nil {
t.Fatal(err)
}
}
expected := hex.EncodeToString(hash.Sum(nil))
// write/read through buffer
buf := NewBytesPipe(nil)
hash.Reset()
done := make(chan struct{})
go func() {
// random delay before read starts
<-time.After(time.Duration(rand.Intn(10)) * time.Millisecond)
for i := 0; ; i++ {
p := make([]byte, readChunks[(c.iterations*c.readsPerLoop+i)%len(readChunks)])
n, _ := buf.Read(p)
if n == 0 {
break
}
hash.Write(p[:n])
}
close(done)
}()
for i := 0; i < c.iterations; i++ {
for w := 0; w < c.writesPerLoop; w++ {
buf.Write(testMessage[:writeChunks[(i*c.writesPerLoop+w)%len(writeChunks)]])
}
}
buf.Close()
<-done
actual := hex.EncodeToString(hash.Sum(nil))
if expected != actual {
t.Fatalf("BytesPipe returned invalid data. Expected checksum %v, got %v", expected, actual)
}
}
}
func BenchmarkBytesPipeWrite(b *testing.B) {
for i := 0; i < b.N; i++ {
readBuf := make([]byte, 1024)
buf := NewBytesPipe(nil)
go func() {
var err error
for err == nil {
_, err = buf.Read(readBuf)
}
}()
for j := 0; j < 1000; j++ {
buf.Write([]byte("pretty short line, because why not?"))
}
buf.Close()
}
}
func BenchmarkBytesPipeRead(b *testing.B) {
rd := make([]byte, 512)
for i := 0; i < b.N; i++ {
b.StopTimer()
buf := NewBytesPipe(nil)
for j := 0; j < 500; j++ {
buf.Write(make([]byte, 1024))
}
b.StartTimer()
for j := 0; j < 1000; j++ {
if n, _ := buf.Read(rd); n != 512 {
b.Fatalf("Wrong number of bytes: %d", n)
}
}
}
}

View File

@@ -1,17 +0,0 @@
package ioutils
import "testing"
func TestFprintfIfNotEmpty(t *testing.T) {
wc := NewWriteCounter(&NopWriter{})
n, _ := FprintfIfNotEmpty(wc, "foo%s", "")
if wc.Count != 0 || n != 0 {
t.Errorf("Wrong count: %v vs. %v vs. 0", wc.Count, n)
}
n, _ = FprintfIfNotEmpty(wc, "foo%s", "bar")
if wc.Count != 6 || n != 6 {
t.Errorf("Wrong count: %v vs. %v vs. 6", wc.Count, n)
}
}

View File

@@ -1,149 +0,0 @@
package ioutils
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"testing"
)
func TestMultiReadSeekerReadAll(t *testing.T) {
str := "hello world"
s1 := strings.NewReader(str + " 1")
s2 := strings.NewReader(str + " 2")
s3 := strings.NewReader(str + " 3")
mr := MultiReadSeeker(s1, s2, s3)
expectedSize := int64(s1.Len() + s2.Len() + s3.Len())
b, err := ioutil.ReadAll(mr)
if err != nil {
t.Fatal(err)
}
expected := "hello world 1hello world 2hello world 3"
if string(b) != expected {
t.Fatalf("ReadAll failed, got: %q, expected %q", string(b), expected)
}
size, err := mr.Seek(0, os.SEEK_END)
if err != nil {
t.Fatal(err)
}
if size != expectedSize {
t.Fatalf("reader size does not match, got %d, expected %d", size, expectedSize)
}
// Reset the position and read again
pos, err := mr.Seek(0, os.SEEK_SET)
if err != nil {
t.Fatal(err)
}
if pos != 0 {
t.Fatalf("expected position to be set to 0, got %d", pos)
}
b, err = ioutil.ReadAll(mr)
if err != nil {
t.Fatal(err)
}
if string(b) != expected {
t.Fatalf("ReadAll failed, got: %q, expected %q", string(b), expected)
}
}
func TestMultiReadSeekerReadEach(t *testing.T) {
str := "hello world"
s1 := strings.NewReader(str + " 1")
s2 := strings.NewReader(str + " 2")
s3 := strings.NewReader(str + " 3")
mr := MultiReadSeeker(s1, s2, s3)
var totalBytes int64
for i, s := range []*strings.Reader{s1, s2, s3} {
sLen := int64(s.Len())
buf := make([]byte, s.Len())
expected := []byte(fmt.Sprintf("%s %d", str, i+1))
if _, err := mr.Read(buf); err != nil && err != io.EOF {
t.Fatal(err)
}
if !bytes.Equal(buf, expected) {
t.Fatalf("expected %q to be %q", string(buf), string(expected))
}
pos, err := mr.Seek(0, os.SEEK_CUR)
if err != nil {
t.Fatalf("iteration: %d, error: %v", i+1, err)
}
// check that the total bytes read is the current position of the seeker
totalBytes += sLen
if pos != totalBytes {
t.Fatalf("expected current position to be: %d, got: %d, iteration: %d", totalBytes, pos, i+1)
}
// This tests not only that SEEK_SET and SEEK_CUR give the same values, but that the next iteration is in the expected position as well
newPos, err := mr.Seek(pos, os.SEEK_SET)
if err != nil {
t.Fatal(err)
}
if newPos != pos {
t.Fatalf("expected to get same position when calling SEEK_SET with value from SEEK_CUR, cur: %d, set: %d", pos, newPos)
}
}
}
func TestMultiReadSeekerReadSpanningChunks(t *testing.T) {
str := "hello world"
s1 := strings.NewReader(str + " 1")
s2 := strings.NewReader(str + " 2")
s3 := strings.NewReader(str + " 3")
mr := MultiReadSeeker(s1, s2, s3)
buf := make([]byte, s1.Len()+3)
_, err := mr.Read(buf)
if err != nil {
t.Fatal(err)
}
// expected is the contents of s1 + 3 bytes from s2, ie, the `hel` at the end of this string
expected := "hello world 1hel"
if string(buf) != expected {
t.Fatalf("expected %s to be %s", string(buf), expected)
}
}
func TestMultiReadSeekerNegativeSeek(t *testing.T) {
str := "hello world"
s1 := strings.NewReader(str + " 1")
s2 := strings.NewReader(str + " 2")
s3 := strings.NewReader(str + " 3")
mr := MultiReadSeeker(s1, s2, s3)
s1Len := s1.Len()
s2Len := s2.Len()
s3Len := s3.Len()
s, err := mr.Seek(int64(-1*s3.Len()), os.SEEK_END)
if err != nil {
t.Fatal(err)
}
if s != int64(s1Len+s2Len) {
t.Fatalf("expected %d to be %d", s, s1.Len()+s2.Len())
}
buf := make([]byte, s3Len)
if _, err := mr.Read(buf); err != nil && err != io.EOF {
t.Fatal(err)
}
expected := fmt.Sprintf("%s %d", str, 3)
if string(buf) != fmt.Sprintf("%s %d", str, 3) {
t.Fatalf("expected %q to be %q", string(buf), expected)
}
}

View File

@@ -1,94 +0,0 @@
package ioutils
import (
"fmt"
"io/ioutil"
"strings"
"testing"
"time"
"golang.org/x/net/context"
)
// Implement io.Reader
type errorReader struct{}
func (r *errorReader) Read(p []byte) (int, error) {
return 0, fmt.Errorf("Error reader always fail.")
}
func TestReadCloserWrapperClose(t *testing.T) {
reader := strings.NewReader("A string reader")
wrapper := NewReadCloserWrapper(reader, func() error {
return fmt.Errorf("This will be called when closing")
})
err := wrapper.Close()
if err == nil || !strings.Contains(err.Error(), "This will be called when closing") {
t.Fatalf("readCloserWrapper should have call the anonymous func and thus, fail.")
}
}
func TestReaderErrWrapperReadOnError(t *testing.T) {
called := false
reader := &errorReader{}
wrapper := NewReaderErrWrapper(reader, func() {
called = true
})
_, err := wrapper.Read([]byte{})
if err == nil || !strings.Contains(err.Error(), "Error reader always fail.") {
t.Fatalf("readErrWrapper should returned an error")
}
if !called {
t.Fatalf("readErrWrapper should have call the anonymous function on failure")
}
}
func TestReaderErrWrapperRead(t *testing.T) {
reader := strings.NewReader("a string reader.")
wrapper := NewReaderErrWrapper(reader, func() {
t.Fatalf("readErrWrapper should not have called the anonymous function")
})
// Read 20 byte (should be ok with the string above)
num, err := wrapper.Read(make([]byte, 20))
if err != nil {
t.Fatal(err)
}
if num != 16 {
t.Fatalf("readerErrWrapper should have read 16 byte, but read %d", num)
}
}
func TestHashData(t *testing.T) {
reader := strings.NewReader("hash-me")
actual, err := HashData(reader)
if err != nil {
t.Fatal(err)
}
expected := "sha256:4d11186aed035cc624d553e10db358492c84a7cd6b9670d92123c144930450aa"
if actual != expected {
t.Fatalf("Expecting %s, got %s", expected, actual)
}
}
type perpetualReader struct{}
func (p *perpetualReader) Read(buf []byte) (n int, err error) {
for i := 0; i != len(buf); i++ {
buf[i] = 'a'
}
return len(buf), nil
}
func TestCancelReadCloser(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond)
cancelReadCloser := NewCancelReadCloser(ctx, ioutil.NopCloser(&perpetualReader{}))
for {
var buf [128]byte
_, err := cancelReadCloser.Read(buf[:])
if err == context.DeadlineExceeded {
break
} else if err != nil {
t.Fatalf("got unexpected error: %v", err)
}
}
}

View File

@@ -1,65 +0,0 @@
package ioutils
import (
"bytes"
"strings"
"testing"
)
func TestWriteCloserWrapperClose(t *testing.T) {
called := false
writer := bytes.NewBuffer([]byte{})
wrapper := NewWriteCloserWrapper(writer, func() error {
called = true
return nil
})
if err := wrapper.Close(); err != nil {
t.Fatal(err)
}
if !called {
t.Fatalf("writeCloserWrapper should have call the anonymous function.")
}
}
func TestNopWriteCloser(t *testing.T) {
writer := bytes.NewBuffer([]byte{})
wrapper := NopWriteCloser(writer)
if err := wrapper.Close(); err != nil {
t.Fatal("NopWriteCloser always return nil on Close.")
}
}
func TestNopWriter(t *testing.T) {
nw := &NopWriter{}
l, err := nw.Write([]byte{'c'})
if err != nil {
t.Fatal(err)
}
if l != 1 {
t.Fatalf("Expected 1 got %d", l)
}
}
func TestWriteCounter(t *testing.T) {
dummy1 := "This is a dummy string."
dummy2 := "This is another dummy string."
totalLength := int64(len(dummy1) + len(dummy2))
reader1 := strings.NewReader(dummy1)
reader2 := strings.NewReader(dummy2)
var buffer bytes.Buffer
wc := NewWriteCounter(&buffer)
reader1.WriteTo(wc)
reader2.WriteTo(wc)
if wc.Count != totalLength {
t.Errorf("Wrong count: %d vs. %d", wc.Count, totalLength)
}
if buffer.String() != dummy1+dummy2 {
t.Error("Wrong message written")
}
}

View File

@@ -1,34 +0,0 @@
package jsonlog
import (
"regexp"
"testing"
)
func TestJSONLogMarshalJSON(t *testing.T) {
logs := map[JSONLog]string{
{Log: `"A log line with \\"`}: `^{\"log\":\"\\\"A log line with \\\\\\\\\\\"\",\"time\":\".{20,}\"}$`,
{Log: "A log line"}: `^{\"log\":\"A log line\",\"time\":\".{20,}\"}$`,
{Log: "A log line with \r"}: `^{\"log\":\"A log line with \\r\",\"time\":\".{20,}\"}$`,
{Log: "A log line with & < >"}: `^{\"log\":\"A log line with \\u0026 \\u003c \\u003e\",\"time\":\".{20,}\"}$`,
{Log: "A log line with utf8 : 🚀 ψ ω β"}: `^{\"log\":\"A log line with utf8 : 🚀 ψ ω β\",\"time\":\".{20,}\"}$`,
{Stream: "stdout"}: `^{\"stream\":\"stdout\",\"time\":\".{20,}\"}$`,
{}: `^{\"time\":\".{20,}\"}$`,
// These ones are a little weird
{Log: "\u2028 \u2029"}: `^{\"log\":\"\\u2028 \\u2029\",\"time\":\".{20,}\"}$`,
{Log: string([]byte{0xaF})}: `^{\"log\":\"\\ufffd\",\"time\":\".{20,}\"}$`,
{Log: string([]byte{0x7F})}: `^{\"log\":\"\x7f\",\"time\":\".{20,}\"}$`,
}
for jsonLog, expression := range logs {
data, err := jsonLog.MarshalJSON()
if err != nil {
t.Fatal(err)
}
res := string(data)
t.Logf("Result of WriteLog: %q", res)
logRe := regexp.MustCompile(expression)
if !logRe.MatchString(res) {
t.Fatalf("Log line not in expected format [%v]: %q", expression, res)
}
}
}

View File

@@ -1,39 +0,0 @@
package jsonlog
import (
"bytes"
"regexp"
"testing"
)
func TestJSONLogsMarshalJSONBuf(t *testing.T) {
logs := map[*JSONLogs]string{
{Log: []byte(`"A log line with \\"`)}: `^{\"log\":\"\\\"A log line with \\\\\\\\\\\"\",\"time\":}$`,
{Log: []byte("A log line")}: `^{\"log\":\"A log line\",\"time\":}$`,
{Log: []byte("A log line with \r")}: `^{\"log\":\"A log line with \\r\",\"time\":}$`,
{Log: []byte("A log line with & < >")}: `^{\"log\":\"A log line with \\u0026 \\u003c \\u003e\",\"time\":}$`,
{Log: []byte("A log line with utf8 : 🚀 ψ ω β")}: `^{\"log\":\"A log line with utf8 : 🚀 ψ ω β\",\"time\":}$`,
{Stream: "stdout"}: `^{\"stream\":\"stdout\",\"time\":}$`,
{Stream: "stdout", Log: []byte("A log line")}: `^{\"log\":\"A log line\",\"stream\":\"stdout\",\"time\":}$`,
{Created: "time"}: `^{\"time\":time}$`,
{}: `^{\"time\":}$`,
// These ones are a little weird
{Log: []byte("\u2028 \u2029")}: `^{\"log\":\"\\u2028 \\u2029\",\"time\":}$`,
{Log: []byte{0xaF}}: `^{\"log\":\"\\ufffd\",\"time\":}$`,
{Log: []byte{0x7F}}: `^{\"log\":\"\x7f\",\"time\":}$`,
// with raw attributes
{Log: []byte("A log line"), RawAttrs: []byte(`{"hello":"world","value":1234}`)}: `^{\"log\":\"A log line\",\"attrs\":{\"hello\":\"world\",\"value\":1234},\"time\":}$`,
}
for jsonLog, expression := range logs {
var buf bytes.Buffer
if err := jsonLog.MarshalJSONBuf(&buf); err != nil {
t.Fatal(err)
}
res := buf.String()
t.Logf("Result of WriteLog: %q", res)
logRe := regexp.MustCompile(expression)
if !logRe.MatchString(res) {
t.Fatalf("Log line not in expected format [%v]: %q", expression, res)
}
}
}

View File

@@ -1,47 +0,0 @@
package jsonlog
import (
"testing"
"time"
)
// Testing to ensure 'year' fields is between 0 and 9999
func TestFastTimeMarshalJSONWithInvalidDate(t *testing.T) {
aTime := time.Date(-1, 1, 1, 0, 0, 0, 0, time.Local)
json, err := FastTimeMarshalJSON(aTime)
if err == nil {
t.Fatalf("FastTimeMarshalJSON should throw an error, but was '%v'", json)
}
anotherTime := time.Date(10000, 1, 1, 0, 0, 0, 0, time.Local)
json, err = FastTimeMarshalJSON(anotherTime)
if err == nil {
t.Fatalf("FastTimeMarshalJSON should throw an error, but was '%v'", json)
}
}
func TestFastTimeMarshalJSON(t *testing.T) {
aTime := time.Date(2015, 5, 29, 11, 1, 2, 3, time.UTC)
json, err := FastTimeMarshalJSON(aTime)
if err != nil {
t.Fatal(err)
}
expected := "\"2015-05-29T11:01:02.000000003Z\""
if json != expected {
t.Fatalf("Expected %v, got %v", expected, json)
}
location, err := time.LoadLocation("Europe/Paris")
if err != nil {
t.Fatal(err)
}
aTime = time.Date(2015, 5, 29, 11, 1, 2, 3, location)
json, err = FastTimeMarshalJSON(aTime)
if err != nil {
t.Fatal(err)
}
expected = "\"2015-05-29T11:01:02.000000003+02:00\""
if json != expected {
t.Fatalf("Expected %v, got %v", expected, json)
}
}

View File

@@ -1,231 +0,0 @@
package jsonmessage
import (
"bytes"
"fmt"
"strings"
"testing"
"time"
"github.com/hyperhq/hypercli/pkg/jsonlog"
"github.com/hyperhq/hypercli/pkg/term"
)
func TestError(t *testing.T) {
je := JSONError{404, "Not found"}
if je.Error() != "Not found" {
t.Fatalf("Expected 'Not found' got '%s'", je.Error())
}
}
func TestProgress(t *testing.T) {
jp := JSONProgress{}
if jp.String() != "" {
t.Fatalf("Expected empty string, got '%s'", jp.String())
}
expected := " 1 B"
jp2 := JSONProgress{Current: 1}
if jp2.String() != expected {
t.Fatalf("Expected %q, got %q", expected, jp2.String())
}
expectedStart := "[==========> ] 20 B/100 B"
jp3 := JSONProgress{Current: 20, Total: 100, Start: time.Now().Unix()}
// Just look at the start of the string
// (the remaining time is really hard to test -_-)
if jp3.String()[:len(expectedStart)] != expectedStart {
t.Fatalf("Expected to start with %q, got %q", expectedStart, jp3.String())
}
expected = "[=========================> ] 50 B/100 B"
jp4 := JSONProgress{Current: 50, Total: 100}
if jp4.String() != expected {
t.Fatalf("Expected %q, got %q", expected, jp4.String())
}
// this number can't be negative gh#7136
expected = "[==================================================>] 50 B"
jp5 := JSONProgress{Current: 50, Total: 40}
if jp5.String() != expected {
t.Fatalf("Expected %q, got %q", expected, jp5.String())
}
}
func TestJSONMessageDisplay(t *testing.T) {
now := time.Now()
messages := map[JSONMessage][]string{
// Empty
{}: {"\n", "\n"},
// Status
{
Status: "status",
}: {
"status\n",
"status\n",
},
// General
{
Time: now.Unix(),
ID: "ID",
From: "From",
Status: "status",
}: {
fmt.Sprintf("%v ID: (from From) status\n", time.Unix(now.Unix(), 0).Format(jsonlog.RFC3339NanoFixed)),
fmt.Sprintf("%v ID: (from From) status\n", time.Unix(now.Unix(), 0).Format(jsonlog.RFC3339NanoFixed)),
},
// General, with nano precision time
{
TimeNano: now.UnixNano(),
ID: "ID",
From: "From",
Status: "status",
}: {
fmt.Sprintf("%v ID: (from From) status\n", time.Unix(0, now.UnixNano()).Format(jsonlog.RFC3339NanoFixed)),
fmt.Sprintf("%v ID: (from From) status\n", time.Unix(0, now.UnixNano()).Format(jsonlog.RFC3339NanoFixed)),
},
// General, with both times Nano is preferred
{
Time: now.Unix(),
TimeNano: now.UnixNano(),
ID: "ID",
From: "From",
Status: "status",
}: {
fmt.Sprintf("%v ID: (from From) status\n", time.Unix(0, now.UnixNano()).Format(jsonlog.RFC3339NanoFixed)),
fmt.Sprintf("%v ID: (from From) status\n", time.Unix(0, now.UnixNano()).Format(jsonlog.RFC3339NanoFixed)),
},
// Stream over status
{
Status: "status",
Stream: "stream",
}: {
"stream",
"stream",
},
// With progress message
{
Status: "status",
ProgressMessage: "progressMessage",
}: {
"status progressMessage",
"status progressMessage",
},
// With progress, stream empty
{
Status: "status",
Stream: "",
Progress: &JSONProgress{Current: 1},
}: {
"",
fmt.Sprintf("%c[2K\rstatus 1 B\r", 27),
},
}
// The tests :)
for jsonMessage, expectedMessages := range messages {
// Without terminal
data := bytes.NewBuffer([]byte{})
if err := jsonMessage.Display(data, false); err != nil {
t.Fatal(err)
}
if data.String() != expectedMessages[0] {
t.Fatalf("Expected [%v], got [%v]", expectedMessages[0], data.String())
}
// With terminal
data = bytes.NewBuffer([]byte{})
if err := jsonMessage.Display(data, true); err != nil {
t.Fatal(err)
}
if data.String() != expectedMessages[1] {
t.Fatalf("Expected [%v], got [%v]", expectedMessages[1], data.String())
}
}
}
// Test JSONMessage with an Error. It will return an error with the text as error, not the meaning of the HTTP code.
func TestJSONMessageDisplayWithJSONError(t *testing.T) {
data := bytes.NewBuffer([]byte{})
jsonMessage := JSONMessage{Error: &JSONError{404, "Can't find it"}}
err := jsonMessage.Display(data, true)
if err == nil || err.Error() != "Can't find it" {
t.Fatalf("Expected a JSONError 404, got [%v]", err)
}
jsonMessage = JSONMessage{Error: &JSONError{401, "Anything"}}
err = jsonMessage.Display(data, true)
if err == nil || err.Error() != "Authentication is required." {
t.Fatalf("Expected an error [Authentication is required.], got [%v]", err)
}
}
func TestDisplayJSONMessagesStreamInvalidJSON(t *testing.T) {
var (
inFd uintptr
)
data := bytes.NewBuffer([]byte{})
reader := strings.NewReader("This is not a 'valid' JSON []")
inFd, _ = term.GetFdInfo(reader)
if err := DisplayJSONMessagesStream(reader, data, inFd, false, nil); err == nil && err.Error()[:17] != "invalid character" {
t.Fatalf("Should have thrown an error (invalid character in ..), got [%v]", err)
}
}
func TestDisplayJSONMessagesStream(t *testing.T) {
var (
inFd uintptr
)
messages := map[string][]string{
// empty string
"": {
"",
""},
// Without progress & ID
"{ \"status\": \"status\" }": {
"status\n",
"status\n",
},
// Without progress, with ID
"{ \"id\": \"ID\",\"status\": \"status\" }": {
"ID: status\n",
fmt.Sprintf("ID: status\n%c[%dB", 27, 0),
},
// With progress
"{ \"id\": \"ID\", \"status\": \"status\", \"progress\": \"ProgressMessage\" }": {
"ID: status ProgressMessage",
fmt.Sprintf("\n%c[%dAID: status ProgressMessage%c[%dB", 27, 0, 27, 0),
},
// With progressDetail
"{ \"id\": \"ID\", \"status\": \"status\", \"progressDetail\": { \"Current\": 1} }": {
"", // progressbar is disabled in non-terminal
fmt.Sprintf("\n%c[%dA%c[2K\rID: status 1 B\r%c[%dB", 27, 0, 27, 27, 0),
},
}
for jsonMessage, expectedMessages := range messages {
data := bytes.NewBuffer([]byte{})
reader := strings.NewReader(jsonMessage)
inFd, _ = term.GetFdInfo(reader)
// Without terminal
if err := DisplayJSONMessagesStream(reader, data, inFd, false, nil); err != nil {
t.Fatal(err)
}
if data.String() != expectedMessages[0] {
t.Fatalf("Expected an [%v], got [%v]", expectedMessages[0], data.String())
}
// With terminal
data = bytes.NewBuffer([]byte{})
reader = strings.NewReader(jsonMessage)
if err := DisplayJSONMessagesStream(reader, data, inFd, true, nil); err != nil {
t.Fatal(err)
}
if data.String() != expectedMessages[1] {
t.Fatalf("Expected an [%v], got [%v]", expectedMessages[1], data.String())
}
}
}

View File

@@ -1,65 +0,0 @@
Locker
=====
locker provides a mechanism for creating finer-grained locking to help
free up more global locks to handle other tasks.
The implementation looks close to a sync.Mutex, however the user must provide a
reference to use to refer to the underlying lock when locking and unlocking,
and unlock may generate an error.
If a lock with a given name does not exist when `Lock` is called, one is
created.
Lock references are automatically cleaned up on `Unlock` if nothing else is
waiting for the lock.
## Usage
```go
package important
import (
"sync"
"time"
"github.com/docker/docker/pkg/locker"
)
type important struct {
locks *locker.Locker
data map[string]interface{}
mu sync.Mutex
}
func (i *important) Get(name string) interface{} {
i.locks.Lock(name)
defer i.locks.Unlock(name)
return data[name]
}
func (i *important) Create(name string, data interface{}) {
i.locks.Lock(name)
defer i.locks.Unlock(name)
i.createImportant(data)
s.mu.Lock()
i.data[name] = data
s.mu.Unlock()
}
func (i *important) createImportant(data interface{}) {
time.Sleep(10 * time.Second)
}
```
For functions dealing with a given name, always lock at the beginning of the
function (or before doing anything with the underlying state), this ensures any
other function that is dealing with the same name will block.
When needing to modify the underlying data, use the global lock to ensure nothing
else is modfying it at the same time.
Since name lock is already in place, no reads will occur while the modification
is being performed.

View File

@@ -1,112 +0,0 @@
/*
Package locker provides a mechanism for creating finer-grained locking to help
free up more global locks to handle other tasks.
The implementation looks close to a sync.Mutex, however the user must provide a
reference to use to refer to the underlying lock when locking and unlocking,
and unlock may generate an error.
If a lock with a given name does not exist when `Lock` is called, one is
created.
Lock references are automatically cleaned up on `Unlock` if nothing else is
waiting for the lock.
*/
package locker
import (
"errors"
"sync"
"sync/atomic"
)
// ErrNoSuchLock is returned when the requested lock does not exist
var ErrNoSuchLock = errors.New("no such lock")
// Locker provides a locking mechanism based on the passed in reference name
type Locker struct {
mu sync.Mutex
locks map[string]*lockCtr
}
// lockCtr is used by Locker to represent a lock with a given name.
type lockCtr struct {
mu sync.Mutex
// waiters is the number of waiters waiting to acquire the lock
// this is int32 instead of uint32 so we can add `-1` in `dec()`
waiters int32
}
// inc increments the number of waiters waiting for the lock
func (l *lockCtr) inc() {
atomic.AddInt32(&l.waiters, 1)
}
// dec decrements the number of waiters waiting on the lock
func (l *lockCtr) dec() {
atomic.AddInt32(&l.waiters, -1)
}
// count gets the current number of waiters
func (l *lockCtr) count() int32 {
return atomic.LoadInt32(&l.waiters)
}
// Lock locks the mutex
func (l *lockCtr) Lock() {
l.mu.Lock()
}
// Unlock unlocks the mutex
func (l *lockCtr) Unlock() {
l.mu.Unlock()
}
// New creates a new Locker
func New() *Locker {
return &Locker{
locks: make(map[string]*lockCtr),
}
}
// Lock locks a mutex with the given name. If it doesn't exist, one is created
func (l *Locker) Lock(name string) {
l.mu.Lock()
if l.locks == nil {
l.locks = make(map[string]*lockCtr)
}
nameLock, exists := l.locks[name]
if !exists {
nameLock = &lockCtr{}
l.locks[name] = nameLock
}
// increment the nameLock waiters while inside the main mutex
// this makes sure that the lock isn't deleted if `Lock` and `Unlock` are called concurrently
nameLock.inc()
l.mu.Unlock()
// Lock the nameLock outside the main mutex so we don't block other operations
// once locked then we can decrement the number of waiters for this lock
nameLock.Lock()
nameLock.dec()
}
// Unlock unlocks the mutex with the given name
// If the given lock is not being waited on by any other callers, it is deleted
func (l *Locker) Unlock(name string) error {
l.mu.Lock()
nameLock, exists := l.locks[name]
if !exists {
l.mu.Unlock()
return ErrNoSuchLock
}
if nameLock.count() == 0 {
delete(l.locks, name)
}
nameLock.Unlock()
l.mu.Unlock()
return nil
}

View File

@@ -1,124 +0,0 @@
package locker
import (
"sync"
"testing"
"time"
)
func TestLockCounter(t *testing.T) {
l := &lockCtr{}
l.inc()
if l.waiters != 1 {
t.Fatal("counter inc failed")
}
l.dec()
if l.waiters != 0 {
t.Fatal("counter dec failed")
}
}
func TestLockerLock(t *testing.T) {
l := New()
l.Lock("test")
ctr := l.locks["test"]
if ctr.count() != 0 {
t.Fatalf("expected waiters to be 0, got :%d", ctr.waiters)
}
chDone := make(chan struct{})
go func() {
l.Lock("test")
close(chDone)
}()
chWaiting := make(chan struct{})
go func() {
for range time.Tick(1 * time.Millisecond) {
if ctr.count() == 1 {
close(chWaiting)
break
}
}
}()
select {
case <-chWaiting:
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for lock waiters to be incremented")
}
select {
case <-chDone:
t.Fatal("lock should not have returned while it was still held")
default:
}
if err := l.Unlock("test"); err != nil {
t.Fatal(err)
}
select {
case <-chDone:
case <-time.After(3 * time.Second):
t.Fatalf("lock should have completed")
}
if ctr.count() != 0 {
t.Fatalf("expected waiters to be 0, got: %d", ctr.count())
}
}
func TestLockerUnlock(t *testing.T) {
l := New()
l.Lock("test")
l.Unlock("test")
chDone := make(chan struct{})
go func() {
l.Lock("test")
close(chDone)
}()
select {
case <-chDone:
case <-time.After(3 * time.Second):
t.Fatalf("lock should not be blocked")
}
}
func TestLockerConcurrency(t *testing.T) {
l := New()
var wg sync.WaitGroup
for i := 0; i <= 10000; i++ {
wg.Add(1)
go func() {
l.Lock("test")
// if there is a concurrency issue, will very likely panic here
l.Unlock("test")
wg.Done()
}()
}
chDone := make(chan struct{})
go func() {
wg.Wait()
close(chDone)
}()
select {
case <-chDone:
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for locks to complete")
}
// Since everything has unlocked this should not exist anymore
if ctr, exists := l.locks["test"]; exists {
t.Fatalf("lock should not exist: %v", ctr)
}
}

View File

@@ -1,22 +0,0 @@
package longpath
import (
"strings"
"testing"
)
func TestStandardLongPath(t *testing.T) {
c := `C:\simple\path`
longC := AddPrefix(c)
if !strings.EqualFold(longC, `\\?\C:\simple\path`) {
t.Errorf("Wrong long path returned. Original = %s ; Long = %s", c, longC)
}
}
func TestUNCLongPath(t *testing.T) {
c := `\\server\share\path`
longC := AddPrefix(c)
if !strings.EqualFold(longC, `\\?\UNC\server\share\path`) {
t.Errorf("Wrong UNC long path returned. Original = %s ; Long = %s", c, longC)
}
}

View File

@@ -1,137 +0,0 @@
// +build linux
package loopback
import (
"errors"
"fmt"
"os"
"syscall"
"github.com/Sirupsen/logrus"
)
// Loopback related errors
var (
ErrAttachLoopbackDevice = errors.New("loopback attach failed")
ErrGetLoopbackBackingFile = errors.New("Unable to get loopback backing file")
ErrSetCapacity = errors.New("Unable set loopback capacity")
)
func stringToLoopName(src string) [LoNameSize]uint8 {
var dst [LoNameSize]uint8
copy(dst[:], src[:])
return dst
}
func getNextFreeLoopbackIndex() (int, error) {
f, err := os.OpenFile("/dev/loop-control", os.O_RDONLY, 0644)
if err != nil {
return 0, err
}
defer f.Close()
index, err := ioctlLoopCtlGetFree(f.Fd())
if index < 0 {
index = 0
}
return index, err
}
func openNextAvailableLoopback(index int, sparseFile *os.File) (loopFile *os.File, err error) {
// Start looking for a free /dev/loop
for {
target := fmt.Sprintf("/dev/loop%d", index)
index++
fi, err := os.Stat(target)
if err != nil {
if os.IsNotExist(err) {
logrus.Errorf("There are no more loopback devices available.")
}
return nil, ErrAttachLoopbackDevice
}
if fi.Mode()&os.ModeDevice != os.ModeDevice {
logrus.Errorf("Loopback device %s is not a block device.", target)
continue
}
// OpenFile adds O_CLOEXEC
loopFile, err = os.OpenFile(target, os.O_RDWR, 0644)
if err != nil {
logrus.Errorf("Error opening loopback device: %s", err)
return nil, ErrAttachLoopbackDevice
}
// Try to attach to the loop file
if err := ioctlLoopSetFd(loopFile.Fd(), sparseFile.Fd()); err != nil {
loopFile.Close()
// If the error is EBUSY, then try the next loopback
if err != syscall.EBUSY {
logrus.Errorf("Cannot set up loopback device %s: %s", target, err)
return nil, ErrAttachLoopbackDevice
}
// Otherwise, we keep going with the loop
continue
}
// In case of success, we finished. Break the loop.
break
}
// This can't happen, but let's be sure
if loopFile == nil {
logrus.Errorf("Unreachable code reached! Error attaching %s to a loopback device.", sparseFile.Name())
return nil, ErrAttachLoopbackDevice
}
return loopFile, nil
}
// AttachLoopDevice attaches the given sparse file to the next
// available loopback device. It returns an opened *os.File.
func AttachLoopDevice(sparseName string) (loop *os.File, err error) {
// Try to retrieve the next available loopback device via syscall.
// If it fails, we discard error and start looping for a
// loopback from index 0.
startIndex, err := getNextFreeLoopbackIndex()
if err != nil {
logrus.Debugf("Error retrieving the next available loopback: %s", err)
}
// OpenFile adds O_CLOEXEC
sparseFile, err := os.OpenFile(sparseName, os.O_RDWR, 0644)
if err != nil {
logrus.Errorf("Error opening sparse file %s: %s", sparseName, err)
return nil, ErrAttachLoopbackDevice
}
defer sparseFile.Close()
loopFile, err := openNextAvailableLoopback(startIndex, sparseFile)
if err != nil {
return nil, err
}
// Set the status of the loopback device
loopInfo := &loopInfo64{
loFileName: stringToLoopName(loopFile.Name()),
loOffset: 0,
loFlags: LoFlagsAutoClear,
}
if err := ioctlLoopSetStatus64(loopFile.Fd(), loopInfo); err != nil {
logrus.Errorf("Cannot set up loopback device info: %s", err)
// If the call failed, then free the loopback device
if err := ioctlLoopClrFd(loopFile.Fd()); err != nil {
logrus.Errorf("Error while cleaning up the loopback device")
}
loopFile.Close()
return nil, ErrAttachLoopbackDevice
}
return loopFile, nil
}

View File

@@ -1,53 +0,0 @@
// +build linux
package loopback
import (
"syscall"
"unsafe"
)
func ioctlLoopCtlGetFree(fd uintptr) (int, error) {
index, _, err := syscall.Syscall(syscall.SYS_IOCTL, fd, LoopCtlGetFree, 0)
if err != 0 {
return 0, err
}
return int(index), nil
}
func ioctlLoopSetFd(loopFd, sparseFd uintptr) error {
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, loopFd, LoopSetFd, sparseFd); err != 0 {
return err
}
return nil
}
func ioctlLoopSetStatus64(loopFd uintptr, loopInfo *loopInfo64) error {
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, loopFd, LoopSetStatus64, uintptr(unsafe.Pointer(loopInfo))); err != 0 {
return err
}
return nil
}
func ioctlLoopClrFd(loopFd uintptr) error {
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, loopFd, LoopClrFd, 0); err != 0 {
return err
}
return nil
}
func ioctlLoopGetStatus64(loopFd uintptr) (*loopInfo64, error) {
loopInfo := &loopInfo64{}
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, loopFd, LoopGetStatus64, uintptr(unsafe.Pointer(loopInfo))); err != 0 {
return nil, err
}
return loopInfo, nil
}
func ioctlLoopSetCapacity(loopFd uintptr, value int) error {
if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, loopFd, LoopSetCapacity, uintptr(value)); err != 0 {
return err
}
return nil
}

View File

@@ -1,52 +0,0 @@
// +build linux
package loopback
/*
#include <linux/loop.h> // FIXME: present only for defines, maybe we can remove it?
#ifndef LOOP_CTL_GET_FREE
#define LOOP_CTL_GET_FREE 0x4C82
#endif
#ifndef LO_FLAGS_PARTSCAN
#define LO_FLAGS_PARTSCAN 8
#endif
*/
import "C"
type loopInfo64 struct {
loDevice uint64 /* ioctl r/o */
loInode uint64 /* ioctl r/o */
loRdevice uint64 /* ioctl r/o */
loOffset uint64
loSizelimit uint64 /* bytes, 0 == max available */
loNumber uint32 /* ioctl r/o */
loEncryptType uint32
loEncryptKeySize uint32 /* ioctl w/o */
loFlags uint32 /* ioctl r/o */
loFileName [LoNameSize]uint8
loCryptName [LoNameSize]uint8
loEncryptKey [LoKeySize]uint8 /* ioctl w/o */
loInit [2]uint64
}
// IOCTL consts
const (
LoopSetFd = C.LOOP_SET_FD
LoopCtlGetFree = C.LOOP_CTL_GET_FREE
LoopGetStatus64 = C.LOOP_GET_STATUS64
LoopSetStatus64 = C.LOOP_SET_STATUS64
LoopClrFd = C.LOOP_CLR_FD
LoopSetCapacity = C.LOOP_SET_CAPACITY
)
// LOOP consts.
const (
LoFlagsAutoClear = C.LO_FLAGS_AUTOCLEAR
LoFlagsReadOnly = C.LO_FLAGS_READ_ONLY
LoFlagsPartScan = C.LO_FLAGS_PARTSCAN
LoKeySize = C.LO_KEY_SIZE
LoNameSize = C.LO_NAME_SIZE
)

View File

@@ -1,63 +0,0 @@
// +build linux
package loopback
import (
"fmt"
"os"
"syscall"
"github.com/Sirupsen/logrus"
)
func getLoopbackBackingFile(file *os.File) (uint64, uint64, error) {
loopInfo, err := ioctlLoopGetStatus64(file.Fd())
if err != nil {
logrus.Errorf("Error get loopback backing file: %s", err)
return 0, 0, ErrGetLoopbackBackingFile
}
return loopInfo.loDevice, loopInfo.loInode, nil
}
// SetCapacity reloads the size for the loopback device.
func SetCapacity(file *os.File) error {
if err := ioctlLoopSetCapacity(file.Fd(), 0); err != nil {
logrus.Errorf("Error loopbackSetCapacity: %s", err)
return ErrSetCapacity
}
return nil
}
// FindLoopDeviceFor returns a loopback device file for the specified file which
// is backing file of a loop back device.
func FindLoopDeviceFor(file *os.File) *os.File {
stat, err := file.Stat()
if err != nil {
return nil
}
targetInode := stat.Sys().(*syscall.Stat_t).Ino
targetDevice := stat.Sys().(*syscall.Stat_t).Dev
for i := 0; true; i++ {
path := fmt.Sprintf("/dev/loop%d", i)
file, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil {
if os.IsNotExist(err) {
return nil
}
// Ignore all errors until the first not-exist
// we want to continue looking for the file
continue
}
dev, inode, err := getLoopbackBackingFile(file)
if err == nil && dev == targetDevice && inode == targetInode {
return file
}
file.Close()
}
return nil
}

View File

@@ -1,40 +0,0 @@
Package mflag (aka multiple-flag) implements command-line flag parsing.
It's an **hacky** fork of the [official golang package](http://golang.org/pkg/flag/)
It adds:
* both short and long flag version
`./example -s red` `./example --string blue`
* multiple names for the same option
```
$>./example -h
Usage of example:
-s, --string="": a simple string
```
___
It is very flexible on purpose, so you can do things like:
```
$>./example -h
Usage of example:
-s, -string, --string="": a simple string
```
Or:
```
$>./example -h
Usage of example:
-oldflag, --newflag="": a simple string
```
You can also hide some flags from the usage, so if we want only `--newflag`:
```
$>./example -h
Usage of example:
--newflag="": a simple string
$>./example -oldflag str
str
```
See [example.go](example/example.go) for more details.

View File

@@ -1,36 +0,0 @@
package main
import (
"fmt"
flag "github.com/hyperhq/hypercli/pkg/mflag"
)
var (
i int
str string
b, b2, h bool
)
func init() {
flag.Bool([]string{"#hp", "#-halp"}, false, "display the halp")
flag.BoolVar(&b, []string{"b", "#bal", "#bol", "-bal"}, false, "a simple bool")
flag.BoolVar(&b, []string{"g", "#gil"}, false, "a simple bool")
flag.BoolVar(&b2, []string{"#-bool"}, false, "a simple bool")
flag.IntVar(&i, []string{"-integer", "-number"}, -1, "a simple integer")
flag.StringVar(&str, []string{"s", "#hidden", "-string"}, "", "a simple string") //-s -hidden and --string will work, but -hidden won't be in the usage
flag.BoolVar(&h, []string{"h", "#help", "-help"}, false, "display the help")
flag.StringVar(&str, []string{"mode"}, "mode1", "set the mode\nmode1: use the mode1\nmode2: use the mode2\nmode3: use the mode3")
flag.Parse()
}
func main() {
if h {
flag.PrintDefaults()
} else {
fmt.Printf("s/#hidden/-string: %s\n", str)
fmt.Printf("b: %t\n", b)
fmt.Printf("-bool: %t\n", b2)
fmt.Printf("s/#hidden/-string(via lookup): %s\n", flag.Lookup("s").Value.String())
fmt.Printf("ARGS: %v\n", flag.Args())
}
}

View File

@@ -1,527 +0,0 @@
// Copyright 2014-2016 The Docker & Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mflag
import (
"bytes"
"fmt"
"os"
"sort"
"strings"
"testing"
"time"
)
// ResetForTesting clears all flag state and sets the usage function as directed.
// After calling ResetForTesting, parse errors in flag handling will not
// exit the program.
func ResetForTesting(usage func()) {
CommandLine = NewFlagSet(os.Args[0], ContinueOnError)
Usage = usage
}
func boolString(s string) string {
if s == "0" {
return "false"
}
return "true"
}
func TestEverything(t *testing.T) {
ResetForTesting(nil)
Bool([]string{"test_bool"}, false, "bool value")
Int([]string{"test_int"}, 0, "int value")
Int64([]string{"test_int64"}, 0, "int64 value")
Uint([]string{"test_uint"}, 0, "uint value")
Uint64([]string{"test_uint64"}, 0, "uint64 value")
String([]string{"test_string"}, "0", "string value")
Float64([]string{"test_float64"}, 0, "float64 value")
Duration([]string{"test_duration"}, 0, "time.Duration value")
m := make(map[string]*Flag)
desired := "0"
visitor := func(f *Flag) {
for _, name := range f.Names {
if len(name) > 5 && name[0:5] == "test_" {
m[name] = f
ok := false
switch {
case f.Value.String() == desired:
ok = true
case name == "test_bool" && f.Value.String() == boolString(desired):
ok = true
case name == "test_duration" && f.Value.String() == desired+"s":
ok = true
}
if !ok {
t.Error("Visit: bad value", f.Value.String(), "for", name)
}
}
}
}
VisitAll(visitor)
if len(m) != 8 {
t.Error("VisitAll misses some flags")
for k, v := range m {
t.Log(k, *v)
}
}
m = make(map[string]*Flag)
Visit(visitor)
if len(m) != 0 {
t.Errorf("Visit sees unset flags")
for k, v := range m {
t.Log(k, *v)
}
}
// Now set all flags
Set("test_bool", "true")
Set("test_int", "1")
Set("test_int64", "1")
Set("test_uint", "1")
Set("test_uint64", "1")
Set("test_string", "1")
Set("test_float64", "1")
Set("test_duration", "1s")
desired = "1"
Visit(visitor)
if len(m) != 8 {
t.Error("Visit fails after set")
for k, v := range m {
t.Log(k, *v)
}
}
// Now test they're visited in sort order.
var flagNames []string
Visit(func(f *Flag) {
for _, name := range f.Names {
flagNames = append(flagNames, name)
}
})
if !sort.StringsAreSorted(flagNames) {
t.Errorf("flag names not sorted: %v", flagNames)
}
}
func TestGet(t *testing.T) {
ResetForTesting(nil)
Bool([]string{"test_bool"}, true, "bool value")
Int([]string{"test_int"}, 1, "int value")
Int64([]string{"test_int64"}, 2, "int64 value")
Uint([]string{"test_uint"}, 3, "uint value")
Uint64([]string{"test_uint64"}, 4, "uint64 value")
String([]string{"test_string"}, "5", "string value")
Float64([]string{"test_float64"}, 6, "float64 value")
Duration([]string{"test_duration"}, 7, "time.Duration value")
visitor := func(f *Flag) {
for _, name := range f.Names {
if len(name) > 5 && name[0:5] == "test_" {
g, ok := f.Value.(Getter)
if !ok {
t.Errorf("Visit: value does not satisfy Getter: %T", f.Value)
return
}
switch name {
case "test_bool":
ok = g.Get() == true
case "test_int":
ok = g.Get() == int(1)
case "test_int64":
ok = g.Get() == int64(2)
case "test_uint":
ok = g.Get() == uint(3)
case "test_uint64":
ok = g.Get() == uint64(4)
case "test_string":
ok = g.Get() == "5"
case "test_float64":
ok = g.Get() == float64(6)
case "test_duration":
ok = g.Get() == time.Duration(7)
}
if !ok {
t.Errorf("Visit: bad value %T(%v) for %s", g.Get(), g.Get(), name)
}
}
}
}
VisitAll(visitor)
}
func testParse(f *FlagSet, t *testing.T) {
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
boolFlag := f.Bool([]string{"bool"}, false, "bool value")
bool2Flag := f.Bool([]string{"bool2"}, false, "bool2 value")
f.Bool([]string{"bool3"}, false, "bool3 value")
bool4Flag := f.Bool([]string{"bool4"}, false, "bool4 value")
intFlag := f.Int([]string{"-int"}, 0, "int value")
int64Flag := f.Int64([]string{"-int64"}, 0, "int64 value")
uintFlag := f.Uint([]string{"uint"}, 0, "uint value")
uint64Flag := f.Uint64([]string{"-uint64"}, 0, "uint64 value")
stringFlag := f.String([]string{"string"}, "0", "string value")
f.String([]string{"string2"}, "0", "string2 value")
singleQuoteFlag := f.String([]string{"squote"}, "", "single quoted value")
doubleQuoteFlag := f.String([]string{"dquote"}, "", "double quoted value")
mixedQuoteFlag := f.String([]string{"mquote"}, "", "mixed quoted value")
mixed2QuoteFlag := f.String([]string{"mquote2"}, "", "mixed2 quoted value")
nestedQuoteFlag := f.String([]string{"nquote"}, "", "nested quoted value")
nested2QuoteFlag := f.String([]string{"nquote2"}, "", "nested2 quoted value")
float64Flag := f.Float64([]string{"float64"}, 0, "float64 value")
durationFlag := f.Duration([]string{"duration"}, 5*time.Second, "time.Duration value")
extra := "one-extra-argument"
args := []string{
"-bool",
"-bool2=true",
"-bool4=false",
"--int", "22",
"--int64", "0x23",
"-uint", "24",
"--uint64", "25",
"-string", "hello",
"-squote='single'",
`-dquote="double"`,
`-mquote='mixed"`,
`-mquote2="mixed2'`,
`-nquote="'single nested'"`,
`-nquote2='"double nested"'`,
"-float64", "2718e28",
"-duration", "2m",
extra,
}
if err := f.Parse(args); err != nil {
t.Fatal(err)
}
if !f.Parsed() {
t.Error("f.Parse() = false after Parse")
}
if *boolFlag != true {
t.Error("bool flag should be true, is ", *boolFlag)
}
if *bool2Flag != true {
t.Error("bool2 flag should be true, is ", *bool2Flag)
}
if !f.IsSet("bool2") {
t.Error("bool2 should be marked as set")
}
if f.IsSet("bool3") {
t.Error("bool3 should not be marked as set")
}
if !f.IsSet("bool4") {
t.Error("bool4 should be marked as set")
}
if *bool4Flag != false {
t.Error("bool4 flag should be false, is ", *bool4Flag)
}
if *intFlag != 22 {
t.Error("int flag should be 22, is ", *intFlag)
}
if *int64Flag != 0x23 {
t.Error("int64 flag should be 0x23, is ", *int64Flag)
}
if *uintFlag != 24 {
t.Error("uint flag should be 24, is ", *uintFlag)
}
if *uint64Flag != 25 {
t.Error("uint64 flag should be 25, is ", *uint64Flag)
}
if *stringFlag != "hello" {
t.Error("string flag should be `hello`, is ", *stringFlag)
}
if !f.IsSet("string") {
t.Error("string flag should be marked as set")
}
if f.IsSet("string2") {
t.Error("string2 flag should not be marked as set")
}
if *singleQuoteFlag != "single" {
t.Error("single quote string flag should be `single`, is ", *singleQuoteFlag)
}
if *doubleQuoteFlag != "double" {
t.Error("double quote string flag should be `double`, is ", *doubleQuoteFlag)
}
if *mixedQuoteFlag != `'mixed"` {
t.Error("mixed quote string flag should be `'mixed\"`, is ", *mixedQuoteFlag)
}
if *mixed2QuoteFlag != `"mixed2'` {
t.Error("mixed2 quote string flag should be `\"mixed2'`, is ", *mixed2QuoteFlag)
}
if *nestedQuoteFlag != "'single nested'" {
t.Error("nested quote string flag should be `'single nested'`, is ", *nestedQuoteFlag)
}
if *nested2QuoteFlag != `"double nested"` {
t.Error("double quote string flag should be `\"double nested\"`, is ", *nested2QuoteFlag)
}
if *float64Flag != 2718e28 {
t.Error("float64 flag should be 2718e28, is ", *float64Flag)
}
if *durationFlag != 2*time.Minute {
t.Error("duration flag should be 2m, is ", *durationFlag)
}
if len(f.Args()) != 1 {
t.Error("expected one argument, got", len(f.Args()))
} else if f.Args()[0] != extra {
t.Errorf("expected argument %q got %q", extra, f.Args()[0])
}
}
func testPanic(f *FlagSet, t *testing.T) {
f.Int([]string{"-int"}, 0, "int value")
if f.Parsed() {
t.Error("f.Parse() = true before Parse")
}
args := []string{
"-int", "21",
}
f.Parse(args)
}
func TestParsePanic(t *testing.T) {
ResetForTesting(func() {})
testPanic(CommandLine, t)
}
func TestParse(t *testing.T) {
ResetForTesting(func() { t.Error("bad parse") })
testParse(CommandLine, t)
}
func TestFlagSetParse(t *testing.T) {
testParse(NewFlagSet("test", ContinueOnError), t)
}
// Declare a user-defined flag type.
type flagVar []string
func (f *flagVar) String() string {
return fmt.Sprint([]string(*f))
}
func (f *flagVar) Set(value string) error {
*f = append(*f, value)
return nil
}
func TestUserDefined(t *testing.T) {
var flags FlagSet
flags.Init("test", ContinueOnError)
var v flagVar
flags.Var(&v, []string{"v"}, "usage")
if err := flags.Parse([]string{"-v", "1", "-v", "2", "-v=3"}); err != nil {
t.Error(err)
}
if len(v) != 3 {
t.Fatal("expected 3 args; got ", len(v))
}
expect := "[1 2 3]"
if v.String() != expect {
t.Errorf("expected value %q got %q", expect, v.String())
}
}
// Declare a user-defined boolean flag type.
type boolFlagVar struct {
count int
}
func (b *boolFlagVar) String() string {
return fmt.Sprintf("%d", b.count)
}
func (b *boolFlagVar) Set(value string) error {
if value == "true" {
b.count++
}
return nil
}
func (b *boolFlagVar) IsBoolFlag() bool {
return b.count < 4
}
func TestUserDefinedBool(t *testing.T) {
var flags FlagSet
flags.Init("test", ContinueOnError)
var b boolFlagVar
var err error
flags.Var(&b, []string{"b"}, "usage")
if err = flags.Parse([]string{"-b", "-b", "-b", "-b=true", "-b=false", "-b", "barg", "-b"}); err != nil {
if b.count < 4 {
t.Error(err)
}
}
if b.count != 4 {
t.Errorf("want: %d; got: %d", 4, b.count)
}
if err == nil {
t.Error("expected error; got none")
}
}
func TestSetOutput(t *testing.T) {
var flags FlagSet
var buf bytes.Buffer
flags.SetOutput(&buf)
flags.Init("test", ContinueOnError)
flags.Parse([]string{"-unknown"})
if out := buf.String(); !strings.Contains(out, "-unknown") {
t.Logf("expected output mentioning unknown; got %q", out)
}
}
// This tests that one can reset the flags. This still works but not well, and is
// superseded by FlagSet.
func TestChangingArgs(t *testing.T) {
ResetForTesting(func() { t.Fatal("bad parse") })
oldArgs := os.Args
defer func() { os.Args = oldArgs }()
os.Args = []string{"cmd", "-before", "subcmd", "-after", "args"}
before := Bool([]string{"before"}, false, "")
if err := CommandLine.Parse(os.Args[1:]); err != nil {
t.Fatal(err)
}
cmd := Arg(0)
os.Args = Args()
after := Bool([]string{"after"}, false, "")
Parse()
args := Args()
if !*before || cmd != "subcmd" || !*after || len(args) != 1 || args[0] != "args" {
t.Fatalf("expected true subcmd true [args] got %v %v %v %v", *before, cmd, *after, args)
}
}
// Test that -help invokes the usage message and returns ErrHelp.
func TestHelp(t *testing.T) {
var helpCalled = false
fs := NewFlagSet("help test", ContinueOnError)
fs.Usage = func() { helpCalled = true }
var flag bool
fs.BoolVar(&flag, []string{"flag"}, false, "regular flag")
// Regular flag invocation should work
err := fs.Parse([]string{"-flag=true"})
if err != nil {
t.Fatal("expected no error; got ", err)
}
if !flag {
t.Error("flag was not set by -flag")
}
if helpCalled {
t.Error("help called for regular flag")
helpCalled = false // reset for next test
}
// Help flag should work as expected.
err = fs.Parse([]string{"-help"})
if err == nil {
t.Fatal("error expected")
}
if err != ErrHelp {
t.Fatal("expected ErrHelp; got ", err)
}
if !helpCalled {
t.Fatal("help was not called")
}
// If we define a help flag, that should override.
var help bool
fs.BoolVar(&help, []string{"help"}, false, "help flag")
helpCalled = false
err = fs.Parse([]string{"-help"})
if err != nil {
t.Fatal("expected no error for defined -help; got ", err)
}
if helpCalled {
t.Fatal("help was called; should not have been for defined help flag")
}
}
// Test the flag count functions.
func TestFlagCounts(t *testing.T) {
fs := NewFlagSet("help test", ContinueOnError)
var flag bool
fs.BoolVar(&flag, []string{"flag1"}, false, "regular flag")
fs.BoolVar(&flag, []string{"#deprecated1"}, false, "regular flag")
fs.BoolVar(&flag, []string{"f", "flag2"}, false, "regular flag")
fs.BoolVar(&flag, []string{"#d", "#deprecated2"}, false, "regular flag")
fs.BoolVar(&flag, []string{"flag3"}, false, "regular flag")
fs.BoolVar(&flag, []string{"g", "#flag4", "-flag4"}, false, "regular flag")
if fs.FlagCount() != 6 {
t.Fatal("FlagCount wrong. ", fs.FlagCount())
}
if fs.FlagCountUndeprecated() != 4 {
t.Fatal("FlagCountUndeprecated wrong. ", fs.FlagCountUndeprecated())
}
if fs.NFlag() != 0 {
t.Fatal("NFlag wrong. ", fs.NFlag())
}
err := fs.Parse([]string{"-fd", "-g", "-flag4"})
if err != nil {
t.Fatal("expected no error for defined -help; got ", err)
}
if fs.NFlag() != 4 {
t.Fatal("NFlag wrong. ", fs.NFlag())
}
}
// Show up bug in sortFlags
func TestSortFlags(t *testing.T) {
fs := NewFlagSet("help TestSortFlags", ContinueOnError)
var err error
var b bool
fs.BoolVar(&b, []string{"b", "-banana"}, false, "usage")
err = fs.Parse([]string{"--banana=true"})
if err != nil {
t.Fatal("expected no error; got ", err)
}
count := 0
fs.VisitAll(func(flag *Flag) {
count++
if flag == nil {
t.Fatal("VisitAll should not return a nil flag")
}
})
flagcount := fs.FlagCount()
if flagcount != count {
t.Fatalf("FlagCount (%d) != number (%d) of elements visited", flagcount, count)
}
// Make sure its idempotent
if flagcount != fs.FlagCount() {
t.Fatalf("FlagCount (%d) != fs.FlagCount() (%d) of elements visited", flagcount, fs.FlagCount())
}
count = 0
fs.Visit(func(flag *Flag) {
count++
if flag == nil {
t.Fatal("Visit should not return a nil flag")
}
})
nflag := fs.NFlag()
if nflag != count {
t.Fatalf("NFlag (%d) != number (%d) of elements visited", nflag, count)
}
if nflag != fs.NFlag() {
t.Fatalf("NFlag (%d) != fs.NFlag() (%d) of elements visited", nflag, fs.NFlag())
}
}
func TestMergeFlags(t *testing.T) {
base := NewFlagSet("base", ContinueOnError)
base.String([]string{"f"}, "", "")
fs := NewFlagSet("test", ContinueOnError)
Merge(fs, base)
if len(fs.formal) != 1 {
t.Fatalf("FlagCount (%d) != number (1) of elements merged", len(fs.formal))
}
}

View File

@@ -1,92 +0,0 @@
package mount
import (
"fmt"
"strings"
)
// Parse fstab type mount options into mount() flags
// and device specific data
func parseOptions(options string) (int, string) {
var (
flag int
data []string
)
flags := map[string]struct {
clear bool
flag int
}{
"defaults": {false, 0},
"ro": {false, RDONLY},
"rw": {true, RDONLY},
"suid": {true, NOSUID},
"nosuid": {false, NOSUID},
"dev": {true, NODEV},
"nodev": {false, NODEV},
"exec": {true, NOEXEC},
"noexec": {false, NOEXEC},
"sync": {false, SYNCHRONOUS},
"async": {true, SYNCHRONOUS},
"dirsync": {false, DIRSYNC},
"remount": {false, REMOUNT},
"mand": {false, MANDLOCK},
"nomand": {true, MANDLOCK},
"atime": {true, NOATIME},
"noatime": {false, NOATIME},
"diratime": {true, NODIRATIME},
"nodiratime": {false, NODIRATIME},
"bind": {false, BIND},
"rbind": {false, RBIND},
"unbindable": {false, UNBINDABLE},
"runbindable": {false, RUNBINDABLE},
"private": {false, PRIVATE},
"rprivate": {false, RPRIVATE},
"shared": {false, SHARED},
"rshared": {false, RSHARED},
"slave": {false, SLAVE},
"rslave": {false, RSLAVE},
"relatime": {false, RELATIME},
"norelatime": {true, RELATIME},
"strictatime": {false, STRICTATIME},
"nostrictatime": {true, STRICTATIME},
}
for _, o := range strings.Split(options, ",") {
// If the option does not exist in the flags table or the flag
// is not supported on the platform,
// then it is a data value for a specific fs type
if f, exists := flags[o]; exists && f.flag != 0 {
if f.clear {
flag &= ^f.flag
} else {
flag |= f.flag
}
} else {
data = append(data, o)
}
}
return flag, strings.Join(data, ",")
}
// ParseTmpfsOptions parse fstab type mount options into flags and data
func ParseTmpfsOptions(options string) (int, string, error) {
flags, data := parseOptions(options)
validFlags := map[string]bool{
"": true,
"size": true,
"mode": true,
"uid": true,
"gid": true,
"nr_inodes": true,
"nr_blocks": true,
"mpol": true,
}
for _, o := range strings.Split(data, ",") {
opt := strings.SplitN(o, "=", 2)
if !validFlags[opt[0]] {
return 0, "", fmt.Errorf("Invalid tmpfs option %q", opt)
}
}
return flags, data, nil
}

View File

@@ -1,48 +0,0 @@
// +build freebsd,cgo
package mount
/*
#include <sys/mount.h>
*/
import "C"
const (
// RDONLY will mount the filesystem as read-only.
RDONLY = C.MNT_RDONLY
// NOSUID will not allow set-user-identifier or set-group-identifier bits to
// take effect.
NOSUID = C.MNT_NOSUID
// NOEXEC will not allow execution of any binaries on the mounted file system.
NOEXEC = C.MNT_NOEXEC
// SYNCHRONOUS will allow any I/O to the file system to be done synchronously.
SYNCHRONOUS = C.MNT_SYNCHRONOUS
// NOATIME will not update the file access time when reading from a file.
NOATIME = C.MNT_NOATIME
)
// These flags are unsupported.
const (
BIND = 0
DIRSYNC = 0
MANDLOCK = 0
NODEV = 0
NODIRATIME = 0
UNBINDABLE = 0
RUNBINDABLE = 0
PRIVATE = 0
RPRIVATE = 0
SHARED = 0
RSHARED = 0
SLAVE = 0
RSLAVE = 0
RBIND = 0
RELATIVE = 0
RELATIME = 0
REMOUNT = 0
STRICTATIME = 0
)

View File

@@ -1,85 +0,0 @@
package mount
import (
"syscall"
)
const (
// RDONLY will mount the file system read-only.
RDONLY = syscall.MS_RDONLY
// NOSUID will not allow set-user-identifier or set-group-identifier bits to
// take effect.
NOSUID = syscall.MS_NOSUID
// NODEV will not interpret character or block special devices on the file
// system.
NODEV = syscall.MS_NODEV
// NOEXEC will not allow execution of any binaries on the mounted file system.
NOEXEC = syscall.MS_NOEXEC
// SYNCHRONOUS will allow I/O to the file system to be done synchronously.
SYNCHRONOUS = syscall.MS_SYNCHRONOUS
// DIRSYNC will force all directory updates within the file system to be done
// synchronously. This affects the following system calls: create, link,
// unlink, symlink, mkdir, rmdir, mknod and rename.
DIRSYNC = syscall.MS_DIRSYNC
// REMOUNT will attempt to remount an already-mounted file system. This is
// commonly used to change the mount flags for a file system, especially to
// make a readonly file system writeable. It does not change device or mount
// point.
REMOUNT = syscall.MS_REMOUNT
// MANDLOCK will force mandatory locks on a filesystem.
MANDLOCK = syscall.MS_MANDLOCK
// NOATIME will not update the file access time when reading from a file.
NOATIME = syscall.MS_NOATIME
// NODIRATIME will not update the directory access time.
NODIRATIME = syscall.MS_NODIRATIME
// BIND remounts a subtree somewhere else.
BIND = syscall.MS_BIND
// RBIND remounts a subtree and all possible submounts somewhere else.
RBIND = syscall.MS_BIND | syscall.MS_REC
// UNBINDABLE creates a mount which cannot be cloned through a bind operation.
UNBINDABLE = syscall.MS_UNBINDABLE
// RUNBINDABLE marks the entire mount tree as UNBINDABLE.
RUNBINDABLE = syscall.MS_UNBINDABLE | syscall.MS_REC
// PRIVATE creates a mount which carries no propagation abilities.
PRIVATE = syscall.MS_PRIVATE
// RPRIVATE marks the entire mount tree as PRIVATE.
RPRIVATE = syscall.MS_PRIVATE | syscall.MS_REC
// SLAVE creates a mount which receives propagation from its master, but not
// vice versa.
SLAVE = syscall.MS_SLAVE
// RSLAVE marks the entire mount tree as SLAVE.
RSLAVE = syscall.MS_SLAVE | syscall.MS_REC
// SHARED creates a mount which provides the ability to create mirrors of
// that mount such that mounts and unmounts within any of the mirrors
// propagate to the other mirrors.
SHARED = syscall.MS_SHARED
// RSHARED marks the entire mount tree as SHARED.
RSHARED = syscall.MS_SHARED | syscall.MS_REC
// RELATIME updates inode access times relative to modify or change time.
RELATIME = syscall.MS_RELATIME
// STRICTATIME allows to explicitly request full atime updates. This makes
// it possible for the kernel to default to relatime or noatime but still
// allow userspace to override it.
STRICTATIME = syscall.MS_STRICTATIME
)

View File

@@ -1,30 +0,0 @@
// +build !linux,!freebsd freebsd,!cgo
package mount
// These flags are unsupported.
const (
BIND = 0
DIRSYNC = 0
MANDLOCK = 0
NOATIME = 0
NODEV = 0
NODIRATIME = 0
NOEXEC = 0
NOSUID = 0
UNBINDABLE = 0
RUNBINDABLE = 0
PRIVATE = 0
RPRIVATE = 0
SHARED = 0
RSHARED = 0
SLAVE = 0
RSLAVE = 0
RBIND = 0
RELATIME = 0
RELATIVE = 0
REMOUNT = 0
STRICTATIME = 0
SYNCHRONOUS = 0
RDONLY = 0
)

View File

@@ -1,74 +0,0 @@
package mount
import (
"time"
)
// GetMounts retrieves a list of mounts for the current running process.
func GetMounts() ([]*Info, error) {
return parseMountTable()
}
// Mounted looks at /proc/self/mountinfo to determine of the specified
// mountpoint has been mounted
func Mounted(mountpoint string) (bool, error) {
entries, err := parseMountTable()
if err != nil {
return false, err
}
// Search the table for the mountpoint
for _, e := range entries {
if e.Mountpoint == mountpoint {
return true, nil
}
}
return false, nil
}
// Mount will mount filesystem according to the specified configuration, on the
// condition that the target path is *not* already mounted. Options must be
// specified like the mount or fstab unix commands: "opt1=val1,opt2=val2". See
// flags.go for supported option flags.
func Mount(device, target, mType, options string) error {
flag, _ := parseOptions(options)
if flag&REMOUNT != REMOUNT {
if mounted, err := Mounted(target); err != nil || mounted {
return err
}
}
return ForceMount(device, target, mType, options)
}
// ForceMount will mount a filesystem according to the specified configuration,
// *regardless* if the target path is not already mounted. Options must be
// specified like the mount or fstab unix commands: "opt1=val1,opt2=val2". See
// flags.go for supported option flags.
func ForceMount(device, target, mType, options string) error {
flag, data := parseOptions(options)
if err := mount(device, target, mType, uintptr(flag), data); err != nil {
return err
}
return nil
}
// Unmount will unmount the target filesystem, so long as it is mounted.
func Unmount(target string) error {
if mounted, err := Mounted(target); err != nil || !mounted {
return err
}
return ForceUnmount(target)
}
// ForceUnmount will force an unmount of the target filesystem, regardless if
// it is mounted or not.
func ForceUnmount(target string) (err error) {
// Simple retry logic for unmount
for i := 0; i < 10; i++ {
if err = unmount(target, 0); err == nil {
return nil
}
time.Sleep(100 * time.Millisecond)
}
return
}

View File

@@ -1,137 +0,0 @@
package mount
import (
"os"
"path"
"testing"
)
func TestMountOptionsParsing(t *testing.T) {
options := "noatime,ro,size=10k"
flag, data := parseOptions(options)
if data != "size=10k" {
t.Fatalf("Expected size=10 got %s", data)
}
expectedFlag := NOATIME | RDONLY
if flag != expectedFlag {
t.Fatalf("Expected %d got %d", expectedFlag, flag)
}
}
func TestMounted(t *testing.T) {
tmp := path.Join(os.TempDir(), "mount-tests")
if err := os.MkdirAll(tmp, 0777); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmp)
var (
sourceDir = path.Join(tmp, "source")
targetDir = path.Join(tmp, "target")
sourcePath = path.Join(sourceDir, "file.txt")
targetPath = path.Join(targetDir, "file.txt")
)
os.Mkdir(sourceDir, 0777)
os.Mkdir(targetDir, 0777)
f, err := os.Create(sourcePath)
if err != nil {
t.Fatal(err)
}
f.WriteString("hello")
f.Close()
f, err = os.Create(targetPath)
if err != nil {
t.Fatal(err)
}
f.Close()
if err := Mount(sourceDir, targetDir, "none", "bind,rw"); err != nil {
t.Fatal(err)
}
defer func() {
if err := Unmount(targetDir); err != nil {
t.Fatal(err)
}
}()
mounted, err := Mounted(targetDir)
if err != nil {
t.Fatal(err)
}
if !mounted {
t.Fatalf("Expected %s to be mounted", targetDir)
}
if _, err := os.Stat(targetDir); err != nil {
t.Fatal(err)
}
}
func TestMountReadonly(t *testing.T) {
tmp := path.Join(os.TempDir(), "mount-tests")
if err := os.MkdirAll(tmp, 0777); err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tmp)
var (
sourceDir = path.Join(tmp, "source")
targetDir = path.Join(tmp, "target")
sourcePath = path.Join(sourceDir, "file.txt")
targetPath = path.Join(targetDir, "file.txt")
)
os.Mkdir(sourceDir, 0777)
os.Mkdir(targetDir, 0777)
f, err := os.Create(sourcePath)
if err != nil {
t.Fatal(err)
}
f.WriteString("hello")
f.Close()
f, err = os.Create(targetPath)
if err != nil {
t.Fatal(err)
}
f.Close()
if err := Mount(sourceDir, targetDir, "none", "bind,ro"); err != nil {
t.Fatal(err)
}
defer func() {
if err := Unmount(targetDir); err != nil {
t.Fatal(err)
}
}()
f, err = os.OpenFile(targetPath, os.O_RDWR, 0777)
if err == nil {
t.Fatal("Should not be able to open a ro file as rw")
}
}
func TestGetMounts(t *testing.T) {
mounts, err := GetMounts()
if err != nil {
t.Fatal(err)
}
root := false
for _, entry := range mounts {
if entry.Mountpoint == "/" {
root = true
}
}
if !root {
t.Fatal("/ should be mounted at least")
}
}

Some files were not shown because too many files have changed in this diff Show More